deep_sort 提升版本

This commit is contained in:
2025-06-11 16:19:38 +08:00
parent 9e99b08d13
commit ca245e4cec
47 changed files with 114 additions and 3644 deletions

View File

@ -51,21 +51,6 @@ async def before_detect(
return detect_log
def run_img_loop(
weights: str,
source: str,
project: str,
name: str,
detect_id: int,
is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_folder(weights, source, project, name, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
def run_detect_folder(
weights: str,
source: str,
@ -111,121 +96,13 @@ async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, nam
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param room_name: websocket链接名称
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param is_gpu: 是否启用加速
:return:
"""
# room = 'detect_rtsp_' + str(detect_id)
# # 选择设备CPU 或 GPU
# device = select_device('cpu')
# # 判断是否存在cuda版本
# if is_gpu == 'True':
# device = select_device('cuda:0')
#
# # 加载模型
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
#
# stride, names, pt = model.stride, model.names, model.pt
# img_sz = check_img_size((640, 640), s=stride) # check image size
#
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# bs = len(dataset)
#
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
#
# time.sleep(3) # 等待3s等待websocket进入
#
# start_time = time.time()
#
# for path, im, im0s, vid_cap, s in dataset:
# # 检查是否已经超过10分钟600秒
# elapsed_time = time.time() - start_time
# if elapsed_time > 600: # 600 seconds = 10 minutes
# print(room, "已达到最大执行时间,结束推理。")
# break
# if room_manager.rooms.get(room):
# im = torch.from_numpy(im).to(model.device)
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
#
# # Inference
# pred = model(im, augment=False, visualize=False)
# # NMS
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
#
# # Process predictions
# for i, det in enumerate(pred): # per image
# p, im0, frame = path[i], im0s[i].copy(), dataset.count
# annotator = Annotator(im0, line_width=3, example=str(names))
# if len(det):
# # Rescale boxes from img_size to im0 size
# det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
#
# # Write results
# for *xyxy, conf, cls in reversed(det):
# c = int(cls) # integer class
# label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
# annotator.box_label(xyxy, label, color=colors(c, True))
#
# # Stream results
# im0 = annotator.result()
# # 将帧编码为 JPEG
# ret, jpeg = cv2.imencode('.jpg', im0)
# if ret:
# frame_data = jpeg.tobytes()
# await room_manager.send_stream_to_room(room, frame_data)
# else:
# print(room, '结束推理')
# break
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
run_detect_rtsp(
weights_pt,
rtsp_url,
data,
detect_id,
is_gpu
)
)
# 可选: 关闭循环
loop.close()
def run_deepsort_loop(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
deepsort_service.run_deepsort(
detect_id,
weights_pt,
data,
idx_to_class,
sort_type,
video_path,
rtsp_url
)
)
# 可选: 关闭循环
loop.close()
model = YoloModel(weights_pt)
model.predict_rtsp(rtsp_url, room_name)

View File

@ -11,10 +11,11 @@ from core.dependencies import IdList
from utils.websocket_server import room_manager
from . import schemas, crud, params, service, models
from apps.business.train.crud import ProjectTrainDal
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
from apps.business.deepsort import service as deep_sort_service
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
import os
import shutil
@ -147,13 +148,11 @@ async def run_detect_yolo(
else:
weights_pt = train.last_pt
thread_train = threading.Thread(
target=service.run_rtsp_loop,
target=service.run_detect_rtsp,
args=(
weights_pt,
detect.rtsp_url,
train.train_data,
detect.id,
None
room
)
)
thread_train.start()
@ -164,22 +163,16 @@ async def run_detect_yolo(
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id)
idx_to_class = {str(i): name for i, name in enumerate(label_name_list)}
# if detect_log_in.pt_type == 'best':
# weights_pt = train.best_pt
# else:
# weights_pt = train.last_pt
if detect.file_type == 'rtsp':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'rtsp', None, detect.rtsp_url))
threading_main.start()
elif detect.file_type == 'video':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'video', detect.folder_url, None))
threading_main.start()
threading_main = threading.Thread(
target=deep_sort_service.run_deepsort,
args=(
detect.id,
train.best_pt,
idx_to_class,
detect.rtsp_url
)
)
threading_main.start()
return SuccessResponse(msg="执行成功")