93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
from ultralytics import YOLO
|
||
from pathlib import Path
|
||
import cv2, subprocess, numpy as np, time
|
||
|
||
|
||
# 开始训练回调
|
||
def on_train_start(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
folder_name = p.parent.name
|
||
print('开始模型训练', folder_name)
|
||
|
||
|
||
# 结束训练回调
|
||
def on_train_end(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
folder_name = p.parent.name
|
||
print('模型训练结束', folder_name)
|
||
|
||
|
||
# 每轮训练结束回调函数
|
||
def on_train_epoch_end(trainer):
|
||
full_path = trainer.save_dir
|
||
p = Path(full_path)
|
||
folder_name = p.parent.name
|
||
print('当前训练轮数', trainer.epoch, '当前项目编号', folder_name)
|
||
|
||
|
||
class YoloModel:
|
||
|
||
def __init__(self, pt_url):
|
||
self.model = YOLO(pt_url)
|
||
self.model.add_callback('on_train_start', on_train_start)
|
||
self.model.add_callback('on_train_end', on_train_end)
|
||
self.model.add_callback('on_train_epoch_end', on_train_epoch_end)
|
||
|
||
def train(self, data, epochs, project, name):
|
||
train_result = self.model.train(
|
||
data=data,
|
||
epochs=epochs,
|
||
imgsz=640,
|
||
device=0,
|
||
project=project,
|
||
name=name,
|
||
verbose=False
|
||
)
|
||
return train_result
|
||
|
||
def detect(self, input_rtsp, output_rtsp, fps: int):
|
||
w, h = 1920, 1080
|
||
# 1. 拉流
|
||
pull = [
|
||
'ffmpeg', '-hide_banner', '-loglevel', 'error',
|
||
'-i', input_rtsp,
|
||
'-an', '-f', 'rawvideo', '-pix_fmt', 'bgr24',
|
||
'-vf', f'fps={fps},scale={w}:{h}', '-'
|
||
]
|
||
pull_proc = subprocess.Popen(pull, stdout=subprocess.PIPE, bufsize=w * h * 3 * 2)
|
||
|
||
# 2. 推流(mediamtx 监听 8554)
|
||
push = [
|
||
'ffmpeg',
|
||
'-hide_banner',
|
||
'-loglevel',
|
||
'error',
|
||
'-y',
|
||
'-f',
|
||
'rawvideo',
|
||
'-pix_fmt', 'yuv420p',
|
||
'-s', f'{w}x{h}', '-r', str(fps),
|
||
'-i', '-', '-c:v', 'libx264', '-preset', 'ultrafast', '-tune', 'zerolatency',
|
||
'-f', 'flv', output_rtsp
|
||
]
|
||
push_proc = subprocess.Popen(push, stdin=subprocess.PIPE)
|
||
|
||
while True:
|
||
raw = pull_proc.stdout.read(w * h * 3)
|
||
if len(raw) != w * h * 3: # 网络丢包,直接跳
|
||
time.sleep(0.01)
|
||
continue
|
||
frame = np.frombuffer(raw, np.uint8).reshape((h, w, 3))
|
||
|
||
# 3. 推理(缩图 + stream 模式)
|
||
small = cv2.resize(frame, (w, h))
|
||
results = self.model(small, stream=True, verbose=False)
|
||
for r in results:
|
||
# 4. 画框(返回 RGB → 转回 BGR)
|
||
img = r.plot()
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV_I420)
|
||
# 5. 写回
|
||
push_proc.stdin.write(img.tobytes())
|