优化训练过程

This commit is contained in:
2025-03-20 16:41:51 +08:00
parent 358bb40a2a
commit bba39adcfc
7 changed files with 16 additions and 12 deletions

View File

@ -161,7 +161,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
yolo_path = os.file_path(yolo_url, 'detect.py') yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id) room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n") await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"] commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = redis_conn.get('is_gpu') is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本 # 判断是否存在cuda版本
if is_gpu == 'True': if is_gpu == 'True':
@ -259,7 +259,7 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
pred = model(im, augment=False, visualize=False) pred = model(im, augment=False, visualize=False)
# NMS # NMS
with dt[2]: with dt[2]:
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000) pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
# Process predictions # Process predictions
for i, det in enumerate(pred): # per image for i, det in enumerate(pred): # per image

View File

@ -14,12 +14,11 @@ from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv from app.util.csv_utils import read_csv
from app.common.redis_cli import redis_conn from app.common.redis_cli import redis_conn
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml import yaml
import subprocess import subprocess
import torch from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
def add_project(info: ProjectInfoIn, session: Session, user_id: int): def add_project(info: ProjectInfoIn, session: Session, user_id: int):

View File

@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
font = Path(font) font = Path(font)
file = CONFIG_DIR / font.name file = CONFIG_DIR / font.name
if not font.exists() and not file.exists(): if not font.exists() and not file.exists():
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}" url = f"https://ultralytics.com/assets/{font.name}"
LOGGER.info(f"Downloading {url} to {file}...") LOGGER.info(f"Downloading {url} to {file}...")
torch.hub.download_url_to_file(url, str(file), progress=progress) torch.hub.download_url_to_file(url, str(file), progress=progress)

View File

@ -31,11 +31,11 @@ PyYAML>=5.3.1
requests>=2.32.2 requests>=2.32.2
scipy==1.13.1 scipy==1.13.1
thop>=0.1.1 # FLOPs computation thop>=0.1.1 # FLOPs computation
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/ # torch==2.6.0+cu124 # 本地安装
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/ # torchvision==0.21.0+cu124 # 本地安装
tqdm>=4.66.3 tqdm>=4.66.3
ultralytics==8.3.75 # https://ultralytics.com ultralytics==8.3.75 # https://ultralytics.com
# Plotting -------------------------------------------------------------------- # Plotting --------------------------------------------------------------------
pandas==2.2.3 pandas==2.2.3
seaborn==0.11.0 # 对应的这个依赖也需要降 seaborn==0.11.2 # 对应的这个依赖也需要降

5
toch_test.py Normal file
View File

@ -0,0 +1,5 @@
import torch
is_gpu = torch.cuda.is_available()
print(str(is_gpu))

View File

@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward # Forward
with torch.cuda.amp.autocast(amp): with torch.amp.autocast(device_type='cuda', enabled=amp):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: if RANK != -1:

View File

@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
font = Path(font) font = Path(font)
file = CONFIG_DIR / font.name file = CONFIG_DIR / font.name
if not font.exists() and not file.exists(): if not font.exists() and not file.exists():
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}" url = f"https://ultralytics.com/assets/{font.name}"
LOGGER.info(f"Downloading {url} to {file}...") LOGGER.info(f"Downloading {url} to {file}...")
torch.hub.download_url_to_file(url, str(file), progress=progress) torch.hub.download_url_to_file(url, str(file), progress=progress)