完成项目推理模块的接口测试

This commit is contained in:
sunyg 2025-04-23 16:38:55 +08:00
parent 0033746fe1
commit 5b38e91f61
8 changed files with 234 additions and 94 deletions

View File

@ -27,25 +27,27 @@ class ProjectDetectDal(DalBase):
self.model = models.ProjectDetect
self.schema = schemas.ProjectDetectOut
async def check_name(self, name: str,project_id: int) -> bool:
async def check_name(self, name: str, project_id: int) -> bool:
"""
校验推理集合名称
"""
count = self.get_count(
count = await self.get_count(
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
return count > 0
async def add_detect(self, data: schemas.ProjectDetectIn):
async def add_detect(self, data: schemas.ProjectDetectIn, user_id: int):
"""
新增集合
"""
detect = models.ProjectDetect(**data.model_dump())
detect = self.model(**data.model_dump())
detect.detect_no = random_str(6)
detect.detect_version = 0
detect.detect_status = '0'
url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url
await self.create_data(data)
detect.user_id = user_id
if detect.file_type != 'rtsp':
url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url
await self.create_model(detect)
return detect
async def delete_detects(self, ids: list[int]):
@ -55,10 +57,17 @@ class ProjectDetectDal(DalBase):
for id_ in ids:
detect_info = await self.get_data(data_id=id_)
if detect_info.file_type != 'rtsp':
os.delete_paths(detect_info.folder_url)
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids])
os.delete_paths([detect_info.folder_url])
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == id_])
log_ids = []
log_urls = []
for log in logs:
os.delete_paths(log.folder_url)
log_urls.append(log.folder_url)
log_ids.append(log.id)
if log_ids:
await ProjectDetectLogDal(self.db).delete_datas(ids=log_ids, v_soft=False)
if log_urls:
os.delete_paths(log_urls)
await self.delete_datas(ids=ids, v_soft=False)
@ -71,10 +80,10 @@ class ProjectDetectFileDal(DalBase):
self.schema = schemas.ProjectDetectFileOut
async def file_count(self, detect_id: int) -> int:
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
count = await self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
return count
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]):
async def add_files(self, detect: models.ProjectDetect, files: list[UploadFile], user_id: int):
images = []
obs = MyObs()
for file in files:
@ -83,23 +92,24 @@ class ProjectDetectFileDal(DalBase):
image.file_name = file.filename
# 保存原图
path = os.save_images(detect.folder_url, file=file)
image.image_url = path
image.file_url = path
image.user_id = user_id
# 上传到obs
object_key = detect.detect_no + '/' + file.filename
success, key, url = obs.put_file(object_key=object_key, file_path=path)
if success:
image.object_key = object_key
image.thumb_image_url = url
image.thumb_file_url = url
else:
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
images.append(image)
await self.create_datas(images)
await self.create_models(images)
async def delete_files(self, ids: list[int]):
file_urls = []
object_keys = []
for id_ in ids:
file = self.get_data(data_id=id_)
file = await self.get_data(data_id=id_)
if file:
file_urls.append(file.file_url)
object_keys.append(file.object_key)
@ -107,13 +117,14 @@ class ProjectDetectFileDal(DalBase):
MyObs().del_objects(object_keys)
await self.delete_datas(ids, v_soft=False)
class ProjectDetectLogDal(DalBase):
def __init__(self, db: AsyncSession):
super(ProjectDetectLogDal, self).__init__()
self.db = db
self.model = models.ProjectDetectLog
self.schema = schemas.ProjectDetectLogSimpleOut
self.schema = schemas.ProjectDetectLogOut
class ProjectDetectLogFileDal(DalBase):

View File

@ -11,15 +11,15 @@ class ProjectDetect(BaseModel):
__tablename__ = "project_detect"
__table_args__ = ({'comment': '项目推理集合'})
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
project_id: Mapped[int] = mapped_column(Integer)
detect_name: Mapped[str] = mapped_column(String(64))
detect_version: Mapped[int] = mapped_column(Integer)
detect_no: Mapped[str] = mapped_column(String(32))
detect_status: Mapped[int] = mapped_column(Integer)
file_type: Mapped[str] = mapped_column(String(10))
folder_url: Mapped[str] = mapped_column(String(255))
rtsp_url: Mapped[str] = mapped_column(String(255))
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
folder_url: Mapped[str] = mapped_column(String(255), nullable=True)
rtsp_url: Mapped[str] = mapped_column(String(255), nullable=True)
user_id: Mapped[int] = mapped_column(Integer)
class ProjectDetectFile(BaseModel):

View File

@ -18,6 +18,8 @@ class ProjectDetectIn(BaseModel):
detect_name: Optional[str] = Field(..., description="推理集合名称")
rtsp_url: Optional[str] = Field(None, description="视频流地址")
model_config = ConfigDict(from_attributes=True)
class ProjectDetectPager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")

View File

@ -21,13 +21,15 @@ async def before_detect(
detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect,
train: train_models.ProjectTrain,
db: AsyncSession):
db: AsyncSession,
user_id: int):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param db:
:param user_id:
:return:
"""
# 推理版本
@ -52,19 +54,33 @@ async def before_detect(
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
await crud.ProjectDetectLogDal(db).create_data(detect_log)
detect_log.user_id = user_id
await crud.ProjectDetectLogDal(db).create_model(detect_log)
return detect_log
def run_img_loop(
weights: str,
source: str,
project: str,
name: str,
detect_id: int,
is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
async def run_detect_img(
weights: str,
source: str,
project: str,
name: str,
log_id: int,
detect_id: int,
db: AsyncSession,
rd: Redis):
is_gpu: str):
"""
执行yolov5的推理
:param weights: 权重文件
@ -74,7 +90,7 @@ async def run_detect_img(
:param log_id: 日志id
:param detect_id: 推理集合id
:param db: 数据库session
:param rd: Redis
:param is_gpu: 是否gpu加速
:return:
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
@ -82,10 +98,9 @@ async def run_detect_img(
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
commend.append("--device=0")
# 启动子进程
with subprocess.Popen(
commend,
@ -101,40 +116,51 @@ async def run_detect_img(
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
detect_files = crud.ProjectDetectFileDal(db).get_data(
v_where=[models.ProjectDetectFile.detect_id == detect_id])
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
"""
更新推理集合的状态
"""
detect_dal = crud.ProjectDetectDal(db)
detect = await detect_dal.get_data(detect_id)
detect.detect_version = detect.detect_version + 1
await detect_dal.put_data(data_id=detect_id, data=detect)
detect_files = await crud.ProjectDetectFileDal(db).get_datas(
limit=0,
v_where=[models.ProjectDetectFile.detect_id == detect_id],
v_return_objs=True,
v_return_count=False)
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.file_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param rd: Redis :redis
:param is_gpu: 是否启用加速
:return:
"""
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
@ -208,19 +234,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
break
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
# 可选: 关闭循环
loop.close()
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()

View File

@ -6,18 +6,24 @@
# @IDE : PyCharm
# @desc : 路由,视图文件
from utils import os_utils as osu
from core.dependencies import IdList
from core.database import redis_getter
from . import schemas, crud, params, service
from . import schemas, crud, params, service, models
from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import zipfile
import tempfile
import threading
from redis.asyncio import Redis
from fastapi import Depends, APIRouter, Form, UploadFile
from fastapi.responses import FileResponse
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
app = APIRouter()
@ -26,22 +32,27 @@ app = APIRouter()
###########################################################
# 项目推理集合信息
###########################################################
@app.get("/list", summary="获取项目推理集合信息列表")
@app.get("/list/{proj_id}", summary="获取推理集合列表")
async def detect_list(
p: params.ProjectDetectParams = Depends(),
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
datas = await crud.ProjectDetectDal(auth.db).get_datas(
limit=0,
v_where=[models.ProjectDetect.project_id == proj_id],
v_order="asc",
v_order_field="id",
v_return_count=False)
return SuccessResponse(datas)
@app.post("/", summary="创建项目推理集合信息")
@app.post("/", summary="创建推理集合")
async def add_detect(
data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.create_data(data=data)
await detect_dal.add_detect(data=data, user_id=auth.user.id)
return SuccessResponse(msg="保存成功")
@ -49,18 +60,18 @@ async def add_detect(
async def delete_detect(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
return SuccessResponse("删除成功")
###########################################################
# 项目推理集合文件信息
###########################################################
@app.get("/file", summary="获取项目推理集合文件信息列表")
@app.get("/files", summary="获取推理集合文件列表")
async def file_list(
p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
if p.limit:
if p.limit > 0:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
else:
@ -68,20 +79,24 @@ async def file_list(
return SuccessResponse(datas)
@app.post("/file", summary="上传项目推理集合文件")
@app.post("/files", summary="上传项目推理集合文件")
async def upload_file(
detect_id: int = Form(...),
files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
file_dal = crud.ProjectDetectFileDal(auth.db)
detect_out = file_dal.get_data(data_id=detect_id)
if detect_out is None:
detect_info = await detect_dal.get_data(data_id=detect_id)
if detect_info is None:
return ErrorResponse("训练集合查询失败,请刷新后再试")
await file_dal.add_file(detect_out, files)
if detect_info.file_type != 'rtsp':
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
return ErrorResponse("上传的文件中存在不符合的文件类型")
await file_dal.add_files(detect_info, files, auth.user.id)
return SuccessResponse(msg="上传成功")
@app.delete("/file", summary="删除项目推理集合文件信息")
@app.delete("/files", summary="删除推理集合文件")
async def delete_file(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
@ -89,29 +104,34 @@ async def delete_file(
return SuccessResponse("删除成功")
@app.post("/detect", summary="开始推理")
def run_detect_yolo(
@app.post("/start", summary="开始推理")
async def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = detect_dal.get_data(detect_log_in.detect_id)
detect = await detect_dal.get_data(detect_log_in.detect_id)
if detect is None:
return ErrorResponse(msg="训练集合不存在")
train = train_dal.get_data(detect_log_in.train_id)
train = await train_dal.get_data(detect_log_in.train_id)
if train is None:
return ErrorResponse("训练权重不存在")
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
is_gpu = await rd.get('is_gpu')
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.id, detect_log.detect_id, auth.db,))
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
@ -120,7 +140,7 @@ def run_detect_yolo(
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train.start()
return SuccessResponse(msg="执行成功")
@ -128,16 +148,50 @@ def run_detect_yolo(
###########################################################
# 项目推理记录信息
###########################################################
@app.get("/log", summary="获取项目推理记录列表")
async def log_pager(
@app.get("/logs", summary="获取推理记录列表")
async def logs(
p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.get("/log_files", summary="获取项目推理记录文件列表")
async def log_files(
@app.get("/logs/download/{log_id}", summary="下载推理结果")
async def logs_download(
log_id: int,
background_tasks: BackgroundTasks,
auth: Auth = Depends(AllUserAuth())
):
# 获取日志记录
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
if detect_log is None:
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
# 检查源文件夹是否存在
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
if not os.path.exists(folder_path):
return ErrorResponse(msg="推理结果文件夹不存在")
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
zip_file_path = osu.zip_folder(folder_path, zip_filename)
# 后台任务删掉这个
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
# 直接返回文件响应
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
def cleanup_temp_dir(temp_dir: str):
"""清理临时目录的后台任务"""
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
print(f"清理临时目录失败: {e}")
@app.get("/logs/files", summary="获取项目推理记录文件列表")
async def logs_files(
p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)

View File

@ -1,7 +1,13 @@
import os
import shutil
from fastapi import UploadFile
import zipfile
import tempfile
from PIL import Image
from fastapi import UploadFile
img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'}
def file_path(*path):
@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
def delete_file_if_exists(*file_paths: str):
"""
删除文件
:param file_path:
:param file_paths:
:return:
"""
for path in file_paths:
@ -110,3 +116,53 @@ def delete_paths(paths):
print(f"路径删除失败 {path}: {e}")
else:
print(f"路径不存在: {path}")
def is_extensions(extension_type: str, file_name: str):
"""
校验文件名
"""
if extension_type == 'img':
file_extensions = img_extensions
elif extension_type == 'video':
file_extensions = video_extensions
else:
file_extensions = []
return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions
def zip_folder(folder_path: str, zip_filename: str) -> str:
"""
将指定文件夹打包成 ZIP 文件并返回 ZIP 文件的路径
:param folder_path: 要打包的文件夹路径
:param zip_filename: 生成的 ZIP 文件名不带扩展名
:return: 生成的 ZIP 文件的完整路径
"""
# 检查文件夹是否存在
if not os.path.isdir(folder_path):
raise ValueError(f"文件夹路径不存在: {folder_path}")
# 确保 ZIP 文件名以 .zip 结尾
if not zip_filename.endswith(".zip"):
zip_filename += ".zip"
# 创建临时目录用于存储 ZIP 文件
temp_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(temp_dir, zip_filename)
try:
# 打包文件夹为 ZIP 文件
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f:
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径
zip_f.write(file_path, arc_name)
# 返回生成的 ZIP 文件路径
return zip_file_path
except Exception as e:
# 清理临时文件夹并重新抛出异常
shutil.rmtree(temp_dir, ignore_errors=True)
raise RuntimeError(f"打包失败: {e}")

View File

@ -25,7 +25,7 @@ if str(ROOT) not in sys.path:
if platform.system() != "Windows":
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import (
from utils.yolov5.models.common import (
C3,
C3SPP,
C3TR,
@ -49,11 +49,11 @@ from models.common import (
GhostConv,
Proto,
)
from models.experimental import MixConv2d
from utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
from utils.plots import feature_visualization
from utils.torch_utils import (
from utils.yolov5.models.experimental import MixConv2d
from utils.yolov5.utils.autoanchor import check_anchor_order
from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
from utils.yolov5.utils.plots import feature_visualization
from utils.yolov5.utils.torch_utils import (
fuse_conv_and_bn,
initialize_weights,
model_info,

View File

@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
Removes incomplete downloads.
"""
from utils.general import LOGGER
from utils.yolov5.utils.general import LOGGER
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
versions.
"""
from utils.general import LOGGER
from utils.yolov5.utils.general import LOGGER
def github_assets(repository, version="latest"):
"""Fetches GitHub repository release tag and asset names using the GitHub API."""