Files
aicheckv2-api/algo/yolov11.py
2025-06-11 16:19:38 +08:00

115 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import time
import asyncio
from pathlib import Path
from ultralytics import YOLO
from utils.websocket_server import room_manager
# 开始训练回调
def on_train_start(trainer):
full_path = trainer.save_dir
p = Path(full_path)
room_name = 'train_' + p.parent.name
asyncio.run(room_manager.send_to_room(room_name, 'start'))
# 结束训练回调
def on_train_end(trainer):
full_path = trainer.save_dir
p = Path(full_path)
room_name = 'train_' + p.parent.name
asyncio.run(room_manager.send_to_room(room_name, 'end'))
# 每轮训练结束回调函数
def on_train_epoch_end(trainer):
full_path = trainer.save_dir
p = Path(full_path)
room_name = 'train_' + p.parent.name
asyncio.run(room_manager.send_to_room(room_name, trainer.epoch))
# 预测开始回调
def on_predict_start(predictor):
full_path = predictor.save_dir
p = Path(full_path)
room_name = 'detect_' + p.parent.parent.name
asyncio.run(room_manager.send_to_room(room_name, 'start'))
# 每个批次预测结束回调
def on_predict_batch_end(predictor):
full_path = predictor.save_dir
p = Path(full_path)
room_name = 'detect_' + p.parent.parent.name
asyncio.run(room_manager.send_to_room(room_name, predictor.seen))
# 预测结束回调
def on_predict_end(predictor):
full_path = predictor.save_dir
p = Path(full_path)
room_name = 'detect_' + p.parent.parent.name
asyncio.run(room_manager.send_to_room(room_name, 'end'))
class YoloModel:
def __init__(self, pt_url=None):
if pt_url:
self.model = YOLO(pt_url)
else:
self.model = YOLO('yolo11n.pt')
def train(self, data, epochs, project, name, patience):
"""
模型训练
"""
self.model.add_callback('on_train_start', on_train_start)
self.model.add_callback('on_train_end', on_train_end)
self.model.add_callback('on_train_epoch_end', on_train_epoch_end)
self.model.train(
data=data,
epochs=epochs,
imgsz=640,
device=0,
project=project,
name=name,
patience=patience,
verbose=False
)
def predict_folder(self, source, name, project):
"""
对文件夹中的内容(图片或者视频)进行预测
"""
self.model.add_callback('on_predict_start', on_predict_start)
self.model.add_callback('on_predict_batch_end', on_predict_batch_end)
self.model.add_callback('on_predict_end', on_predict_end)
self.model.predict(
source=source,
name=name,
project=project,
save=True,
save_txt=True,
device='0',
conf=0.6
)
def predict_rtsp(self, source, room_name):
"""
对rtsp视频流进行预测
"""
start_time = time.time()
for result in self.model.predict(source=source, stream=True, device='0', conf=0.6):
# 检查是否已经超过10分钟600秒
elapsed_time = time.time() - start_time
if elapsed_time > 600: # 600 seconds = 10 minutes
break
frame = result.plot()
ret, jpeg = cv2.imencode('.jpg', frame)
if ret:
frame_data = jpeg.tobytes()
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))