from algo import YoloModel from utils import os_utils as os from . import models, crud, schemas from application.settings import detect_url from apps.business.train import models as train_models from apps.business.deepsort import service as deepsort_service import asyncio from sqlalchemy.ext.asyncio import AsyncSession async def before_detect( detect_in: schemas.ProjectDetectLogIn, detect: models.ProjectDetect, train: train_models.ProjectTrain, db: AsyncSession, user_id: int): """ 开始推理 :param detect: :param detect_in: :param train: :param db: :param user_id: :return: """ # 推理版本 version_path = 'v' + str(detect.detect_version + 1) # 权重文件 pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt # 推理集合文件路径 img_url = detect.folder_url out_url = os.file_path(detect_url, detect.detect_no, 'detect') # 构建推理记录数据 detect_log = models.ProjectDetectLog() detect_log.detect_name = detect.detect_name detect_log.detect_id = detect.id detect_log.detect_version = version_path detect_log.train_id = train.id detect_log.train_version = train.train_version detect_log.pt_type = detect_in.pt_type detect_log.pt_url = pt_url detect_log.folder_url = img_url detect_log.detect_folder_url = out_url detect_log.user_id = user_id await crud.ProjectDetectLogDal(db).create_model(detect_log) return detect_log def run_detect_folder( weights: str, source: str, project: str, name: str): """ 执行yolov5的推理 :param weights: 权重文件 :param source: 图片所在文件 :param project: 推理完成的文件位置 :param name: 版本名称 :return: """ model = YoloModel(weights) model.predict_folder( source=source, project=project, name=name ) async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name): """ 更新推理集合的状态 """ detect_dal = crud.ProjectDetectDal(db) detect = await detect_dal.get_data(detect_id) detect.detect_version = detect.detect_version + 1 await detect_dal.put_data(data_id=detect_id, data=detect) detect_files = await crud.ProjectDetectFileDal(db).get_datas( limit=0, v_where=[models.ProjectDetectFile.detect_id == detect_id], v_return_objs=True, v_return_count=False) detect_log_files = [] for detect_file in detect_files: detect_log_img = models.ProjectDetectLogFile() detect_log_img.log_id = log_id image_url = os.file_path(project, name, detect_file.file_name) detect_log_img.file_url = image_url detect_log_img.file_name = detect_file.file_name detect_log_files.append(detect_log_img) await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files) def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str): """ rtsp 视频流推理 :param room_name: websocket链接名称 :param weights_pt: 权重文件 :param rtsp_url: 视频流地址 :return: """ model = YoloModel(weights_pt) model.predict_rtsp(rtsp_url, room_name)