Files
ultralytics_demo/algo/model.py
2025-09-15 14:37:32 +08:00

93 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
from pathlib import Path
import cv2, subprocess, numpy as np, time
# 开始训练回调
def on_train_start(trainer):
full_path = trainer.save_dir
p = Path(full_path)
folder_name = p.parent.name
print('开始模型训练', folder_name)
# 结束训练回调
def on_train_end(trainer):
full_path = trainer.save_dir
p = Path(full_path)
folder_name = p.parent.name
print('模型训练结束', folder_name)
# 每轮训练结束回调函数
def on_train_epoch_end(trainer):
full_path = trainer.save_dir
p = Path(full_path)
folder_name = p.parent.name
print('当前训练轮数', trainer.epoch, '当前项目编号', folder_name)
class YoloModel:
def __init__(self, pt_url):
self.model = YOLO(pt_url)
self.model.add_callback('on_train_start', on_train_start)
self.model.add_callback('on_train_end', on_train_end)
self.model.add_callback('on_train_epoch_end', on_train_epoch_end)
def train(self, data, epochs, project, name):
train_result = self.model.train(
data=data,
epochs=epochs,
imgsz=640,
device=0,
project=project,
name=name,
verbose=False
)
return train_result
def detect(self, input_rtsp, output_rtsp, fps: int):
w, h = 1920, 1080
# 1. 拉流
pull = [
'ffmpeg', '-hide_banner', '-loglevel', 'error',
'-i', input_rtsp,
'-an', '-f', 'rawvideo', '-pix_fmt', 'bgr24',
'-vf', f'fps={fps},scale={w}:{h}', '-'
]
pull_proc = subprocess.Popen(pull, stdout=subprocess.PIPE, bufsize=w * h * 3 * 2)
# 2. 推流mediamtx 监听 8554
push = [
'ffmpeg',
'-hide_banner',
'-loglevel',
'error',
'-y',
'-f',
'rawvideo',
'-pix_fmt', 'yuv420p',
'-s', f'{w}x{h}', '-r', str(fps),
'-i', '-', '-c:v', 'libx264', '-preset', 'ultrafast', '-tune', 'zerolatency',
'-f', 'flv', output_rtsp
]
push_proc = subprocess.Popen(push, stdin=subprocess.PIPE)
while True:
raw = pull_proc.stdout.read(w * h * 3)
if len(raw) != w * h * 3: # 网络丢包,直接跳
time.sleep(0.01)
continue
frame = np.frombuffer(raw, np.uint8).reshape((h, w, 3))
# 3. 推理(缩图 + stream 模式)
small = cv2.resize(frame, (w, h))
results = self.model(small, stream=True, verbose=False)
for r in results:
# 4. 画框(返回 RGB → 转回 BGR
img = r.plot()
img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV_I420)
# 5. 写回
push_proc.stdin.write(img.tobytes())