完成项目推理模块的接口测试

This commit is contained in:
2025-04-23 16:38:55 +08:00
parent 0033746fe1
commit 5b38e91f61
8 changed files with 234 additions and 94 deletions

View File

@ -1,7 +1,13 @@
import os
import shutil
from fastapi import UploadFile
import zipfile
import tempfile
from PIL import Image
from fastapi import UploadFile
img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'}
def file_path(*path):
@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
def delete_file_if_exists(*file_paths: str):
"""
删除文件
:param file_path:
:param file_paths:
:return:
"""
for path in file_paths:
@ -110,3 +116,53 @@ def delete_paths(paths):
print(f"路径删除失败 {path}: {e}")
else:
print(f"路径不存在: {path}")
def is_extensions(extension_type: str, file_name: str):
"""
校验文件名
"""
if extension_type == 'img':
file_extensions = img_extensions
elif extension_type == 'video':
file_extensions = video_extensions
else:
file_extensions = []
return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions
def zip_folder(folder_path: str, zip_filename: str) -> str:
"""
将指定文件夹打包成 ZIP 文件,并返回 ZIP 文件的路径。
:param folder_path: 要打包的文件夹路径
:param zip_filename: 生成的 ZIP 文件名(不带扩展名)
:return: 生成的 ZIP 文件的完整路径
"""
# 检查文件夹是否存在
if not os.path.isdir(folder_path):
raise ValueError(f"文件夹路径不存在: {folder_path}")
# 确保 ZIP 文件名以 .zip 结尾
if not zip_filename.endswith(".zip"):
zip_filename += ".zip"
# 创建临时目录用于存储 ZIP 文件
temp_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(temp_dir, zip_filename)
try:
# 打包文件夹为 ZIP 文件
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f:
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径
zip_f.write(file_path, arc_name)
# 返回生成的 ZIP 文件路径
return zip_file_path
except Exception as e:
# 清理临时文件夹并重新抛出异常
shutil.rmtree(temp_dir, ignore_errors=True)
raise RuntimeError(f"打包失败: {e}")

View File

@ -25,7 +25,7 @@ if str(ROOT) not in sys.path:
if platform.system() != "Windows":
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import (
from utils.yolov5.models.common import (
C3,
C3SPP,
C3TR,
@ -49,11 +49,11 @@ from models.common import (
GhostConv,
Proto,
)
from models.experimental import MixConv2d
from utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
from utils.plots import feature_visualization
from utils.torch_utils import (
from utils.yolov5.models.experimental import MixConv2d
from utils.yolov5.utils.autoanchor import check_anchor_order
from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
from utils.yolov5.utils.plots import feature_visualization
from utils.yolov5.utils.torch_utils import (
fuse_conv_and_bn,
initialize_weights,
model_info,

View File

@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
Removes incomplete downloads.
"""
from utils.general import LOGGER
from utils.yolov5.utils.general import LOGGER
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
versions.
"""
from utils.general import LOGGER
from utils.yolov5.utils.general import LOGGER
def github_assets(repository, version="latest"):
"""Fetches GitHub repository release tag and asset names using the GitHub API."""