完成项目推理模块的接口测试
This commit is contained in:
@ -1,7 +1,13 @@
|
||||
import os
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
import zipfile
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
from fastapi import UploadFile
|
||||
|
||||
img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||||
|
||||
video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'}
|
||||
|
||||
|
||||
def file_path(*path):
|
||||
@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
|
||||
def delete_file_if_exists(*file_paths: str):
|
||||
"""
|
||||
删除文件
|
||||
:param file_path:
|
||||
:param file_paths:
|
||||
:return:
|
||||
"""
|
||||
for path in file_paths:
|
||||
@ -110,3 +116,53 @@ def delete_paths(paths):
|
||||
print(f"路径删除失败 {path}: {e}")
|
||||
else:
|
||||
print(f"路径不存在: {path}")
|
||||
|
||||
|
||||
def is_extensions(extension_type: str, file_name: str):
|
||||
"""
|
||||
校验文件名
|
||||
"""
|
||||
if extension_type == 'img':
|
||||
file_extensions = img_extensions
|
||||
elif extension_type == 'video':
|
||||
file_extensions = video_extensions
|
||||
else:
|
||||
file_extensions = []
|
||||
return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions
|
||||
|
||||
|
||||
def zip_folder(folder_path: str, zip_filename: str) -> str:
|
||||
"""
|
||||
将指定文件夹打包成 ZIP 文件,并返回 ZIP 文件的路径。
|
||||
|
||||
:param folder_path: 要打包的文件夹路径
|
||||
:param zip_filename: 生成的 ZIP 文件名(不带扩展名)
|
||||
:return: 生成的 ZIP 文件的完整路径
|
||||
"""
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.isdir(folder_path):
|
||||
raise ValueError(f"文件夹路径不存在: {folder_path}")
|
||||
|
||||
# 确保 ZIP 文件名以 .zip 结尾
|
||||
if not zip_filename.endswith(".zip"):
|
||||
zip_filename += ".zip"
|
||||
|
||||
# 创建临时目录用于存储 ZIP 文件
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
zip_file_path = os.path.join(temp_dir, zip_filename)
|
||||
|
||||
try:
|
||||
# 打包文件夹为 ZIP 文件
|
||||
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f:
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径
|
||||
zip_f.write(file_path, arc_name)
|
||||
|
||||
# 返回生成的 ZIP 文件路径
|
||||
return zip_file_path
|
||||
except Exception as e:
|
||||
# 清理临时文件夹并重新抛出异常
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise RuntimeError(f"打包失败: {e}")
|
||||
|
@ -25,7 +25,7 @@ if str(ROOT) not in sys.path:
|
||||
if platform.system() != "Windows":
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import (
|
||||
from utils.yolov5.models.common import (
|
||||
C3,
|
||||
C3SPP,
|
||||
C3TR,
|
||||
@ -49,11 +49,11 @@ from models.common import (
|
||||
GhostConv,
|
||||
Proto,
|
||||
)
|
||||
from models.experimental import MixConv2d
|
||||
from utils.autoanchor import check_anchor_order
|
||||
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.plots import feature_visualization
|
||||
from utils.torch_utils import (
|
||||
from utils.yolov5.models.experimental import MixConv2d
|
||||
from utils.yolov5.utils.autoanchor import check_anchor_order
|
||||
from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.yolov5.utils.plots import feature_visualization
|
||||
from utils.yolov5.utils.torch_utils import (
|
||||
fuse_conv_and_bn,
|
||||
initialize_weights,
|
||||
model_info,
|
||||
|
@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
|
||||
|
||||
Removes incomplete downloads.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
from utils.yolov5.utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
|
||||
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
|
||||
versions.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
from utils.yolov5.utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version="latest"):
|
||||
"""Fetches GitHub repository release tag and asset names using the GitHub API."""
|
||||
|
Reference in New Issue
Block a user