Files
aicheckv2-api/apps/business/detect/service.py
2025-06-09 15:27:45 +08:00

231 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from algo import YoloModel
from utils import os_utils as os
from . import models, crud, schemas
from application.settings import detect_url
from apps.business.train import models as train_models
from apps.business.deepsort import service as deepsort_service
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
async def before_detect(
detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect,
train: train_models.ProjectTrain,
db: AsyncSession,
user_id: int):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param db:
:param user_id:
:return:
"""
# 推理版本
version_path = 'v' + str(detect.detect_version + 1)
# 权重文件
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
# 推理集合文件路径
img_url = detect.folder_url
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
# 构建推理记录数据
detect_log = models.ProjectDetectLog()
detect_log.detect_name = detect.detect_name
detect_log.detect_id = detect.id
detect_log.detect_version = version_path
detect_log.train_id = train.id
detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
detect_log.user_id = user_id
await crud.ProjectDetectLogDal(db).create_model(detect_log)
return detect_log
def run_img_loop(
weights: str,
source: str,
project: str,
name: str,
detect_id: int,
is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_folder(weights, source, project, name, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
def run_detect_folder(
weights: str,
source: str,
project: str,
name: str):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:return:
"""
model = YoloModel(weights)
model.predict_folder(
source=source,
project=project,
name=name
)
async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
"""
更新推理集合的状态
"""
detect_dal = crud.ProjectDetectDal(db)
detect = await detect_dal.get_data(detect_id)
detect.detect_version = detect.detect_version + 1
await detect_dal.put_data(data_id=detect_id, data=detect)
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
limit=0,
v_where=[models.ProjectDetectFile.detect_id == detect_id],
v_return_objs=True,
v_return_count=False)
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.file_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param is_gpu: 是否启用加速
:return:
"""
# room = 'detect_rtsp_' + str(detect_id)
# # 选择设备CPU 或 GPU
# device = select_device('cpu')
# # 判断是否存在cuda版本
# if is_gpu == 'True':
# device = select_device('cuda:0')
#
# # 加载模型
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
#
# stride, names, pt = model.stride, model.names, model.pt
# img_sz = check_img_size((640, 640), s=stride) # check image size
#
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# bs = len(dataset)
#
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
#
# time.sleep(3) # 等待3s等待websocket进入
#
# start_time = time.time()
#
# for path, im, im0s, vid_cap, s in dataset:
# # 检查是否已经超过10分钟600秒
# elapsed_time = time.time() - start_time
# if elapsed_time > 600: # 600 seconds = 10 minutes
# print(room, "已达到最大执行时间,结束推理。")
# break
# if room_manager.rooms.get(room):
# im = torch.from_numpy(im).to(model.device)
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
#
# # Inference
# pred = model(im, augment=False, visualize=False)
# # NMS
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
#
# # Process predictions
# for i, det in enumerate(pred): # per image
# p, im0, frame = path[i], im0s[i].copy(), dataset.count
# annotator = Annotator(im0, line_width=3, example=str(names))
# if len(det):
# # Rescale boxes from img_size to im0 size
# det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
#
# # Write results
# for *xyxy, conf, cls in reversed(det):
# c = int(cls) # integer class
# label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
# annotator.box_label(xyxy, label, color=colors(c, True))
#
# # Stream results
# im0 = annotator.result()
# # 将帧编码为 JPEG
# ret, jpeg = cv2.imencode('.jpg', im0)
# if ret:
# frame_data = jpeg.tobytes()
# await room_manager.send_stream_to_room(room, frame_data)
# else:
# print(room, '结束推理')
# break
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
run_detect_rtsp(
weights_pt,
rtsp_url,
data,
detect_id,
is_gpu
)
)
# 可选: 关闭循环
loop.close()
def run_deepsort_loop(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
deepsort_service.run_deepsort(
detect_id,
weights_pt,
data,
idx_to_class,
sort_type,
video_path,
rtsp_url
)
)
# 可选: 关闭循环
loop.close()