Files
aicheckv2/app/service/project_detect_service.py

291 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import subprocess
import torch
from app.util.yolov5.models.common import DetectMultiBackend
from app.util.yolov5.utils.torch_utils import select_device
from app.util.yolov5.utils.dataloaders import LoadStreams
from app.util.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
from ultralytics.utils.plotting import Annotator, colors
from app.model.crud import project_detect_crud as pdc
from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog, \
ProjectDetectLogImg
from app.util.random_utils import random_str
from app.config.config_reader import detect_url
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import yolo_url
from app.websocket.web_socket_server import room_manager
from app.common.redis_cli import redis_conn
def add_detect(detect_in: ProjectDetectIn, session: Session):
"""
新增训练集合信息,并创建文件夹
:param detect_in:
:param session:
:return:
"""
detect = ProjectDetect(**detect_in.dict())
detect.detect_no = random_str(6)
detect.detect_version = 0
detect.detect_status = '0'
url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url
detect = pdc.add_detect(detect, session)
return detect
def del_detect(detect_id: int, session: Session):
"""
删除推理集合和推理记录
:param detect_id:
:param session:
:return:
"""
detect = pdc.get_detect_by_id(detect_id, session)
session.delete(detect)
if detect.file_type != 'rtsp':
folder_url = [detect.folder_url]
detect_logs = pdc.get_logs(detect_id, session)
for log in detect_logs:
folder_url.append(log.detect_folder_url)
os.create_folder(folder_url)
session.commit()
def check_image_name(detect_id: int, files: List[UploadFile], session: Session):
"""
校验上传的文件名称是否重复
:param detect_id:
:param files:
:param session:
:return:
"""
for file in files:
if not pdc.check_img_name(detect_id, file.filename, session):
return False, file.filename
return True, None
def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], session: Session):
"""
上传推理集合的照片,保存原图,并生成缩略图
:param detect:
:param files:
:param session:
:return:
"""
images = []
for file in files:
image = ProjectDetectImg()
image.detect_id = detect.id
image.file_name = file.filename
# 保存原图
path = os.save_images(detect.folder_url, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = os.file_path(detect.folder_url, 'thumb', ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
pdc.add_detect_imgs(images, session)
def del_detect_img(detect_img_id: int, session: Session):
"""
删除训练集合图片
:param detect_img_id:
:param session:
:return:
"""
detect_img = session.query(ProjectDetectImg).filter_by(id=detect_img_id).first()
if detect_img is None:
return 0
os.delete_file_if_exists(detect_img.image_url, detect_img.thumb_image_url)
session.delete(detect_img)
session.commit()
return 1
def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param session:
:return:
"""
# 推理版本
version_path = 'v' + str(detect.detect_version + 1)
# 权重文件
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
# 推理集合文件路径
img_url = detect.folder_url
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
# 构建推理记录数据
detect_log = ProjectDetectLog()
detect_log.detect_name = detect.detect_name
detect_log.detect_id = detect.id
detect_log.detect_version = version_path
detect_log.train_id = train.id
detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
detect_log = pdc.add_detect_log(detect_log, session)
return detect_log
async def run_detect_img(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int,
session: Session):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:param log_id: 日志id
:param detect_id: 推理集合id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pdc.update_detect_status(detect_id, -1, session)
else:
await room_manager.send_to_room(room, 'success')
pdc.update_detect_status(detect_id, 2, session)
detect_imgs = pdc.get_img_list(detect_id, session)
detect_log_imgs = []
for detect_img in detect_imgs:
detect_log_img = ProjectDetectLogImg()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_img.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_img.file_name
detect_log_imgs.append(detect_log_img)
pdc.add_detect_imgs(detect_log_imgs, session)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:return:
"""
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3) # 等待3s等待websocket进入
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
with dt[1]:
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image
p, im0, frame = path[i], im0s[i].copy(), dataset.count
annotator = Annotator(im0, line_width=3, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
annotator.box_label(xyxy, label, color=colors(c, True))
# Stream results
im0 = annotator.result()
# 将帧编码为 JPEG
ret, jpeg = cv2.imencode('.jpg', im0)
if ret:
frame_data = jpeg.tobytes()
await room_manager.send_stream_to_room(room, frame_data)
else:
print(room, '结束推理');
break