完成项目推理模块的接口测试

This commit is contained in:
sunyg 2025-04-23 16:38:55 +08:00
parent 0033746fe1
commit 5b38e91f61
8 changed files with 234 additions and 94 deletions

View File

@ -27,25 +27,27 @@ class ProjectDetectDal(DalBase):
self.model = models.ProjectDetect self.model = models.ProjectDetect
self.schema = schemas.ProjectDetectOut self.schema = schemas.ProjectDetectOut
async def check_name(self, name: str,project_id: int) -> bool: async def check_name(self, name: str, project_id: int) -> bool:
""" """
校验推理集合名称 校验推理集合名称
""" """
count = self.get_count( count = await self.get_count(
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name]) v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
return count > 0 return count > 0
async def add_detect(self, data: schemas.ProjectDetectIn): async def add_detect(self, data: schemas.ProjectDetectIn, user_id: int):
""" """
新增集合 新增集合
""" """
detect = models.ProjectDetect(**data.model_dump()) detect = self.model(**data.model_dump())
detect.detect_no = random_str(6) detect.detect_no = random_str(6)
detect.detect_version = 0 detect.detect_version = 0
detect.detect_status = '0' detect.detect_status = '0'
url = os.create_folder(detect_url, detect.detect_no, 'images') detect.user_id = user_id
detect.folder_url = url if detect.file_type != 'rtsp':
await self.create_data(data) url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url
await self.create_model(detect)
return detect return detect
async def delete_detects(self, ids: list[int]): async def delete_detects(self, ids: list[int]):
@ -55,10 +57,17 @@ class ProjectDetectDal(DalBase):
for id_ in ids: for id_ in ids:
detect_info = await self.get_data(data_id=id_) detect_info = await self.get_data(data_id=id_)
if detect_info.file_type != 'rtsp': if detect_info.file_type != 'rtsp':
os.delete_paths(detect_info.folder_url) os.delete_paths([detect_info.folder_url])
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids]) logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == id_])
log_ids = []
log_urls = []
for log in logs: for log in logs:
os.delete_paths(log.folder_url) log_urls.append(log.folder_url)
log_ids.append(log.id)
if log_ids:
await ProjectDetectLogDal(self.db).delete_datas(ids=log_ids, v_soft=False)
if log_urls:
os.delete_paths(log_urls)
await self.delete_datas(ids=ids, v_soft=False) await self.delete_datas(ids=ids, v_soft=False)
@ -71,10 +80,10 @@ class ProjectDetectFileDal(DalBase):
self.schema = schemas.ProjectDetectFileOut self.schema = schemas.ProjectDetectFileOut
async def file_count(self, detect_id: int) -> int: async def file_count(self, detect_id: int) -> int:
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id]) count = await self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
return count return count
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]): async def add_files(self, detect: models.ProjectDetect, files: list[UploadFile], user_id: int):
images = [] images = []
obs = MyObs() obs = MyObs()
for file in files: for file in files:
@ -83,23 +92,24 @@ class ProjectDetectFileDal(DalBase):
image.file_name = file.filename image.file_name = file.filename
# 保存原图 # 保存原图
path = os.save_images(detect.folder_url, file=file) path = os.save_images(detect.folder_url, file=file)
image.image_url = path image.file_url = path
image.user_id = user_id
# 上传到obs # 上传到obs
object_key = detect.detect_no + '/' + file.filename object_key = detect.detect_no + '/' + file.filename
success, key, url = obs.put_file(object_key=object_key, file_path=path) success, key, url = obs.put_file(object_key=object_key, file_path=path)
if success: if success:
image.object_key = object_key image.object_key = object_key
image.thumb_image_url = url image.thumb_file_url = url
else: else:
raise CustomException("obs上传失败", code=status.HTTP_ERROR) raise CustomException("obs上传失败", code=status.HTTP_ERROR)
images.append(image) images.append(image)
await self.create_datas(images) await self.create_models(images)
async def delete_files(self, ids: list[int]): async def delete_files(self, ids: list[int]):
file_urls = [] file_urls = []
object_keys = [] object_keys = []
for id_ in ids: for id_ in ids:
file = self.get_data(data_id=id_) file = await self.get_data(data_id=id_)
if file: if file:
file_urls.append(file.file_url) file_urls.append(file.file_url)
object_keys.append(file.object_key) object_keys.append(file.object_key)
@ -107,13 +117,14 @@ class ProjectDetectFileDal(DalBase):
MyObs().del_objects(object_keys) MyObs().del_objects(object_keys)
await self.delete_datas(ids, v_soft=False) await self.delete_datas(ids, v_soft=False)
class ProjectDetectLogDal(DalBase): class ProjectDetectLogDal(DalBase):
def __init__(self, db: AsyncSession): def __init__(self, db: AsyncSession):
super(ProjectDetectLogDal, self).__init__() super(ProjectDetectLogDal, self).__init__()
self.db = db self.db = db
self.model = models.ProjectDetectLog self.model = models.ProjectDetectLog
self.schema = schemas.ProjectDetectLogSimpleOut self.schema = schemas.ProjectDetectLogOut
class ProjectDetectLogFileDal(DalBase): class ProjectDetectLogFileDal(DalBase):

View File

@ -11,15 +11,15 @@ class ProjectDetect(BaseModel):
__tablename__ = "project_detect" __tablename__ = "project_detect"
__table_args__ = ({'comment': '项目推理集合'}) __table_args__ = ({'comment': '项目推理集合'})
project_id: Mapped[int] = mapped_column(Integer, nullable=False) project_id: Mapped[int] = mapped_column(Integer)
detect_name: Mapped[str] = mapped_column(String(64), nullable=False) detect_name: Mapped[str] = mapped_column(String(64))
detect_version: Mapped[int] = mapped_column(Integer) detect_version: Mapped[int] = mapped_column(Integer)
detect_no: Mapped[str] = mapped_column(String(32)) detect_no: Mapped[str] = mapped_column(String(32))
detect_status: Mapped[int] = mapped_column(Integer) detect_status: Mapped[int] = mapped_column(Integer)
file_type: Mapped[str] = mapped_column(String(10)) file_type: Mapped[str] = mapped_column(String(10))
folder_url: Mapped[str] = mapped_column(String(255)) folder_url: Mapped[str] = mapped_column(String(255), nullable=True)
rtsp_url: Mapped[str] = mapped_column(String(255)) rtsp_url: Mapped[str] = mapped_column(String(255), nullable=True)
user_id: Mapped[int] = mapped_column(Integer, nullable=False) user_id: Mapped[int] = mapped_column(Integer)
class ProjectDetectFile(BaseModel): class ProjectDetectFile(BaseModel):

View File

@ -18,6 +18,8 @@ class ProjectDetectIn(BaseModel):
detect_name: Optional[str] = Field(..., description="推理集合名称") detect_name: Optional[str] = Field(..., description="推理集合名称")
rtsp_url: Optional[str] = Field(None, description="视频流地址") rtsp_url: Optional[str] = Field(None, description="视频流地址")
model_config = ConfigDict(from_attributes=True)
class ProjectDetectPager(BaseModel): class ProjectDetectPager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id") project_id: Optional[int] = Field(..., description="项目id")

View File

@ -21,13 +21,15 @@ async def before_detect(
detect_in: schemas.ProjectDetectLogIn, detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect, detect: models.ProjectDetect,
train: train_models.ProjectTrain, train: train_models.ProjectTrain,
db: AsyncSession): db: AsyncSession,
user_id: int):
""" """
开始推理 开始推理
:param detect: :param detect:
:param detect_in: :param detect_in:
:param train: :param train:
:param db: :param db:
:param user_id:
:return: :return:
""" """
# 推理版本 # 推理版本
@ -52,19 +54,33 @@ async def before_detect(
detect_log.pt_url = pt_url detect_log.pt_url = pt_url
detect_log.folder_url = img_url detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url detect_log.detect_folder_url = out_url
await crud.ProjectDetectLogDal(db).create_data(detect_log) detect_log.user_id = user_id
await crud.ProjectDetectLogDal(db).create_model(detect_log)
return detect_log return detect_log
def run_img_loop(
weights: str,
source: str,
project: str,
name: str,
detect_id: int,
is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
async def run_detect_img( async def run_detect_img(
weights: str, weights: str,
source: str, source: str,
project: str, project: str,
name: str, name: str,
log_id: int,
detect_id: int, detect_id: int,
db: AsyncSession, is_gpu: str):
rd: Redis):
""" """
执行yolov5的推理 执行yolov5的推理
:param weights: 权重文件 :param weights: 权重文件
@ -74,7 +90,7 @@ async def run_detect_img(
:param log_id: 日志id :param log_id: 日志id
:param detect_id: 推理集合id :param detect_id: 推理集合id
:param db: 数据库session :param db: 数据库session
:param rd: Redis :param is_gpu: 是否gpu加速
:return: :return:
""" """
yolo_path = os.file_path(yolo_url, 'detect.py') yolo_path = os.file_path(yolo_url, 'detect.py')
@ -82,10 +98,9 @@ async def run_detect_img(
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n") await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
project, "--save-txt", "--conf-thres", "0.4"] project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本 # 判断是否存在cuda版本
if is_gpu == 'True': if is_gpu == 'True':
commend.append("--device", "0") commend.append("--device=0")
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
commend, commend,
@ -101,40 +116,51 @@ async def run_detect_img(
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n': if line != '\n':
await room_manager.send_to_room(room, line + '\n') await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码 # 等待进程结束并获取返回码
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
await room_manager.send_to_room(room, 'error') await room_manager.send_to_room(room, 'error')
else: else:
await room_manager.send_to_room(room, 'success') await room_manager.send_to_room(room, 'success')
detect_files = crud.ProjectDetectFileDal(db).get_data(
v_where=[models.ProjectDetectFile.detect_id == detect_id])
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis): async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
"""
更新推理集合的状态
"""
detect_dal = crud.ProjectDetectDal(db)
detect = await detect_dal.get_data(detect_id)
detect.detect_version = detect.detect_version + 1
await detect_dal.put_data(data_id=detect_id, data=detect)
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
limit=0,
v_where=[models.ProjectDetectFile.detect_id == detect_id],
v_return_objs=True,
v_return_count=False)
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.file_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
""" """
rtsp 视频流推理 rtsp 视频流推理
:param detect_id: 训练集的id :param detect_id: 训练集的id
:param weights_pt: 权重文件 :param weights_pt: 权重文件
:param rtsp_url: 视频流地址 :param rtsp_url: 视频流地址
:param data: yaml文件 :param data: yaml文件
:param rd: Redis :redis :param is_gpu: 是否启用加速
:return: :return:
""" """
room = 'detect_rtsp_' + str(detect_id) room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU # 选择设备CPU 或 GPU
device = select_device('cpu') device = select_device('cpu')
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本 # 判断是否存在cuda版本
if is_gpu == 'True': if is_gpu == 'True':
device = select_device('cuda:0') device = select_device('cuda:0')
@ -208,19 +234,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
break break
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession): def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
# 运行异步函数 # 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db)) loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
# 可选: 关闭循环 # 可选: 关闭循环
loop.close() loop.close()

View File

@ -6,18 +6,24 @@
# @IDE : PyCharm # @IDE : PyCharm
# @desc : 路由,视图文件 # @desc : 路由,视图文件
from utils import os_utils as osu
from core.dependencies import IdList from core.dependencies import IdList
from core.database import redis_getter from core.database import redis_getter
from . import schemas, crud, params, service from . import schemas, crud, params, service, models
from utils.websocket_server import room_manager from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal from apps.business.train.crud import ProjectTrainDal
from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import zipfile
import tempfile
import threading import threading
from redis.asyncio import Redis from redis.asyncio import Redis
from fastapi import Depends, APIRouter, Form, UploadFile from fastapi.responses import FileResponse
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
app = APIRouter() app = APIRouter()
@ -26,22 +32,27 @@ app = APIRouter()
########################################################### ###########################################################
# 项目推理集合信息 # 项目推理集合信息
########################################################### ###########################################################
@app.get("/list", summary="获取项目推理集合信息列表") @app.get("/list/{proj_id}", summary="获取推理集合列表")
async def detect_list( async def detect_list(
p: params.ProjectDetectParams = Depends(), proj_id: int,
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True) datas = await crud.ProjectDetectDal(auth.db).get_datas(
return SuccessResponse(datas, count=count) limit=0,
v_where=[models.ProjectDetect.project_id == proj_id],
v_order="asc",
v_order_field="id",
v_return_count=False)
return SuccessResponse(datas)
@app.post("/", summary="创建项目推理集合信息") @app.post("/", summary="创建推理集合")
async def add_detect( async def add_detect(
data: schemas.ProjectDetectIn, data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db) detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id): if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合") return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.create_data(data=data) await detect_dal.add_detect(data=data, user_id=auth.user.id)
return SuccessResponse(msg="保存成功") return SuccessResponse(msg="保存成功")
@ -49,18 +60,18 @@ async def add_detect(
async def delete_detect( async def delete_detect(
ids: IdList = Depends(), ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False) await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
return SuccessResponse("删除成功") return SuccessResponse("删除成功")
########################################################### ###########################################################
# 项目推理集合文件信息 # 项目推理集合文件信息
########################################################### ###########################################################
@app.get("/file", summary="获取项目推理集合文件信息列表") @app.get("/files", summary="获取推理集合文件列表")
async def file_list( async def file_list(
p: params.ProjectDetectFileParams = Depends(), p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
if p.limit: if p.limit > 0:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True) datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
else: else:
@ -68,20 +79,24 @@ async def file_list(
return SuccessResponse(datas) return SuccessResponse(datas)
@app.post("/file", summary="上传项目推理集合文件") @app.post("/files", summary="上传项目推理集合文件")
async def upload_file( async def upload_file(
detect_id: int = Form(...), detect_id: int = Form(...),
files: list[UploadFile] = Form(...), files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
file_dal = crud.ProjectDetectFileDal(auth.db) file_dal = crud.ProjectDetectFileDal(auth.db)
detect_out = file_dal.get_data(data_id=detect_id) detect_info = await detect_dal.get_data(data_id=detect_id)
if detect_out is None: if detect_info is None:
return ErrorResponse("训练集合查询失败,请刷新后再试") return ErrorResponse("训练集合查询失败,请刷新后再试")
await file_dal.add_file(detect_out, files) if detect_info.file_type != 'rtsp':
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
return ErrorResponse("上传的文件中存在不符合的文件类型")
await file_dal.add_files(detect_info, files, auth.user.id)
return SuccessResponse(msg="上传成功") return SuccessResponse(msg="上传成功")
@app.delete("/file", summary="删除项目推理集合文件信息") @app.delete("/files", summary="删除推理集合文件")
async def delete_file( async def delete_file(
ids: IdList = Depends(), ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
@ -89,29 +104,34 @@ async def delete_file(
return SuccessResponse("删除成功") return SuccessResponse("删除成功")
@app.post("/detect", summary="开始推理") @app.post("/start", summary="开始推理")
def run_detect_yolo( async def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn, detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()), auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)): rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db) detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db) train_dal = ProjectTrainDal(auth.db)
detect = detect_dal.get_data(detect_log_in.detect_id) detect = await detect_dal.get_data(detect_log_in.detect_id)
if detect is None: if detect is None:
return ErrorResponse(msg="训练集合不存在") return ErrorResponse(msg="训练集合不存在")
train = train_dal.get_data(detect_log_in.train_id) train = await train_dal.get_data(detect_log_in.train_id)
if train is None: if train is None:
return ErrorResponse("训练权重不存在") return ErrorResponse("训练权重不存在")
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None: if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片") return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
is_gpu = await rd.get('is_gpu')
if detect.file_type == 'img' or detect.file_type == 'video': if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = service.before_detect(detect_log_in, detect, train, auth.db) detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop, thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url, args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version, detect_log.detect_folder_url, detect_log.detect_version,
detect_log.id, detect_log.detect_id, auth.db,)) detect_log.detect_id, is_gpu))
thread_train.start() thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp': elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id) room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room): if not room_manager.rooms.get(room):
@ -120,7 +140,7 @@ def run_detect_yolo(
else: else:
weights_pt = train.last_pt weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop, thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,)) args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train.start() thread_train.start()
return SuccessResponse(msg="执行成功") return SuccessResponse(msg="执行成功")
@ -128,16 +148,50 @@ def run_detect_yolo(
########################################################### ###########################################################
# 项目推理记录信息 # 项目推理记录信息
########################################################### ###########################################################
@app.get("/log", summary="获取项目推理记录列表") @app.get("/logs", summary="获取推理记录列表")
async def log_pager( async def logs(
p: params.ProjectDetectLogParams = Depends(), p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True) datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
@app.get("/log_files", summary="获取项目推理记录文件列表") @app.get("/logs/download/{log_id}", summary="下载推理结果")
async def log_files( async def logs_download(
log_id: int,
background_tasks: BackgroundTasks,
auth: Auth = Depends(AllUserAuth())
):
# 获取日志记录
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
if detect_log is None:
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
# 检查源文件夹是否存在
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
if not os.path.exists(folder_path):
return ErrorResponse(msg="推理结果文件夹不存在")
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
zip_file_path = osu.zip_folder(folder_path, zip_filename)
# 后台任务删掉这个
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
# 直接返回文件响应
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
def cleanup_temp_dir(temp_dir: str):
"""清理临时目录的后台任务"""
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
print(f"清理临时目录失败: {e}")
@app.get("/logs/files", summary="获取项目推理记录文件列表")
async def logs_files(
p: params.ProjectDetectLogFileParams = Depends(), p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)

View File

@ -1,7 +1,13 @@
import os import os
import shutil import shutil
from fastapi import UploadFile import zipfile
import tempfile
from PIL import Image from PIL import Image
from fastapi import UploadFile
img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'}
def file_path(*path): def file_path(*path):
@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
def delete_file_if_exists(*file_paths: str): def delete_file_if_exists(*file_paths: str):
""" """
删除文件 删除文件
:param file_path: :param file_paths:
:return: :return:
""" """
for path in file_paths: for path in file_paths:
@ -110,3 +116,53 @@ def delete_paths(paths):
print(f"路径删除失败 {path}: {e}") print(f"路径删除失败 {path}: {e}")
else: else:
print(f"路径不存在: {path}") print(f"路径不存在: {path}")
def is_extensions(extension_type: str, file_name: str):
"""
校验文件名
"""
if extension_type == 'img':
file_extensions = img_extensions
elif extension_type == 'video':
file_extensions = video_extensions
else:
file_extensions = []
return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions
def zip_folder(folder_path: str, zip_filename: str) -> str:
"""
将指定文件夹打包成 ZIP 文件并返回 ZIP 文件的路径
:param folder_path: 要打包的文件夹路径
:param zip_filename: 生成的 ZIP 文件名不带扩展名
:return: 生成的 ZIP 文件的完整路径
"""
# 检查文件夹是否存在
if not os.path.isdir(folder_path):
raise ValueError(f"文件夹路径不存在: {folder_path}")
# 确保 ZIP 文件名以 .zip 结尾
if not zip_filename.endswith(".zip"):
zip_filename += ".zip"
# 创建临时目录用于存储 ZIP 文件
temp_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(temp_dir, zip_filename)
try:
# 打包文件夹为 ZIP 文件
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f:
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径
zip_f.write(file_path, arc_name)
# 返回生成的 ZIP 文件路径
return zip_file_path
except Exception as e:
# 清理临时文件夹并重新抛出异常
shutil.rmtree(temp_dir, ignore_errors=True)
raise RuntimeError(f"打包失败: {e}")

View File

@ -25,7 +25,7 @@ if str(ROOT) not in sys.path:
if platform.system() != "Windows": if platform.system() != "Windows":
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import ( from utils.yolov5.models.common import (
C3, C3,
C3SPP, C3SPP,
C3TR, C3TR,
@ -49,11 +49,11 @@ from models.common import (
GhostConv, GhostConv,
Proto, Proto,
) )
from models.experimental import MixConv2d from utils.yolov5.models.experimental import MixConv2d
from utils.autoanchor import check_anchor_order from utils.yolov5.utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
from utils.plots import feature_visualization from utils.yolov5.utils.plots import feature_visualization
from utils.torch_utils import ( from utils.yolov5.utils.torch_utils import (
fuse_conv_and_bn, fuse_conv_and_bn,
initialize_weights, initialize_weights,
model_info, model_info,

View File

@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
Removes incomplete downloads. Removes incomplete downloads.
""" """
from utils.general import LOGGER from utils.yolov5.utils.general import LOGGER
file = Path(file) file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup """Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
versions. versions.
""" """
from utils.general import LOGGER from utils.yolov5.utils.general import LOGGER
def github_assets(repository, version="latest"): def github_assets(repository, version="latest"):
"""Fetches GitHub repository release tag and asset names using the GitHub API.""" """Fetches GitHub repository release tag and asset names using the GitHub API."""