修改一些小问题
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import asyncio
|
import asyncio
|
||||||
import torch
|
|
||||||
|
|
||||||
from app.application.token_middleware import TokenMiddleware
|
from app.application.token_middleware import TokenMiddleware
|
||||||
from app.application.logger_middleware import LoggerMiddleware
|
from app.application.logger_middleware import LoggerMiddleware
|
||||||
|
@ -54,7 +54,7 @@ def del_detect(detect_id: int, session: Session):
|
|||||||
detect_logs = pdc.get_logs(detect_id, session)
|
detect_logs = pdc.get_logs(detect_id, session)
|
||||||
for log in detect_logs:
|
for log in detect_logs:
|
||||||
folder_url.append(log.detect_folder_url)
|
folder_url.append(log.detect_folder_url)
|
||||||
os.create_folder(folder_url)
|
os.delete_paths(folder_url)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
|
|||||||
is_gpu = redis_conn.get('is_gpu')
|
is_gpu = redis_conn.get('is_gpu')
|
||||||
# 判断是否存在cuda版本
|
# 判断是否存在cuda版本
|
||||||
if is_gpu == 'True':
|
if is_gpu == 'True':
|
||||||
commend.append("--device", "0")
|
commend.append("--device=0")
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
commend,
|
commend,
|
||||||
@ -225,12 +225,12 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
|||||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||||
|
|
||||||
stride, names, pt = model.stride, model.names, model.pt
|
stride, names, pt = model.stride, model.names, model.pt
|
||||||
imgsz = check_img_size((640, 640), s=stride) # check image size
|
img_sz = check_img_size((640, 640), s=stride) # check image size
|
||||||
|
|
||||||
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
|
dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
|
||||||
bs = len(dataset)
|
bs = len(dataset)
|
||||||
|
|
||||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
|
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
|
||||||
|
|
||||||
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
|
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
|
||||||
|
|
||||||
@ -244,22 +244,11 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
|||||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
if model.xml and im.shape[0] > 1:
|
|
||||||
ims = torch.chunk(im, im.shape[0], 0)
|
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
if model.xml and im.shape[0] > 1:
|
pred = model(im, augment=False, visualize=False)
|
||||||
pred = None
|
|
||||||
for image in ims:
|
|
||||||
if pred is None:
|
|
||||||
pred = model(image, augment=False, visualize=False).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
|
|
||||||
dim=0)
|
|
||||||
pred = [pred, None]
|
|
||||||
else:
|
|
||||||
pred = model(im, augment=False, visualize=False)
|
|
||||||
# NMS
|
# NMS
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
|
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
|
||||||
@ -286,5 +275,5 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
|||||||
frame_data = jpeg.tobytes()
|
frame_data = jpeg.tobytes()
|
||||||
await room_manager.send_stream_to_room(room, frame_data)
|
await room_manager.send_stream_to_room(room, frame_data)
|
||||||
else:
|
else:
|
||||||
print(room, '结束推理');
|
print(room, '结束推理')
|
||||||
break
|
break
|
||||||
|
@ -88,7 +88,6 @@ def del_img(image_id: int, session: Session):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
||||||
"""
|
"""
|
||||||
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
|
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
|
||||||
@ -223,7 +222,7 @@ async def run_commend(data: str, project: str, name: str, epochs: int, patience:
|
|||||||
while process.poll() is None:
|
while process.poll() is None:
|
||||||
line = process.stdout.readline()
|
line = process.stdout.readline()
|
||||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||||
if line != '\n' and '0%' not in line:
|
if line != '\n' and '0%' not in line and 'yolo' not in line and 'YOLO' not in line:
|
||||||
await room_manager.send_to_room(room, line + '\n')
|
await room_manager.send_to_room(room, line + '\n')
|
||||||
|
|
||||||
# 等待进程结束并获取返回码
|
# 等待进程结束并获取返回码
|
||||||
@ -232,24 +231,24 @@ async def run_commend(data: str, project: str, name: str, epochs: int, patience:
|
|||||||
pic.update_project_status(project_id, '-1', session)
|
pic.update_project_status(project_id, '-1', session)
|
||||||
else:
|
else:
|
||||||
await room_manager.send_to_room(room, 'success')
|
await room_manager.send_to_room(room, 'success')
|
||||||
pic.update_project_status(project_id, '2', session)
|
pic.update_project_status(project_id, '2', session)
|
||||||
# 然后保存版本训练信息
|
# 然后保存版本训练信息
|
||||||
train = ProjectTrain()
|
train = ProjectTrain()
|
||||||
train.project_id = project_id
|
train.project_id = project_id
|
||||||
train.train_version = name
|
train.train_version = name
|
||||||
train_url = os.file_path(project, name)
|
train_url = os.file_path(project, name)
|
||||||
train.train_url = train_url
|
train.train_url = train_url
|
||||||
train.train_data = data
|
train.train_data = data
|
||||||
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||||||
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||||||
train.best_pt = bast_pt_path
|
train.best_pt = bast_pt_path
|
||||||
train.last_pt = last_pt_path
|
train.last_pt = last_pt_path
|
||||||
if weights != None and weights != '':
|
if weights != None and weights != '':
|
||||||
train.weights_id = weights
|
train.weights_id = weights
|
||||||
train.weights_name = train_info.train_version
|
train.weights_name = train_info.train_version
|
||||||
train.patience = patience
|
train.patience = patience
|
||||||
train.epochs = epochs
|
train.epochs = epochs
|
||||||
ptc.add_train(train, session)
|
ptc.add_train(train, session)
|
||||||
|
|
||||||
|
|
||||||
def operate_img_label(img_list: List[ProjectImgLabel],
|
def operate_img_label(img_list: List[ProjectImgLabel],
|
||||||
|
@ -15,8 +15,6 @@ class SocketManager:
|
|||||||
self.rooms[room].remove(websocket)
|
self.rooms[room].remove(websocket)
|
||||||
if len(self.rooms[room]) == 0:
|
if len(self.rooms[room]) == 0:
|
||||||
del self.rooms[room]
|
del self.rooms[room]
|
||||||
if room.startswith('detect_rtsp_'):
|
|
||||||
print()
|
|
||||||
|
|
||||||
async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None):
|
async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None):
|
||||||
if room in self.rooms:
|
if room in self.rooms:
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
# deep_sort ----------------------这个暂时还不需要
|
|
||||||
atomicwrites==1.3.0
|
|
||||||
attrs==19.3.0
|
|
||||||
colorama==0.4.3
|
|
||||||
easydict~=1.13
|
|
||||||
entrypoints==0.3
|
|
||||||
et-xmlfile==1.0.1
|
|
||||||
flake8==3.7.9
|
|
||||||
flake8-import-order==0.18.1
|
|
||||||
importlib-metadata==1.6.0
|
|
||||||
jdcal==1.4.1
|
|
||||||
joblib==1.2.0
|
|
||||||
lap==0.4.0
|
|
||||||
mccabe==0.6.1
|
|
||||||
more-itertools==8.2.0
|
|
||||||
motmetrics==1.2.0
|
|
||||||
openpyxl==3.0.3
|
|
||||||
packaging==20.3
|
|
||||||
pluggy==0.13.1
|
|
||||||
py==1.10.0
|
|
||||||
py-cpuinfo==5.0.0
|
|
||||||
pycodestyle==2.5.0
|
|
||||||
pyflakes==2.1.1
|
|
||||||
pyparsing==2.4.7
|
|
||||||
pytest==5.4.1
|
|
||||||
pytest-benchmark==3.2.3
|
|
||||||
python-dateutil==2.8.1
|
|
||||||
pytz==2019.3
|
|
||||||
scikit-learn==1.6.1
|
|
||||||
six==1.14.0
|
|
||||||
sklearn==0.0
|
|
||||||
Vizer==0.1.5
|
|
||||||
wcwidth==0.1.9
|
|
||||||
xmltodict==0.12.0
|
|
||||||
zipp==3.1.0
|
|
||||||
mmdet~=3.0.0
|
|
||||||
pycocotools~=2.0.6
|
|
||||||
python-dotenv~=0.21.0
|
|
@ -32,9 +32,9 @@ PyYAML>=5.3.1
|
|||||||
requests>=2.32.2
|
requests>=2.32.2
|
||||||
scipy==1.13.1
|
scipy==1.13.1
|
||||||
thop>=0.1.1 # FLOPs computation
|
thop>=0.1.1 # FLOPs computation
|
||||||
# torch==2.6.0+cu124 # 本地安装
|
# torch==2.6.0+cu124
|
||||||
# torchvision==0.21.0+cu124 # 本地安装
|
# torchvision==0.21.0+cu124
|
||||||
tqdm>=4.66.3
|
tqdm>=4.66.3
|
||||||
ultralytics==8.3.75 # https://ultralytics.com
|
ultralytics==8.3.75
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
seaborn==0.11.2 # 对应的这个依赖也需要降
|
seaborn==0.11.2 # 对应的这个依赖也需要降
|
||||||
|
@ -485,7 +485,7 @@ def train(hyp, opt, device, callbacks):
|
|||||||
"ema": deepcopy(ema.ema).half(),
|
"ema": deepcopy(ema.ema).half(),
|
||||||
"updates": ema.updates,
|
"updates": ema.updates,
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"opt": vars(opt),
|
# "opt": vars(opt),
|
||||||
"git": GIT_INFO, # {remote, branch, commit} if a git repo
|
"git": GIT_INFO, # {remote, branch, commit} if a git repo
|
||||||
"date": datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user