115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
import cv2
|
||
import time
|
||
import asyncio
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
|
||
from utils.websocket_server import room_manager
|
||
|
||
|
||
# 开始训练回调
|
||
def on_train_start(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'train_' + p.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, 'start'))
|
||
|
||
|
||
# 结束训练回调
|
||
def on_train_end(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'train_' + p.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, 'end'))
|
||
|
||
|
||
# 每轮训练结束回调函数
|
||
def on_train_epoch_end(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'train_' + p.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, trainer.epoch))
|
||
|
||
|
||
# 预测开始回调
|
||
def on_predict_start(predictor):
|
||
full_path = predictor.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'detect_' + p.parent.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, 'start'))
|
||
|
||
|
||
# 每个批次预测结束回调
|
||
def on_predict_batch_end(predictor):
|
||
full_path = predictor.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'detect_' + p.parent.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, predictor.seen))
|
||
|
||
|
||
# 预测结束回调
|
||
def on_predict_end(predictor):
|
||
full_path = predictor.save_dir
|
||
p = Path(full_path)
|
||
room_name = 'detect_' + p.parent.parent.name
|
||
asyncio.run(room_manager.send_to_room(room_name, 'end'))
|
||
|
||
|
||
class YoloModel:
|
||
|
||
def __init__(self, pt_url=None):
|
||
if pt_url:
|
||
self.model = YOLO(pt_url)
|
||
else:
|
||
self.model = YOLO('yolo11n.pt')
|
||
|
||
def train(self, data, epochs, project, name, patience):
|
||
"""
|
||
模型训练
|
||
"""
|
||
self.model.add_callback('on_train_start', on_train_start)
|
||
self.model.add_callback('on_train_end', on_train_end)
|
||
self.model.add_callback('on_train_epoch_end', on_train_epoch_end)
|
||
self.model.train(
|
||
data=data,
|
||
epochs=epochs,
|
||
imgsz=640,
|
||
device=0,
|
||
project=project,
|
||
name=name,
|
||
patience=patience,
|
||
verbose=False
|
||
)
|
||
|
||
def predict_folder(self, source, name, project):
|
||
"""
|
||
对文件夹中的内容(图片或者视频)进行预测
|
||
"""
|
||
self.model.add_callback('on_predict_start', on_predict_start)
|
||
self.model.add_callback('on_predict_batch_end', on_predict_batch_end)
|
||
self.model.add_callback('on_predict_end', on_predict_end)
|
||
self.model.predict(
|
||
source=source,
|
||
name=name,
|
||
project=project,
|
||
save=True,
|
||
save_txt=True,
|
||
device='0',
|
||
conf=0.6
|
||
)
|
||
|
||
def predict_rtsp(self, source, room_name):
|
||
"""
|
||
对rtsp视频流进行预测
|
||
"""
|
||
start_time = time.time()
|
||
for result in self.model.predict(source=source, stream=True, device='0', conf=0.6):
|
||
# 检查是否已经超过10分钟(600秒)
|
||
elapsed_time = time.time() - start_time
|
||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||
break
|
||
frame = result.plot()
|
||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||
if ret:
|
||
frame_data = jpeg.tobytes()
|
||
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data)) |