108 lines
3.2 KiB
Python
108 lines
3.2 KiB
Python
from algo import YoloModel
|
|
from utils import os_utils as os
|
|
from . import models, crud, schemas
|
|
from application.settings import detect_url
|
|
from apps.business.train import models as train_models
|
|
from apps.business.deepsort import service as deepsort_service
|
|
|
|
import asyncio
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
async def before_detect(
|
|
detect_in: schemas.ProjectDetectLogIn,
|
|
detect: models.ProjectDetect,
|
|
train: train_models.ProjectTrain,
|
|
db: AsyncSession,
|
|
user_id: int):
|
|
"""
|
|
开始推理
|
|
:param detect:
|
|
:param detect_in:
|
|
:param train:
|
|
:param db:
|
|
:param user_id:
|
|
:return:
|
|
"""
|
|
# 推理版本
|
|
version_path = 'v' + str(detect.detect_version + 1)
|
|
|
|
# 权重文件
|
|
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
|
|
|
|
# 推理集合文件路径
|
|
img_url = detect.folder_url
|
|
|
|
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
|
|
|
|
# 构建推理记录数据
|
|
detect_log = models.ProjectDetectLog()
|
|
detect_log.detect_name = detect.detect_name
|
|
detect_log.detect_id = detect.id
|
|
detect_log.detect_version = version_path
|
|
detect_log.train_id = train.id
|
|
detect_log.train_version = train.train_version
|
|
detect_log.pt_type = detect_in.pt_type
|
|
detect_log.pt_url = pt_url
|
|
detect_log.folder_url = img_url
|
|
detect_log.detect_folder_url = out_url
|
|
detect_log.user_id = user_id
|
|
await crud.ProjectDetectLogDal(db).create_model(detect_log)
|
|
return detect_log
|
|
|
|
|
|
def run_detect_folder(
|
|
weights: str,
|
|
source: str,
|
|
project: str,
|
|
name: str):
|
|
"""
|
|
执行yolov5的推理
|
|
:param weights: 权重文件
|
|
:param source: 图片所在文件
|
|
:param project: 推理完成的文件位置
|
|
:param name: 版本名称
|
|
:return:
|
|
"""
|
|
model = YoloModel(weights)
|
|
model.predict_folder(
|
|
source=source,
|
|
project=project,
|
|
name=name
|
|
)
|
|
|
|
|
|
async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
|
|
"""
|
|
更新推理集合的状态
|
|
"""
|
|
detect_dal = crud.ProjectDetectDal(db)
|
|
detect = await detect_dal.get_data(detect_id)
|
|
detect.detect_version = detect.detect_version + 1
|
|
await detect_dal.put_data(data_id=detect_id, data=detect)
|
|
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
|
|
limit=0,
|
|
v_where=[models.ProjectDetectFile.detect_id == detect_id],
|
|
v_return_objs=True,
|
|
v_return_count=False)
|
|
detect_log_files = []
|
|
for detect_file in detect_files:
|
|
detect_log_img = models.ProjectDetectLogFile()
|
|
detect_log_img.log_id = log_id
|
|
image_url = os.file_path(project, name, detect_file.file_name)
|
|
detect_log_img.file_url = image_url
|
|
detect_log_img.file_name = detect_file.file_name
|
|
detect_log_files.append(detect_log_img)
|
|
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
|
|
|
|
|
|
def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str):
|
|
"""
|
|
rtsp 视频流推理
|
|
:param room_name: websocket链接名称
|
|
:param weights_pt: 权重文件
|
|
:param rtsp_url: 视频流地址
|
|
:return:
|
|
"""
|
|
model = YoloModel(weights_pt)
|
|
model.predict_rtsp(rtsp_url, room_name) |