This commit is contained in:
wudong 2022-11-08 13:37:30 +08:00
commit 16c9fc98d4

View File

@ -29,8 +29,10 @@ import os
import platform
import sys
from pathlib import Path
from app.schemas.TrainResult import DetectReport, DetectProcessValueDice
from app.controller.AlgorithmController import algorithm_process_value_websocket
import torch
from datetime import datetime
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
@ -113,14 +115,28 @@ def run(id,
# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
#回调函数参数定义
report = DetectReport(rate_of_progess=0, precision=[], id=id)
@algorithm_process_value_websocket()
def report_cellback(i, num_epochs, ori_img, res_img):
report.rate_of_progess = ((i + 1) / num_epochs) * 100
#report.progress = (i + 1)
report.end_time = datetime.now()
#report.precision[0].value.append(reportAccu)
process_value_list = DetectProcessValueDice(ori_img=ori_img, res_img=res_img)
report.precision.append(process_value_list)
return report.dict()
#######定义声明完成##################
count = 0
for path, im, im0s, vid_cap, s in dataset:
count = count + 1
with dt[0]:
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
@ -206,9 +222,9 @@ def run(id,
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
# 原始图像路径/结果图像路径 传参
#online_img_tools.get_res_img(res_path=save_path, img_path=path, proj_no=pro)
report_cellback(count,len(os.listdir(source)),path,save_path)
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
#######统计检测结果:图片总数量,成功数量,失败数量