diff --git a/apps/business/detect/crud.py b/apps/business/detect/crud.py index 0bae20c..44657d1 100644 --- a/apps/business/detect/crud.py +++ b/apps/business/detect/crud.py @@ -27,25 +27,27 @@ class ProjectDetectDal(DalBase): self.model = models.ProjectDetect self.schema = schemas.ProjectDetectOut - async def check_name(self, name: str,project_id: int) -> bool: + async def check_name(self, name: str, project_id: int) -> bool: """ 校验推理集合名称 """ - count = self.get_count( + count = await self.get_count( v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name]) return count > 0 - async def add_detect(self, data: schemas.ProjectDetectIn): + async def add_detect(self, data: schemas.ProjectDetectIn, user_id: int): """ 新增集合 """ - detect = models.ProjectDetect(**data.model_dump()) + detect = self.model(**data.model_dump()) detect.detect_no = random_str(6) detect.detect_version = 0 detect.detect_status = '0' - url = os.create_folder(detect_url, detect.detect_no, 'images') - detect.folder_url = url - await self.create_data(data) + detect.user_id = user_id + if detect.file_type != 'rtsp': + url = os.create_folder(detect_url, detect.detect_no, 'images') + detect.folder_url = url + await self.create_model(detect) return detect async def delete_detects(self, ids: list[int]): @@ -55,10 +57,17 @@ class ProjectDetectDal(DalBase): for id_ in ids: detect_info = await self.get_data(data_id=id_) if detect_info.file_type != 'rtsp': - os.delete_paths(detect_info.folder_url) - logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids]) + os.delete_paths([detect_info.folder_url]) + logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == id_]) + log_ids = [] + log_urls = [] for log in logs: - os.delete_paths(log.folder_url) + log_urls.append(log.folder_url) + log_ids.append(log.id) + if log_ids: + await ProjectDetectLogDal(self.db).delete_datas(ids=log_ids, v_soft=False) + if log_urls: + os.delete_paths(log_urls) await self.delete_datas(ids=ids, v_soft=False) @@ -71,10 +80,10 @@ class ProjectDetectFileDal(DalBase): self.schema = schemas.ProjectDetectFileOut async def file_count(self, detect_id: int) -> int: - count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id]) + count = await self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id]) return count - async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]): + async def add_files(self, detect: models.ProjectDetect, files: list[UploadFile], user_id: int): images = [] obs = MyObs() for file in files: @@ -83,23 +92,24 @@ class ProjectDetectFileDal(DalBase): image.file_name = file.filename # 保存原图 path = os.save_images(detect.folder_url, file=file) - image.image_url = path + image.file_url = path + image.user_id = user_id # 上传到obs object_key = detect.detect_no + '/' + file.filename success, key, url = obs.put_file(object_key=object_key, file_path=path) if success: image.object_key = object_key - image.thumb_image_url = url + image.thumb_file_url = url else: raise CustomException("obs上传失败", code=status.HTTP_ERROR) images.append(image) - await self.create_datas(images) + await self.create_models(images) async def delete_files(self, ids: list[int]): file_urls = [] object_keys = [] for id_ in ids: - file = self.get_data(data_id=id_) + file = await self.get_data(data_id=id_) if file: file_urls.append(file.file_url) object_keys.append(file.object_key) @@ -107,13 +117,14 @@ class ProjectDetectFileDal(DalBase): MyObs().del_objects(object_keys) await self.delete_datas(ids, v_soft=False) + class ProjectDetectLogDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectDetectLogDal, self).__init__() self.db = db self.model = models.ProjectDetectLog - self.schema = schemas.ProjectDetectLogSimpleOut + self.schema = schemas.ProjectDetectLogOut class ProjectDetectLogFileDal(DalBase): diff --git a/apps/business/detect/models/detect.py b/apps/business/detect/models/detect.py index 7019a91..1305b3e 100644 --- a/apps/business/detect/models/detect.py +++ b/apps/business/detect/models/detect.py @@ -11,15 +11,15 @@ class ProjectDetect(BaseModel): __tablename__ = "project_detect" __table_args__ = ({'comment': '项目推理集合'}) - project_id: Mapped[int] = mapped_column(Integer, nullable=False) - detect_name: Mapped[str] = mapped_column(String(64), nullable=False) + project_id: Mapped[int] = mapped_column(Integer) + detect_name: Mapped[str] = mapped_column(String(64)) detect_version: Mapped[int] = mapped_column(Integer) detect_no: Mapped[str] = mapped_column(String(32)) detect_status: Mapped[int] = mapped_column(Integer) file_type: Mapped[str] = mapped_column(String(10)) - folder_url: Mapped[str] = mapped_column(String(255)) - rtsp_url: Mapped[str] = mapped_column(String(255)) - user_id: Mapped[int] = mapped_column(Integer, nullable=False) + folder_url: Mapped[str] = mapped_column(String(255), nullable=True) + rtsp_url: Mapped[str] = mapped_column(String(255), nullable=True) + user_id: Mapped[int] = mapped_column(Integer) class ProjectDetectFile(BaseModel): diff --git a/apps/business/detect/schemas/project_detect.py b/apps/business/detect/schemas/project_detect.py index 696f886..7c09ada 100644 --- a/apps/business/detect/schemas/project_detect.py +++ b/apps/business/detect/schemas/project_detect.py @@ -18,6 +18,8 @@ class ProjectDetectIn(BaseModel): detect_name: Optional[str] = Field(..., description="推理集合名称") rtsp_url: Optional[str] = Field(None, description="视频流地址") + model_config = ConfigDict(from_attributes=True) + class ProjectDetectPager(BaseModel): project_id: Optional[int] = Field(..., description="项目id") diff --git a/apps/business/detect/service.py b/apps/business/detect/service.py index 528439f..15eb950 100644 --- a/apps/business/detect/service.py +++ b/apps/business/detect/service.py @@ -21,13 +21,15 @@ async def before_detect( detect_in: schemas.ProjectDetectLogIn, detect: models.ProjectDetect, train: train_models.ProjectTrain, - db: AsyncSession): + db: AsyncSession, + user_id: int): """ 开始推理 :param detect: :param detect_in: :param train: :param db: + :param user_id: :return: """ # 推理版本 @@ -52,19 +54,33 @@ async def before_detect( detect_log.pt_url = pt_url detect_log.folder_url = img_url detect_log.detect_folder_url = out_url - await crud.ProjectDetectLogDal(db).create_data(detect_log) + detect_log.user_id = user_id + await crud.ProjectDetectLogDal(db).create_model(detect_log) return detect_log +def run_img_loop( + weights: str, + source: str, + project: str, + name: str, + detect_id: int, + is_gpu: str): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + # 运行异步函数 + loop.run_until_complete(run_detect_img(weights, source, project, name, detect_id, is_gpu)) + # 可选: 关闭循环 + loop.close() + + async def run_detect_img( weights: str, source: str, project: str, name: str, - log_id: int, detect_id: int, - db: AsyncSession, - rd: Redis): + is_gpu: str): """ 执行yolov5的推理 :param weights: 权重文件 @@ -74,7 +90,7 @@ async def run_detect_img( :param log_id: 日志id :param detect_id: 推理集合id :param db: 数据库session - :param rd: Redis + :param is_gpu: 是否gpu加速 :return: """ yolo_path = os.file_path(yolo_url, 'detect.py') @@ -82,10 +98,9 @@ async def run_detect_img( await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n") commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt", "--conf-thres", "0.4"] - is_gpu = rd.get('is_gpu') # 判断是否存在cuda版本 if is_gpu == 'True': - commend.append("--device", "0") + commend.append("--device=0") # 启动子进程 with subprocess.Popen( commend, @@ -101,40 +116,51 @@ async def run_detect_img( process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 if line != '\n': await room_manager.send_to_room(room, line + '\n') - # 等待进程结束并获取返回码 return_code = process.wait() if return_code != 0: await room_manager.send_to_room(room, 'error') else: await room_manager.send_to_room(room, 'success') - detect_files = crud.ProjectDetectFileDal(db).get_data( - v_where=[models.ProjectDetectFile.detect_id == detect_id]) - detect_log_files = [] - for detect_file in detect_files: - detect_log_img = models.ProjectDetectLogFile() - detect_log_img.log_id = log_id - image_url = os.file_path(project, name, detect_file.file_name) - detect_log_img.image_url = image_url - detect_log_img.file_name = detect_file.file_name - detect_log_files.append(detect_log_img) - await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files) -async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis): +async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name): + """ + 更新推理集合的状态 + """ + detect_dal = crud.ProjectDetectDal(db) + detect = await detect_dal.get_data(detect_id) + detect.detect_version = detect.detect_version + 1 + await detect_dal.put_data(data_id=detect_id, data=detect) + detect_files = await crud.ProjectDetectFileDal(db).get_datas( + limit=0, + v_where=[models.ProjectDetectFile.detect_id == detect_id], + v_return_objs=True, + v_return_count=False) + detect_log_files = [] + for detect_file in detect_files: + detect_log_img = models.ProjectDetectLogFile() + detect_log_img.log_id = log_id + image_url = os.file_path(project, name, detect_file.file_name) + detect_log_img.file_url = image_url + detect_log_img.file_name = detect_file.file_name + detect_log_files.append(detect_log_img) + await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files) + + +async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str): """ rtsp 视频流推理 :param detect_id: 训练集的id :param weights_pt: 权重文件 :param rtsp_url: 视频流地址 :param data: yaml文件 - :param rd: Redis :redis + :param is_gpu: 是否启用加速 :return: """ room = 'detect_rtsp_' + str(detect_id) # 选择设备(CPU 或 GPU) device = select_device('cpu') - is_gpu = rd.get('is_gpu') # 判断是否存在cuda版本 if is_gpu == 'True': device = select_device('cuda:0') @@ -208,19 +234,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: break -def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession): +def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # 运行异步函数 - loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db)) - # 可选: 关闭循环 - loop.close() - - -def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - # 运行异步函数 - loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd)) + loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu)) # 可选: 关闭循环 loop.close() \ No newline at end of file diff --git a/apps/business/detect/views.py b/apps/business/detect/views.py index 2ee8545..3bae1b6 100644 --- a/apps/business/detect/views.py +++ b/apps/business/detect/views.py @@ -6,18 +6,24 @@ # @IDE : PyCharm # @desc : 路由,视图文件 +from utils import os_utils as osu from core.dependencies import IdList from core.database import redis_getter -from . import schemas, crud, params, service +from . import schemas, crud, params, service, models from utils.websocket_server import room_manager from apps.business.train.crud import ProjectTrainDal from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth from utils.response import SuccessResponse, ErrorResponse +import os +import shutil +import zipfile +import tempfile import threading from redis.asyncio import Redis -from fastapi import Depends, APIRouter, Form, UploadFile +from fastapi.responses import FileResponse +from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks app = APIRouter() @@ -26,22 +32,27 @@ app = APIRouter() ########################################################### # 项目推理集合信息 ########################################################### -@app.get("/list", summary="获取项目推理集合信息列表") +@app.get("/list/{proj_id}", summary="获取推理集合列表") async def detect_list( - p: params.ProjectDetectParams = Depends(), + proj_id: int, auth: Auth = Depends(AllUserAuth())): - datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True) - return SuccessResponse(datas, count=count) + datas = await crud.ProjectDetectDal(auth.db).get_datas( + limit=0, + v_where=[models.ProjectDetect.project_id == proj_id], + v_order="asc", + v_order_field="id", + v_return_count=False) + return SuccessResponse(datas) -@app.post("/", summary="创建项目推理集合信息") +@app.post("/", summary="创建推理集合") async def add_detect( data: schemas.ProjectDetectIn, auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) if await detect_dal.check_name(data.detect_name, data.project_id): return ErrorResponse(msg="该项目中存在相同名称的集合") - await detect_dal.create_data(data=data) + await detect_dal.add_detect(data=data, user_id=auth.user.id) return SuccessResponse(msg="保存成功") @@ -49,18 +60,18 @@ async def add_detect( async def delete_detect( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): - await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False) + await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids) return SuccessResponse("删除成功") ########################################################### # 项目推理集合文件信息 ########################################################### -@app.get("/file", summary="获取项目推理集合文件信息列表") +@app.get("/files", summary="获取推理集合文件列表") async def file_list( p: params.ProjectDetectFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): - if p.limit: + if p.limit > 0: datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) else: @@ -68,20 +79,24 @@ async def file_list( return SuccessResponse(datas) -@app.post("/file", summary="上传项目推理集合文件") +@app.post("/files", summary="上传项目推理集合文件") async def upload_file( detect_id: int = Form(...), files: list[UploadFile] = Form(...), auth: Auth = Depends(AllUserAuth())): + detect_dal = crud.ProjectDetectDal(auth.db) file_dal = crud.ProjectDetectFileDal(auth.db) - detect_out = file_dal.get_data(data_id=detect_id) - if detect_out is None: + detect_info = await detect_dal.get_data(data_id=detect_id) + if detect_info is None: return ErrorResponse("训练集合查询失败,请刷新后再试") - await file_dal.add_file(detect_out, files) + if detect_info.file_type != 'rtsp': + if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files): + return ErrorResponse("上传的文件中存在不符合的文件类型") + await file_dal.add_files(detect_info, files, auth.user.id) return SuccessResponse(msg="上传成功") -@app.delete("/file", summary="删除项目推理集合文件信息") +@app.delete("/files", summary="删除推理集合文件") async def delete_file( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): @@ -89,29 +104,34 @@ async def delete_file( return SuccessResponse("删除成功") -@app.post("/detect", summary="开始推理") -def run_detect_yolo( +@app.post("/start", summary="开始推理") +async def run_detect_yolo( detect_log_in: schemas.ProjectDetectLogIn, auth: Auth = Depends(AllUserAuth()), rd: Redis = Depends(redis_getter)): detect_dal = crud.ProjectDetectDal(auth.db) train_dal = ProjectTrainDal(auth.db) - detect = detect_dal.get_data(detect_log_in.detect_id) + detect = await detect_dal.get_data(detect_log_in.detect_id) if detect is None: return ErrorResponse(msg="训练集合不存在") - train = train_dal.get_data(detect_log_in.train_id) + train = await train_dal.get_data(detect_log_in.train_id) if train is None: return ErrorResponse("训练权重不存在") - file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) + file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) if file_count == 0 and detect.rtsp_url is None: return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片") + is_gpu = await rd.get('is_gpu') if detect.file_type == 'img' or detect.file_type == 'video': - detect_log = service.before_detect(detect_log_in, detect, train, auth.db) + detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id) thread_train = threading.Thread(target=service.run_img_loop, args=(detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version, - detect_log.id, detect_log.detect_id, auth.db,)) + detect_log.detect_id, is_gpu)) thread_train.start() + await service.update_sql( + auth.db, detect_log.detect_id, + detect_log.id, detect_log.detect_folder_url, + detect_log.detect_version) elif detect.file_type == 'rtsp': room = 'detect_rtsp_' + str(detect.id) if not room_manager.rooms.get(room): @@ -120,7 +140,7 @@ def run_detect_yolo( else: weights_pt = train.last_pt thread_train = threading.Thread(target=service.run_rtsp_loop, - args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,)) + args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,)) thread_train.start() return SuccessResponse(msg="执行成功") @@ -128,16 +148,50 @@ def run_detect_yolo( ########################################################### # 项目推理记录信息 ########################################################### -@app.get("/log", summary="获取项目推理记录列表") -async def log_pager( +@app.get("/logs", summary="获取推理记录列表") +async def logs( p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) -@app.get("/log_files", summary="获取项目推理记录文件列表") -async def log_files( +@app.get("/logs/download/{log_id}", summary="下载推理结果") +async def logs_download( + log_id: int, + background_tasks: BackgroundTasks, + auth: Auth = Depends(AllUserAuth()) +): + # 获取日志记录 + detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id) + if detect_log is None: + return ErrorResponse(msg="推理结果查询错误,请刷新页面再试") + + # 检查源文件夹是否存在 + folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version) + if not os.path.exists(folder_path): + return ErrorResponse(msg="推理结果文件夹不存在") + + zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip" + zip_file_path = osu.zip_folder(folder_path, zip_filename) + + # 后台任务删掉这个 + background_tasks.add_task(cleanup_temp_dir, zip_file_path) + + # 直接返回文件响应 + return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path)) + + +def cleanup_temp_dir(temp_dir: str): + """清理临时目录的后台任务""" + try: + shutil.rmtree(temp_dir, ignore_errors=True) + except Exception as e: + print(f"清理临时目录失败: {e}") + + +@app.get("/logs/files", summary="获取项目推理记录文件列表") +async def logs_files( p: params.ProjectDetectLogFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) diff --git a/utils/os_utils.py b/utils/os_utils.py index 4b1a9d1..8422d31 100644 --- a/utils/os_utils.py +++ b/utils/os_utils.py @@ -1,7 +1,13 @@ import os import shutil -from fastapi import UploadFile +import zipfile +import tempfile from PIL import Image +from fastapi import UploadFile + +img_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} + +video_extensions = {'avi', 'wmv', 'rmvb', 'mp4', 'm4v', 'avi'} def file_path(*path): @@ -83,7 +89,7 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name): def delete_file_if_exists(*file_paths: str): """ 删除文件 - :param file_path: + :param file_paths: :return: """ for path in file_paths: @@ -110,3 +116,53 @@ def delete_paths(paths): print(f"路径删除失败 {path}: {e}") else: print(f"路径不存在: {path}") + + +def is_extensions(extension_type: str, file_name: str): + """ + 校验文件名 + """ + if extension_type == 'img': + file_extensions = img_extensions + elif extension_type == 'video': + file_extensions = video_extensions + else: + file_extensions = [] + return '.' in file_name and file_name.rsplit('.', 1)[1].lower() in file_extensions + + +def zip_folder(folder_path: str, zip_filename: str) -> str: + """ + 将指定文件夹打包成 ZIP 文件,并返回 ZIP 文件的路径。 + + :param folder_path: 要打包的文件夹路径 + :param zip_filename: 生成的 ZIP 文件名(不带扩展名) + :return: 生成的 ZIP 文件的完整路径 + """ + # 检查文件夹是否存在 + if not os.path.isdir(folder_path): + raise ValueError(f"文件夹路径不存在: {folder_path}") + + # 确保 ZIP 文件名以 .zip 结尾 + if not zip_filename.endswith(".zip"): + zip_filename += ".zip" + + # 创建临时目录用于存储 ZIP 文件 + temp_dir = tempfile.mkdtemp() + zip_file_path = os.path.join(temp_dir, zip_filename) + + try: + # 打包文件夹为 ZIP 文件 + with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED, allowZip64=True) as zip_f: + for root, dirs, files in os.walk(folder_path): + for file in files: + file_path = os.path.join(root, file) + arc_name = os.path.relpath(file_path, folder_path) # 保持相对路径 + zip_f.write(file_path, arc_name) + + # 返回生成的 ZIP 文件路径 + return zip_file_path + except Exception as e: + # 清理临时文件夹并重新抛出异常 + shutil.rmtree(temp_dir, ignore_errors=True) + raise RuntimeError(f"打包失败: {e}") diff --git a/utils/yolov5/models/yolo.py b/utils/yolov5/models/yolo.py index 13498ac..8165f7e 100644 --- a/utils/yolov5/models/yolo.py +++ b/utils/yolov5/models/yolo.py @@ -25,7 +25,7 @@ if str(ROOT) not in sys.path: if platform.system() != "Windows": ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative -from models.common import ( +from utils.yolov5.models.common import ( C3, C3SPP, C3TR, @@ -49,11 +49,11 @@ from models.common import ( GhostConv, Proto, ) -from models.experimental import MixConv2d -from utils.autoanchor import check_anchor_order -from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args -from utils.plots import feature_visualization -from utils.torch_utils import ( +from utils.yolov5.models.experimental import MixConv2d +from utils.yolov5.utils.autoanchor import check_anchor_order +from utils.yolov5.utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args +from utils.yolov5.utils.plots import feature_visualization +from utils.yolov5.utils.torch_utils import ( fuse_conv_and_bn, initialize_weights, model_info, diff --git a/utils/yolov5/utils/downloads.py b/utils/yolov5/utils/downloads.py index f51d67a..90b6d52 100644 --- a/utils/yolov5/utils/downloads.py +++ b/utils/yolov5/utils/downloads.py @@ -63,7 +63,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""): Removes incomplete downloads. """ - from utils.general import LOGGER + from utils.yolov5.utils.general import LOGGER file = Path(file) assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" @@ -89,7 +89,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"): """Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup versions. """ - from utils.general import LOGGER + from utils.yolov5.utils.general import LOGGER def github_assets(repository, version="latest"): """Fetches GitHub repository release tag and asset names using the GitHub API."""