增加删除推理集合接口

This commit is contained in:
2025-03-14 17:54:21 +08:00
parent 9e79fb6a6d
commit af65911db3
4 changed files with 123 additions and 52 deletions

View File

@ -2,10 +2,10 @@ import threading
import asyncio import asyncio
from typing import List from typing import List
from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.websocket.web_socket_server import room_manager
from app.db.db_session import get_db from app.db.db_session import get_db
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.model.crud import project_detect_crud as pdc from app.model.crud import project_detect_crud as pdc
@ -52,6 +52,18 @@ def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
return rc.response_success(msg="新增成功", data=detect.id) return rc.response_success(msg="新增成功", data=detect.id)
@detect.get("/del_detect/{detect_id}")
def del_detect(detect_id: int, session: Session = Depends(get_db)):
"""
删除训练集合
:param detect_id:
:param session:
:return:
"""
pds.del_detect(detect_id, session)
return rc.response_success(msg="删除成功")
@detect.post("/get_img_list") @detect.post("/get_img_list")
def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)): def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)):
""" """
@ -127,6 +139,8 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend
detect_log.detect_version, detect_log.id, detect_log.detect_id, session,)) detect_log.detect_version, detect_log.id, detect_log.detect_id, session,))
thread_train.start() thread_train.start()
elif detect.file_type == 'rtsp': elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best': if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt weights_pt = train.best_pt
else: else:

View File

@ -190,6 +190,17 @@ def get_log_list(detect_id: int, session: Session):
return result return result
def get_logs(detect_id: int, session: Session):
"""
获取推理记录
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
return query.all()
def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session): def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session):
""" """
获取分页数据 获取分页数据

View File

@ -1,3 +1,5 @@
import time
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
@ -37,6 +39,24 @@ def add_detect(detect_in: ProjectDetectIn, session: Session):
return detect return detect
def del_detect(detect_id: int, session: Session):
"""
删除推理集合和推理记录
:param detect_id:
:param session:
:return:
"""
folder_url = []
detect = pdc.get_detect_by_id(detect_id, session)
session.delete(detect)
folder_url.append(detect.folder_url)
detect_logs = pdc.get_logs(detect_id, session)
for log in detect_logs:
folder_url.append(log.detect_folder_url)
os.create_folder(folder_url)
session.commit()
def check_image_name(detect_id: int, files: List[UploadFile], session: Session): def check_image_name(detect_id: int, files: List[UploadFile], session: Session):
""" """
校验上传的文件名称是否重复 校验上传的文件名称是否重复
@ -187,7 +207,6 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
:return: :return:
""" """
room = 'detect_rtsp_' + str(detect_id) room = 'detect_rtsp_' + str(detect_id)
await room_manager.send_to_room(room, '开始推理rtsp视频流')
# 选择设备CPU 或 GPU # 选择设备CPU 或 GPU
device = select_device('cpu') device = select_device('cpu')
@ -204,7 +223,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3)# 等待3s等待websocket进入
for path, im, im0s, vid_cap, s in dataset: for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
with dt[0]: with dt[0]:
im = torch.from_numpy(im).to(model.device) im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
@ -252,6 +274,9 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
if ret: if ret:
frame_data = jpeg.tobytes() frame_data = jpeg.tobytes()
await room_manager.send_stream_to_room(room, frame_data) await room_manager.send_stream_to_room(room, frame_data)
else:
break

View File

@ -89,3 +89,24 @@ def delete_file_if_exists(*file_paths: str):
for path in file_paths: for path in file_paths:
if os.path.exists(path): # 检查文件是否存在 if os.path.exists(path): # 检查文件是否存在
os.remove(path) # 删除文件 os.remove(path) # 删除文件
def delete_paths(paths):
"""
删除给定路径数组中的每个路径及其包含的所有内容。
:param paths: 文件或目录路径的列表
"""
for path in paths:
if os.path.exists(path):
try:
if os.path.isfile(path) or os.path.islink(path):
# 如果是文件或符号链接,则删除
os.remove(path)
print(f"Deleted file: {path}")
elif os.path.isdir(path):
# 如果是目录,则递归删除
shutil.rmtree(path)
except Exception as e:
print(f"路径删除失败 {path}: {e}")
else:
print(f"路径不存在: {path}")