diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py index e45f482..b6ea6d7 100644 --- a/app/service/project_detect_service.py +++ b/app/service/project_detect_service.py @@ -161,7 +161,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log yolo_path = os.file_path(yolo_url, 'detect.py') room = 'detect_' + str(detect_id) await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n") - commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"] + commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt", "--conf-thres", "0.4"] is_gpu = redis_conn.get('is_gpu') # 判断是否存在cuda版本 if is_gpu == 'True': @@ -259,7 +259,7 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: pred = model(im, augment=False, visualize=False) # NMS with dt[2]: - pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000) + pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000) # Process predictions for i, det in enumerate(pred): # per image diff --git a/app/service/project_train_service.py b/app/service/project_train_service.py index 8ad992b..8c0d9ec 100644 --- a/app/service/project_train_service.py +++ b/app/service/project_train_service.py @@ -14,12 +14,11 @@ from app.websocket.web_socket_server import room_manager from app.util.csv_utils import read_csv from app.common.redis_cli import redis_conn -from sqlalchemy.orm import Session -from typing import List -from fastapi import UploadFile import yaml import subprocess -import torch +from typing import List +from fastapi import UploadFile +from sqlalchemy.orm import Session def add_project(info: ProjectInfoIn, session: Session, user_id: int): diff --git a/app/util/yolov5/utils/general.py b/app/util/yolov5/utils/general.py index d016ca3..1cf24c4 100644 --- a/app/util/yolov5/utils/general.py +++ b/app/util/yolov5/utils/general.py @@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False): font = Path(font) file = CONFIG_DIR / font.name if not font.exists() and not file.exists(): - url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}" + url = f"https://ultralytics.com/assets/{font.name}" LOGGER.info(f"Downloading {url} to {file}...") torch.hub.download_url_to_file(url, str(file), progress=progress) diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 2f74c5a..8870527 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -31,11 +31,11 @@ PyYAML>=5.3.1 requests>=2.32.2 scipy==1.13.1 thop>=0.1.1 # FLOPs computation -torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/ -torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/ +# torch==2.6.0+cu124 # 本地安装 +# torchvision==0.21.0+cu124 # 本地安装 tqdm>=4.66.3 ultralytics==8.3.75 # https://ultralytics.com # Plotting -------------------------------------------------------------------- pandas==2.2.3 -seaborn==0.11.0 # 对应的这个依赖也需要降 \ No newline at end of file +seaborn==0.11.2 # 对应的这个依赖也需要降 \ No newline at end of file diff --git a/toch_test.py b/toch_test.py new file mode 100644 index 0000000..f28ade7 --- /dev/null +++ b/toch_test.py @@ -0,0 +1,5 @@ +import torch + +is_gpu = torch.cuda.is_available() + +print(str(is_gpu)) diff --git a/yolov5/train.py b/yolov5/train.py index 1401ccb..d65fee7 100644 --- a/yolov5/train.py +++ b/yolov5/train.py @@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.cuda.amp.autocast(amp): + with torch.amp.autocast(device_type='cuda', enabled=amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: diff --git a/yolov5/utils/general.py b/yolov5/utils/general.py index 89bbc61..302bc8a 100644 --- a/yolov5/utils/general.py +++ b/yolov5/utils/general.py @@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False): font = Path(font) file = CONFIG_DIR / font.name if not font.exists() and not file.exists(): - url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}" + url = f"https://ultralytics.com/assets/{font.name}" LOGGER.info(f"Downloading {url} to {file}...") torch.hub.download_url_to_file(url, str(file), progress=progress)