优化训练过程
This commit is contained in:
@ -161,7 +161,7 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
|
|||||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||||
room = 'detect_' + str(detect_id)
|
room = 'detect_' + str(detect_id)
|
||||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt", "--conf-thres", "0.4"]
|
||||||
is_gpu = redis_conn.get('is_gpu')
|
is_gpu = redis_conn.get('is_gpu')
|
||||||
# 判断是否存在cuda版本
|
# 判断是否存在cuda版本
|
||||||
if is_gpu == 'True':
|
if is_gpu == 'True':
|
||||||
@ -259,7 +259,7 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
|||||||
pred = model(im, augment=False, visualize=False)
|
pred = model(im, augment=False, visualize=False)
|
||||||
# NMS
|
# NMS
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
|
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
|
||||||
|
|
||||||
# Process predictions
|
# Process predictions
|
||||||
for i, det in enumerate(pred): # per image
|
for i, det in enumerate(pred): # per image
|
||||||
|
@ -14,12 +14,11 @@ from app.websocket.web_socket_server import room_manager
|
|||||||
from app.util.csv_utils import read_csv
|
from app.util.csv_utils import read_csv
|
||||||
from app.common.redis_cli import redis_conn
|
from app.common.redis_cli import redis_conn
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import List
|
|
||||||
from fastapi import UploadFile
|
|
||||||
import yaml
|
import yaml
|
||||||
import subprocess
|
import subprocess
|
||||||
import torch
|
from typing import List
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||||
|
@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
|
|||||||
font = Path(font)
|
font = Path(font)
|
||||||
file = CONFIG_DIR / font.name
|
file = CONFIG_DIR / font.name
|
||||||
if not font.exists() and not file.exists():
|
if not font.exists() and not file.exists():
|
||||||
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
|
url = f"https://ultralytics.com/assets/{font.name}"
|
||||||
LOGGER.info(f"Downloading {url} to {file}...")
|
LOGGER.info(f"Downloading {url} to {file}...")
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||||
|
|
||||||
|
@ -31,11 +31,11 @@ PyYAML>=5.3.1
|
|||||||
requests>=2.32.2
|
requests>=2.32.2
|
||||||
scipy==1.13.1
|
scipy==1.13.1
|
||||||
thop>=0.1.1 # FLOPs computation
|
thop>=0.1.1 # FLOPs computation
|
||||||
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
# torch==2.6.0+cu124 # 本地安装
|
||||||
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
# torchvision==0.21.0+cu124 # 本地安装
|
||||||
tqdm>=4.66.3
|
tqdm>=4.66.3
|
||||||
ultralytics==8.3.75 # https://ultralytics.com
|
ultralytics==8.3.75 # https://ultralytics.com
|
||||||
|
|
||||||
# Plotting --------------------------------------------------------------------
|
# Plotting --------------------------------------------------------------------
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
seaborn==0.11.0 # 对应的这个依赖也需要降
|
seaborn==0.11.2 # 对应的这个依赖也需要降
|
5
toch_test.py
Normal file
5
toch_test.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
is_gpu = torch.cuda.is_available()
|
||||||
|
|
||||||
|
print(str(is_gpu))
|
@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
|
|||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with torch.amp.autocast(device_type='cuda', enabled=amp):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
|
|||||||
font = Path(font)
|
font = Path(font)
|
||||||
file = CONFIG_DIR / font.name
|
file = CONFIG_DIR / font.name
|
||||||
if not font.exists() and not file.exists():
|
if not font.exists() and not file.exists():
|
||||||
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
|
url = f"https://ultralytics.com/assets/{font.name}"
|
||||||
LOGGER.info(f"Downloading {url} to {file}...")
|
LOGGER.info(f"Downloading {url} to {file}...")
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user