import time import asyncio from pathlib import Path from ultralytics import YOLO from utils.websocket_server import room_manager # 开始训练回调 def on_train_start(trainer): full_path = trainer.save_dir p = Path(full_path) room_name = 'train_' + p.parent.name asyncio.run(room_manager.send_to_room(room_name, 'start')) # 结束训练回调 def on_train_end(trainer): full_path = trainer.save_dir p = Path(full_path) room_name = 'train_' + p.parent.name asyncio.run(room_manager.send_to_room(room_name, 'end')) # 每轮训练结束回调函数 def on_train_epoch_end(trainer): full_path = trainer.save_dir p = Path(full_path) room_name = 'train_' + p.parent.name asyncio.run(room_manager.send_to_room(room_name, trainer.epoch)) # 预测开始回调 def on_predict_start(predictor): full_path = predictor.save_dir p = Path(full_path) room_name = 'detect_' + p.parent.parent.name asyncio.run(room_manager.send_to_room(room_name, 'start')) # 每个批次预测结束回调 def on_predict_batch_end(predictor): full_path = predictor.save_dir p = Path(full_path) room_name = 'detect_' + p.parent.parent.name asyncio.run(room_manager.send_to_room(room_name, predictor.seen)) # 预测结束回调 def on_predict_end(predictor): full_path = predictor.save_dir p = Path(full_path) room_name = 'detect_' + p.parent.parent.name asyncio.run(room_manager.send_to_room(room_name, 'end')) class YoloModel: def __init__(self, pt_url=None): if pt_url: self.model = YOLO(pt_url) else: self.model = YOLO('yolo11n.pt') def train(self, data, epochs, project, name, patience): """ 模型训练 """ self.model.add_callback('on_train_start', on_train_start) self.model.add_callback('on_train_end', on_train_end) self.model.add_callback('on_train_epoch_end', on_train_epoch_end) self.model.train( data=data, epochs=epochs, imgsz=640, device=0, project=project, name=name, patience=patience, verbose=False ) def predict_folder(self, source, name, project): """ 对文件夹中的内容(图片或者视频)进行预测 """ self.model.add_callback('on_predict_start', on_predict_start) self.model.add_callback('on_predict_batch_end', on_predict_batch_end) self.model.add_callback('on_predict_end', on_predict_end) self.model.predict( source=source, name=name, project=project, save=True, save_txt=True, device='0', conf=0.6 ) def predict_rtsp(self, source, detect_id): """ 对rtsp视频流进行预测 """ room_name = 'detect_rtsp_' + str(detect_id) start_time = time.time() for result in self.model.predict(source=source, stream=True, device='0', conf=0.6): # 检查是否已经超过10分钟(600秒) elapsed_time = time.time() - start_time if elapsed_time > 600: # 600 seconds = 10 minutes break frame = result.plot() frame_data = frame.tobytes() asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))