修改一些小问题

This commit is contained in:
2025-06-05 08:45:44 +08:00
parent f32cd5b9a2
commit 0cd3d914e9
7 changed files with 31 additions and 84 deletions

View File

@ -1,7 +1,6 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import asyncio import asyncio
import torch
from app.application.token_middleware import TokenMiddleware from app.application.token_middleware import TokenMiddleware
from app.application.logger_middleware import LoggerMiddleware from app.application.logger_middleware import LoggerMiddleware

View File

@ -54,7 +54,7 @@ def del_detect(detect_id: int, session: Session):
detect_logs = pdc.get_logs(detect_id, session) detect_logs = pdc.get_logs(detect_id, session)
for log in detect_logs: for log in detect_logs:
folder_url.append(log.detect_folder_url) folder_url.append(log.detect_folder_url)
os.create_folder(folder_url) os.delete_paths(folder_url)
session.commit() session.commit()
@ -168,7 +168,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
is_gpu = redis_conn.get('is_gpu') is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本 # 判断是否存在cuda版本
if is_gpu == 'True': if is_gpu == 'True':
commend.append("--device", "0") commend.append("--device=0")
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
commend, commend,
@ -225,12 +225,12 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False) model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size img_sz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1) dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset) bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
@ -244,22 +244,11 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
im /= 255 # 0 - 255 to 0.0 - 1.0 im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference # Inference
with dt[1]: with dt[1]:
if model.xml and im.shape[0] > 1: pred = model(im, augment=False, visualize=False)
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
# NMS # NMS
with dt[2]: with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000) pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
@ -286,5 +275,5 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
frame_data = jpeg.tobytes() frame_data = jpeg.tobytes()
await room_manager.send_stream_to_room(room, frame_data) await room_manager.send_stream_to_room(room, frame_data)
else: else:
print(room, '结束推理'); print(room, '结束推理')
break break

View File

@ -88,7 +88,6 @@ def del_img(image_id: int, session: Session):
session.commit() session.commit()
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session): def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
""" """
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存 保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
@ -223,7 +222,7 @@ async def run_commend(data: str, project: str, name: str, epochs: int, patience:
while process.poll() is None: while process.poll() is None:
line = process.stdout.readline() line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n' and '0%' not in line: if line != '\n' and '0%' not in line and 'yolo' not in line and 'YOLO' not in line:
await room_manager.send_to_room(room, line + '\n') await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码 # 等待进程结束并获取返回码
@ -232,24 +231,24 @@ async def run_commend(data: str, project: str, name: str, epochs: int, patience:
pic.update_project_status(project_id, '-1', session) pic.update_project_status(project_id, '-1', session)
else: else:
await room_manager.send_to_room(room, 'success') await room_manager.send_to_room(room, 'success')
pic.update_project_status(project_id, '2', session) pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息 # 然后保存版本训练信息
train = ProjectTrain() train = ProjectTrain()
train.project_id = project_id train.project_id = project_id
train.train_version = name train.train_version = name
train_url = os.file_path(project, name) train_url = os.file_path(project, name)
train.train_url = train_url train.train_url = train_url
train.train_data = data train.train_data = data
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt') bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
last_pt_path = os.file_path(train_url, 'weights', 'last.pt') last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path train.best_pt = bast_pt_path
train.last_pt = last_pt_path train.last_pt = last_pt_path
if weights != None and weights != '': if weights != None and weights != '':
train.weights_id = weights train.weights_id = weights
train.weights_name = train_info.train_version train.weights_name = train_info.train_version
train.patience = patience train.patience = patience
train.epochs = epochs train.epochs = epochs
ptc.add_train(train, session) ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel], def operate_img_label(img_list: List[ProjectImgLabel],

View File

@ -15,8 +15,6 @@ class SocketManager:
self.rooms[room].remove(websocket) self.rooms[room].remove(websocket)
if len(self.rooms[room]) == 0: if len(self.rooms[room]) == 0:
del self.rooms[room] del self.rooms[room]
if room.startswith('detect_rtsp_'):
print()
async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None): async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None):
if room in self.rooms: if room in self.rooms:

View File

@ -1,38 +0,0 @@
# deep_sort ----------------------这个暂时还不需要
atomicwrites==1.3.0
attrs==19.3.0
colorama==0.4.3
easydict~=1.13
entrypoints==0.3
et-xmlfile==1.0.1
flake8==3.7.9
flake8-import-order==0.18.1
importlib-metadata==1.6.0
jdcal==1.4.1
joblib==1.2.0
lap==0.4.0
mccabe==0.6.1
more-itertools==8.2.0
motmetrics==1.2.0
openpyxl==3.0.3
packaging==20.3
pluggy==0.13.1
py==1.10.0
py-cpuinfo==5.0.0
pycodestyle==2.5.0
pyflakes==2.1.1
pyparsing==2.4.7
pytest==5.4.1
pytest-benchmark==3.2.3
python-dateutil==2.8.1
pytz==2019.3
scikit-learn==1.6.1
six==1.14.0
sklearn==0.0
Vizer==0.1.5
wcwidth==0.1.9
xmltodict==0.12.0
zipp==3.1.0
mmdet~=3.0.0
pycocotools~=2.0.6
python-dotenv~=0.21.0

View File

@ -32,9 +32,9 @@ PyYAML>=5.3.1
requests>=2.32.2 requests>=2.32.2
scipy==1.13.1 scipy==1.13.1
thop>=0.1.1 # FLOPs computation thop>=0.1.1 # FLOPs computation
# torch==2.6.0+cu124 # 本地安装 # torch==2.6.0+cu124
# torchvision==0.21.0+cu124 # 本地安装 # torchvision==0.21.0+cu124
tqdm>=4.66.3 tqdm>=4.66.3
ultralytics==8.3.75 # https://ultralytics.com ultralytics==8.3.75
pandas==2.2.3 pandas==2.2.3
seaborn==0.11.2 # 对应的这个依赖也需要降 seaborn==0.11.2 # 对应的这个依赖也需要降

View File

@ -485,7 +485,7 @@ def train(hyp, opt, device, callbacks):
"ema": deepcopy(ema.ema).half(), "ema": deepcopy(ema.ema).half(),
"updates": ema.updates, "updates": ema.updates,
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"opt": vars(opt), # "opt": vars(opt),
"git": GIT_INFO, # {remote, branch, commit} if a git repo "git": GIT_INFO, # {remote, branch, commit} if a git repo
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
} }