优化训练过程

This commit is contained in:
2025-03-20 16:41:51 +08:00
parent 358bb40a2a
commit bba39adcfc
7 changed files with 16 additions and 12 deletions

View File

@ -161,7 +161,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
@ -259,7 +259,7 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image

View File

@ -14,12 +14,11 @@ from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv
from app.common.redis_cli import redis_conn
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
import torch
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
def add_project(info: ProjectInfoIn, session: Session, user_id: int):

View File

@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
font = Path(font)
file = CONFIG_DIR / font.name
if not font.exists() and not file.exists():
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
url = f"https://ultralytics.com/assets/{font.name}"
LOGGER.info(f"Downloading {url} to {file}...")
torch.hub.download_url_to_file(url, str(file), progress=progress)

View File

@ -31,11 +31,11 @@ PyYAML>=5.3.1
requests>=2.32.2
scipy==1.13.1
thop>=0.1.1 # FLOPs computation
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
# torch==2.6.0+cu124 # 本地安装
# torchvision==0.21.0+cu124 # 本地安装
tqdm>=4.66.3
ultralytics==8.3.75 # https://ultralytics.com
# Plotting --------------------------------------------------------------------
pandas==2.2.3
seaborn==0.11.0 # 对应的这个依赖也需要降
seaborn==0.11.2 # 对应的这个依赖也需要降

5
toch_test.py Normal file
View File

@ -0,0 +1,5 @@
import torch
is_gpu = torch.cuda.is_available()
print(str(is_gpu))

View File

@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast(device_type='cuda', enabled=amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:

View File

@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
font = Path(font)
file = CONFIG_DIR / font.name
if not font.exists() and not file.exists():
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
url = f"https://ultralytics.com/assets/{font.name}"
LOGGER.info(f"Downloading {url} to {file}...")
torch.hub.download_url_to_file(url, str(file), progress=progress)