Files
aicheckv2-api/apps/business/detect/service.py
2025-06-11 16:19:38 +08:00

108 lines
3.2 KiB
Python

from algo import YoloModel
from utils import os_utils as os
from . import models, crud, schemas
from application.settings import detect_url
from apps.business.train import models as train_models
from apps.business.deepsort import service as deepsort_service
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
async def before_detect(
detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect,
train: train_models.ProjectTrain,
db: AsyncSession,
user_id: int):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param db:
:param user_id:
:return:
"""
# 推理版本
version_path = 'v' + str(detect.detect_version + 1)
# 权重文件
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
# 推理集合文件路径
img_url = detect.folder_url
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
# 构建推理记录数据
detect_log = models.ProjectDetectLog()
detect_log.detect_name = detect.detect_name
detect_log.detect_id = detect.id
detect_log.detect_version = version_path
detect_log.train_id = train.id
detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
detect_log.user_id = user_id
await crud.ProjectDetectLogDal(db).create_model(detect_log)
return detect_log
def run_detect_folder(
weights: str,
source: str,
project: str,
name: str):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:return:
"""
model = YoloModel(weights)
model.predict_folder(
source=source,
project=project,
name=name
)
async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
"""
更新推理集合的状态
"""
detect_dal = crud.ProjectDetectDal(db)
detect = await detect_dal.get_data(detect_id)
detect.detect_version = detect.detect_version + 1
await detect_dal.put_data(data_id=detect_id, data=detect)
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
limit=0,
v_where=[models.ProjectDetectFile.detect_id == detect_id],
v_return_objs=True,
v_return_count=False)
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.file_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str):
"""
rtsp 视频流推理
:param room_name: websocket链接名称
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:return:
"""
model = YoloModel(weights_pt)
model.predict_rtsp(rtsp_url, room_name)