完成项目推理模块的接口测试
This commit is contained in:
parent
0033746fe1
commit
5b38e91f61
@ -27,25 +27,27 @@ class ProjectDetectDal(DalBase):
|
||||
self.model = models.ProjectDetect
|
||||
self.schema = schemas.ProjectDetectOut
|
||||
|
||||
async def check_name(self, name: str,project_id: int) -> bool:
|
||||
async def check_name(self, name: str, project_id: int) -> bool:
|
||||
"""
|
||||
校验推理集合名称
|
||||
"""
|
||||
count = self.get_count(
|
||||
count = await self.get_count(
|
||||
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
|
||||
return count > 0
|
||||
|
||||
async def add_detect(self, data: schemas.ProjectDetectIn):
|
||||
async def add_detect(self, data: schemas.ProjectDetectIn, user_id: int):
|
||||
"""
|
||||
新增集合
|
||||
"""
|
||||
detect = models.ProjectDetect(**data.model_dump())
|
||||
detect = self.model(**data.model_dump())
|
||||
detect.detect_no = random_str(6)
|
||||
detect.detect_version = 0
|
||||
detect.detect_status = '0'
|
||||
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||
detect.folder_url = url
|
||||
await self.create_data(data)
|
||||
detect.user_id = user_id
|
||||
if detect.file_type != 'rtsp':
|
||||
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||
detect.folder_url = url
|
||||
await self.create_model(detect)
|
||||
return detect
|
||||
|
||||
async def delete_detects(self, ids: list[int]):
|
||||
@ -55,10 +57,17 @@ class ProjectDetectDal(DalBase):
|
||||
for id_ in ids:
|
||||
detect_info = await self.get_data(data_id=id_)
|
||||
if detect_info.file_type != 'rtsp':
|
||||
os.delete_paths(detect_info.folder_url)
|
||||
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids])
|
||||
os.delete_paths([detect_info.folder_url])
|
||||
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == id_])
|
||||
log_ids = []
|
||||
log_urls = []
|
||||
for log in logs:
|
||||
os.delete_paths(log.folder_url)
|
||||
log_urls.append(log.folder_url)
|
||||
log_ids.append(log.id)
|
||||
if log_ids:
|
||||
await ProjectDetectLogDal(self.db).delete_datas(ids=log_ids, v_soft=False)
|
||||
if log_urls:
|
||||
os.delete_paths(log_urls)
|
||||
await self.delete_datas(ids=ids, v_soft=False)
|
||||
|
||||
|
||||
@ -71,10 +80,10 @@ class ProjectDetectFileDal(DalBase):
|
||||
self.schema = schemas.ProjectDetectFileOut
|
||||
|
||||
async def file_count(self, detect_id: int) -> int:
|
||||
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
count = await self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
return count
|
||||
|
||||
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]):
|
||||
async def add_files(self, detect: models.ProjectDetect, files: list[UploadFile], user_id: int):
|
||||
images = []
|
||||
obs = MyObs()
|
||||
for file in files:
|
||||
@ -83,23 +92,24 @@ class ProjectDetectFileDal(DalBase):
|
||||
image.file_name = file.filename
|
||||
# 保存原图
|
||||
path = os.save_images(detect.folder_url, file=file)
|
||||
image.image_url = path
|
||||
image.file_url = path
|
||||
image.user_id = user_id
|
||||
# 上传到obs
|
||||
object_key = detect.detect_no + '/' + file.filename
|
||||
success, key, url = obs.put_file(object_key=object_key, file_path=path)
|
||||
if success:
|
||||
image.object_key = object_key
|
||||
image.thumb_image_url = url
|
||||
image.thumb_file_url = url
|
||||
else:
|
||||
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
|
||||
images.append(image)
|
||||
await self.create_datas(images)
|
||||
await self.create_models(images)
|
||||
|
||||
async def delete_files(self, ids: list[int]):
|
||||
file_urls = []
|
||||
object_keys = []
|
||||
for id_ in ids:
|
||||
file = self.get_data(data_id=id_)
|
||||
file = await self.get_data(data_id=id_)
|
||||
if file:
|
||||
file_urls.append(file.file_url)
|
||||
object_keys.append(file.object_key)
|
||||
@ -107,13 +117,14 @@ class ProjectDetectFileDal(DalBase):
|
||||
MyObs().del_objects(object_keys)
|
||||
await self.delete_datas(ids, v_soft=False)
|
||||
|
||||
|
||||
class ProjectDetectLogDal(DalBase):
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super(ProjectDetectLogDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetectLog
|
||||
self.schema = schemas.ProjectDetectLogSimpleOut
|
||||
self.schema = schemas.ProjectDetectLogOut
|
||||
|
||||
|
||||
class ProjectDetectLogFileDal(DalBase):
|
||||
|
@ -11,15 +11,15 @@ class ProjectDetect(BaseModel):
|
||||
__tablename__ = "project_detect"
|
||||
__table_args__ = ({'comment': '项目推理集合'})
|
||||
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
project_id: Mapped[int] = mapped_column(Integer)
|
||||
detect_name: Mapped[str] = mapped_column(String(64))
|
||||
detect_version: Mapped[int] = mapped_column(Integer)
|
||||
detect_no: Mapped[str] = mapped_column(String(32))
|
||||
detect_status: Mapped[int] = mapped_column(Integer)
|
||||
file_type: Mapped[str] = mapped_column(String(10))
|
||||
folder_url: Mapped[str] = mapped_column(String(255))
|
||||
rtsp_url: Mapped[str] = mapped_column(String(255))
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
folder_url: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
rtsp_url: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
|
||||
class ProjectDetectFile(BaseModel):
|
||||
|
@ -18,6 +18,8 @@ class ProjectDetectIn(BaseModel):
|
||||
detect_name: Optional[str] = Field(..., description="推理集合名称")
|
||||
rtsp_url: Optional[str] = Field(None, description="视频流地址")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectDetectPager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
|
@ -21,13 +21,15 @@ async def before_detect(
|
||||
detect_in: schemas.ProjectDetectLogIn,
|
||||
detect: models.ProjectDetect,
|
||||
train: train_models.ProjectTrain,
|
||||
db: AsyncSession):
|
||||
db: AsyncSession,
|
||||
user_id: int):
|
||||
"""
|
||||
开始推理
|
||||
:param detect:
|
||||
:param detect_in:
|
||||
:param train:
|
||||
:param db:
|
||||
:param user_id:
|
||||
:return:
|
||||
"""
|
||||
# 推理版本
|
||||
@ -52,19 +54,33 @@ async def before_detect(
|
||||
detect_log.pt_url = pt_url
|
||||
detect_log.folder_url = img_url
|
||||
detect_log.detect_folder_url = out_url
|
||||
await crud.ProjectDetectLogDal(db).create_data(detect_log)
|
||||
detect_log.user_id = user_id
|
||||
await crud.ProjectDetectLogDal(db).create_model(detect_log)
|
||||
return detect_log
|
||||
|
||||
|
||||
def run_img_loop(
|
||||
weights: str,
|
||||
source: str,
|
||||
project: str,
|
||||
name: str,
|
||||
detect_id: int,
|
||||
is_gpu: str):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_img(weights, source, project, name, detect_id, is_gpu))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
async def run_detect_img(
|
||||
weights: str,
|
||||
source: str,
|
||||
project: str,
|
||||
name: str,
|
||||
log_id: int,
|
||||
detect_id: int,
|
||||
db: AsyncSession,
|
||||
rd: Redis):
|
||||
is_gpu: str):
|
||||
"""
|
||||
执行yolov5的推理
|
||||
:param weights: 权重文件
|
||||
@ -74,7 +90,7 @@ async def run_detect_img(
|
||||
:param log_id: 日志id
|
||||
:param detect_id: 推理集合id
|
||||
:param db: 数据库session
|
||||
:param rd: Redis
|
||||
:param is_gpu: 是否gpu加速
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
@ -82,10 +98,9 @@ async def run_detect_img(
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
|
||||
project, "--save-txt", "--conf-thres", "0.4"]
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device", "0")
|
||||
commend.append("--device=0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
@ -101,40 +116,51 @@ async def run_detect_img(
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
await room_manager.send_to_room(room, 'error')
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
detect_files = crud.ProjectDetectFileDal(db).get_data(
|
||||
v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
detect_log_files = []
|
||||
for detect_file in detect_files:
|
||||
detect_log_img = models.ProjectDetectLogFile()
|
||||
detect_log_img.log_id = log_id
|
||||
image_url = os.file_path(project, name, detect_file.file_name)
|
||||
detect_log_img.image_url = image_url
|
||||
detect_log_img.file_name = detect_file.file_name
|
||||
detect_log_files.append(detect_log_img)
|
||||
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
|
||||
|
||||
|
||||
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
|
||||
"""
|
||||
更新推理集合的状态
|
||||
"""
|
||||
detect_dal = crud.ProjectDetectDal(db)
|
||||
detect = await detect_dal.get_data(detect_id)
|
||||
detect.detect_version = detect.detect_version + 1
|
||||
await detect_dal.put_data(data_id=detect_id, data=detect)
|
||||
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
|
||||
limit=0,
|
||||
v_where=[models.ProjectDetectFile.detect_id == detect_id],
|
||||
v_return_objs=True,
|
||||
v_return_count=False)
|
||||
detect_log_files = []
|
||||
for detect_file in detect_files:
|
||||
detect_log_img = models.ProjectDetectLogFile()
|
||||
detect_log_img.log_id = log_id
|
||||
image_url = os.file_path(project, name, detect_file.file_name)
|
||||
detect_log_img.file_url = image_url
|
||||
detect_log_img.file_name = detect_file.file_name
|
||||
detect_log_files.append(detect_log_img)
|
||||
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
|
||||
|
||||
|
||||
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
|
||||
"""
|
||||
rtsp 视频流推理
|
||||
:param detect_id: 训练集的id
|
||||
:param weights_pt: 权重文件
|
||||
:param rtsp_url: 视频流地址
|
||||
:param data: yaml文件
|
||||
:param rd: Redis :redis
|
||||
:param is_gpu: 是否启用加速
|
||||
:return:
|
||||
"""
|
||||
room = 'detect_rtsp_' + str(detect_id)
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
device = select_device('cuda:0')
|
||||
@ -208,19 +234,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
||||
break
|
||||
|
||||
|
||||
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
|
||||
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
|
||||
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
@ -6,18 +6,24 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
|
||||
from utils import os_utils as osu
|
||||
from core.dependencies import IdList
|
||||
from core.database import redis_getter
|
||||
from . import schemas, crud, params, service
|
||||
from . import schemas, crud, params, service, models
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
import threading
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
@ -26,22 +32,27 @@ app = APIRouter()
|
||||
###########################################################
|
||||
# 项目推理集合信息
|
||||
###########################################################
|
||||
@app.get("/list", summary="获取项目推理集合信息列表")
|
||||
@app.get("/list/{proj_id}", summary="获取推理集合列表")
|
||||
async def detect_list(
|
||||
p: params.ProjectDetectParams = Depends(),
|
||||
proj_id: int,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
datas = await crud.ProjectDetectDal(auth.db).get_datas(
|
||||
limit=0,
|
||||
v_where=[models.ProjectDetect.project_id == proj_id],
|
||||
v_order="asc",
|
||||
v_order_field="id",
|
||||
v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/", summary="创建项目推理集合信息")
|
||||
@app.post("/", summary="创建推理集合")
|
||||
async def add_detect(
|
||||
data: schemas.ProjectDetectIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
if await detect_dal.check_name(data.detect_name, data.project_id):
|
||||
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
||||
await detect_dal.create_data(data=data)
|
||||
await detect_dal.add_detect(data=data, user_id=auth.user.id)
|
||||
return SuccessResponse(msg="保存成功")
|
||||
|
||||
|
||||
@ -49,18 +60,18 @@ async def add_detect(
|
||||
async def delete_detect(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理集合文件信息
|
||||
###########################################################
|
||||
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
||||
@app.get("/files", summary="获取推理集合文件列表")
|
||||
async def file_list(
|
||||
p: params.ProjectDetectFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
if p.limit:
|
||||
if p.limit > 0:
|
||||
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
else:
|
||||
@ -68,20 +79,24 @@ async def file_list(
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/file", summary="上传项目推理集合文件")
|
||||
@app.post("/files", summary="上传项目推理集合文件")
|
||||
async def upload_file(
|
||||
detect_id: int = Form(...),
|
||||
files: list[UploadFile] = Form(...),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
file_dal = crud.ProjectDetectFileDal(auth.db)
|
||||
detect_out = file_dal.get_data(data_id=detect_id)
|
||||
if detect_out is None:
|
||||
detect_info = await detect_dal.get_data(data_id=detect_id)
|
||||
if detect_info is None:
|
||||
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
||||
await file_dal.add_file(detect_out, files)
|
||||
if detect_info.file_type != 'rtsp':
|
||||
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
|
||||
return ErrorResponse("上传的文件中存在不符合的文件类型")
|
||||
await file_dal.add_files(detect_info, files, auth.user.id)
|
||||
return SuccessResponse(msg="上传成功")
|
||||
|
||||
|
||||
@app.delete("/file", summary="删除项目推理集合文件信息")
|
||||
@app.delete("/files", summary="删除推理集合文件")
|
||||
async def delete_file(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
@ -89,29 +104,34 @@ async def delete_file(
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.post("/detect", summary="开始推理")
|
||||
def run_detect_yolo(
|
||||
@app.post("/start", summary="开始推理")
|
||||
async def run_detect_yolo(
|
||||
detect_log_in: schemas.ProjectDetectLogIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
train_dal = ProjectTrainDal(auth.db)
|
||||
detect = detect_dal.get_data(detect_log_in.detect_id)
|
||||
detect = await detect_dal.get_data(detect_log_in.detect_id)
|
||||
if detect is None:
|
||||
return ErrorResponse(msg="训练集合不存在")
|
||||
train = train_dal.get_data(detect_log_in.train_id)
|
||||
train = await train_dal.get_data(detect_log_in.train_id)
|
||||
if train is None:
|
||||
return ErrorResponse("训练权重不存在")
|
||||
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
if file_count == 0 and detect.rtsp_url is None:
|
||||
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
||||
is_gpu = await rd.get('is_gpu')
|
||||
if detect.file_type == 'img' or detect.file_type == 'video':
|
||||
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
||||
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
|
||||
thread_train = threading.Thread(target=service.run_img_loop,
|
||||
args=(detect_log.pt_url, detect_log.folder_url,
|
||||
detect_log.detect_folder_url, detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, auth.db,))
|
||||
detect_log.detect_id, is_gpu))
|
||||
thread_train.start()
|
||||
await service.update_sql(
|
||||
auth.db, detect_log.detect_id,
|
||||
detect_log.id, detect_log.detect_folder_url,
|
||||
detect_log.detect_version)
|
||||
elif detect.file_type == 'rtsp':
|
||||
room = 'detect_rtsp_' + str(detect.id)
|
||||
if not room_manager.rooms.get(room):
|
||||
@ -120,7 +140,7 @@ def run_detect_yolo(
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
@ -128,16 +148,50 @@ def run_detect_yolo(
|
||||
###########################################################
|
||||
# 项目推理记录信息
|
||||
###########################################################
|
||||
@app.get("/log", summary="获取项目推理记录列表")
|
||||
async def log_pager(
|
||||
@app.get("/logs", summary="获取推理记录列表")
|
||||
async def logs(
|
||||
p: params.ProjectDetectLogParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
||||
async def log_files(
|
||||
@app.get("/logs/download/{log_id}", summary="下载推理结果")
|
||||
async def logs_download(
|
||||
log_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth: Auth = Depends(AllUserAuth())
|
||||
):
|
||||
# 获取日志记录
|
||||
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
|
||||
if detect_log is None:
|
||||
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
|
||||
|
||||
# 检查源文件夹是否存在
|
||||
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
|
||||
if not os.path.exists(folder_path):
|
||||
return ErrorResponse(msg="推理结果文件夹不存在")
|
||||
|
||||
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
|
||||
zip_file_path = osu.zip_folder(folder_path, zip_filename)
|
||||
|
||||
# 后台任务删掉这个
|
||||
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
|
||||
|
||||
# 直接返回文件响应
|
||||
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
|
||||
|
||||
|
||||
def cleanup_temp_dir(temp_dir: str):
|
||||
"""清理临时目录的后台任务"""
|
||||
try:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"清理临时目录失败: {e}")
|
||||
|
||||
|
||||
@app.get("/logs/files", summary="获取项目推理记录文件列表")
|
||||
async def logs_files(
|
||||
p: params.ProjectDetectLogFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
|
@ -1,7 +1,13 @@
|
||||
import os
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
import zipfile
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
from fastapi import UploadFile
|
||||
|
||||
img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||||
|
||||
video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'}
|
||||
|
||||
|
||||
def file_path(*path):
|
||||
@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
|
||||
def delete_file_if_exists(*file_paths: str):
|
||||
"""
|
||||
删除文件
|
||||
:param file_path:
|
||||
:param file_paths:
|
||||
:return:
|
||||
"""
|
||||
for path in file_paths:
|
||||
@ -110,3 +116,53 @@ def delete_paths(paths):
|
||||
print(f"路径删除失败 {path}: {e}")
|
||||
else:
|
||||
print(f"路径不存在: {path}")
|
||||
|
||||
|
||||
def is_extensions(extension_type: str, file_name: str):
|
||||
"""
|
||||
校验文件名
|
||||
"""
|
||||
if extension_type == 'img':
|
||||
file_extensions = img_extensions
|
||||
elif extension_type == 'video':
|
||||
file_extensions = video_extensions
|
||||
else:
|
||||
file_extensions = []
|
||||
return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions
|
||||
|
||||
|
||||
def zip_folder(folder_path: str, zip_filename: str) -> str:
|
||||
"""
|
||||
将指定文件夹打包成 ZIP 文件,并返回 ZIP 文件的路径。
|
||||
|
||||
:param folder_path: 要打包的文件夹路径
|
||||
:param zip_filename: 生成的 ZIP 文件名(不带扩展名)
|
||||
:return: 生成的 ZIP 文件的完整路径
|
||||
"""
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.isdir(folder_path):
|
||||
raise ValueError(f"文件夹路径不存在: {folder_path}")
|
||||
|
||||
# 确保 ZIP 文件名以 .zip 结尾
|
||||
if not zip_filename.endswith(".zip"):
|
||||
zip_filename += ".zip"
|
||||
|
||||
# 创建临时目录用于存储 ZIP 文件
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
zip_file_path = os.path.join(temp_dir, zip_filename)
|
||||
|
||||
try:
|
||||
# 打包文件夹为 ZIP 文件
|
||||
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f:
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径
|
||||
zip_f.write(file_path, arc_name)
|
||||
|
||||
# 返回生成的 ZIP 文件路径
|
||||
return zip_file_path
|
||||
except Exception as e:
|
||||
# 清理临时文件夹并重新抛出异常
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise RuntimeError(f"打包失败: {e}")
|
||||
|
@ -25,7 +25,7 @@ if str(ROOT) not in sys.path:
|
||||
if platform.system() != "Windows":
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import (
|
||||
from utils.yolov5.models.common import (
|
||||
C3,
|
||||
C3SPP,
|
||||
C3TR,
|
||||
@ -49,11 +49,11 @@ from models.common import (
|
||||
GhostConv,
|
||||
Proto,
|
||||
)
|
||||
from models.experimental import MixConv2d
|
||||
from utils.autoanchor import check_anchor_order
|
||||
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.plots import feature_visualization
|
||||
from utils.torch_utils import (
|
||||
from utils.yolov5.models.experimental import MixConv2d
|
||||
from utils.yolov5.utils.autoanchor import check_anchor_order
|
||||
from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.yolov5.utils.plots import feature_visualization
|
||||
from utils.yolov5.utils.torch_utils import (
|
||||
fuse_conv_and_bn,
|
||||
initialize_weights,
|
||||
model_info,
|
||||
|
@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
|
||||
|
||||
Removes incomplete downloads.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
from utils.yolov5.utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
|
||||
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
|
||||
versions.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
from utils.yolov5.utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version="latest"):
|
||||
"""Fetches GitHub repository release tag and asset names using the GitHub API."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user