使用cuda版torch进行训练和推理

This commit is contained in:
2025-03-17 15:45:46 +08:00
parent ff7b6fda0e
commit ebc4e9df4f
5 changed files with 37 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import torch
from app.application.token_middleware import TokenMiddleware
from app.application.logger_middleware import LoggerMiddleware

View File

@ -1,5 +1,4 @@
import time
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
@ -20,6 +19,7 @@ from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import yolo_url
from app.websocket.web_socket_server import room_manager
from app.common.redis_cli import redis_conn
def add_detect(detect_in: ProjectDetectIn, session: Session):
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程
with subprocess.Popen(
commend,
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)

View File

@ -12,12 +12,14 @@ from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv
from app.common.redis_cli import redis_conn
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
import torch
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -180,9 +182,7 @@ def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, sessi
return data, project, name
async def run_commend(data: str, project: str,
name: str, epochs: int, patience: int, weights: str,
project_id: int, session: Session):
async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
"""
执行训练
:param data: 训练数据集
@ -197,13 +197,20 @@ async def run_commend(data: str, project: str,
"""
yolo_path = os.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻
if weights != '' and weights is not None:
train_info = ptc.get_train(int(weights), session)
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
# 启动子进程
with subprocess.Popen(
commend,
@ -212,7 +219,7 @@ async def run_commend(data: str, project: str,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
encoding='latin1',
) as process:
while process.poll() is None:
line = process.stdout.readline()

View File

@ -1,5 +1,12 @@
import uvicorn
import torch
from app.common.redis_cli import redis_conn
from app.application.app import my_app
if __name__ == '__main__':
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=True)
# 在主线程初始化cuda
is_gpu = torch.cuda.is_available()
redis_conn.set('is_gpu', str(is_gpu))
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=False)

View File

@ -24,7 +24,7 @@ python-socketio == 5.12.1
# BASE ------------------------------------------------------------------------
gitpython>=3.1.30
matplotlib>=3.3
numpy==2.0.2
numpy==1.24.0 # cuda版本的numpy需要降一下版本
opencv-python>=4.1.1
pillow>=10.3.0
psutil # system resources
@ -32,12 +32,11 @@ PyYAML>=5.3.1
requests>=2.32.2
scipy==1.13.1
thop>=0.1.1 # FLOPs computation
torch --extra-index-url https://download.pytorch.org/whl/cu121
torchvision --extra-index-url https://download.pytorch.org/whl/cu121
torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
tqdm>=4.66.3
ultralytics>=8.2.34 # https://ultralytics.com
ultralytics==8.3.75 # https://ultralytics.com
# Plotting --------------------------------------------------------------------
pandas==2.2.3
seaborn>=0.11.0
seaborn==0.11.0 # 对应的这个依赖也需要降