使用cuda版torch进行训练和推理

This commit is contained in:
2025-03-17 15:45:46 +08:00
parent ff7b6fda0e
commit ebc4e9df4f
5 changed files with 37 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import asyncio import asyncio
import torch
from app.application.token_middleware import TokenMiddleware from app.application.token_middleware import TokenMiddleware
from app.application.logger_middleware import LoggerMiddleware from app.application.logger_middleware import LoggerMiddleware

View File

@ -1,5 +1,4 @@
import time import time
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
@ -20,6 +19,7 @@ from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import yolo_url from app.config.config_reader import yolo_url
from app.websocket.web_socket_server import room_manager from app.websocket.web_socket_server import room_manager
from app.common.redis_cli import redis_conn
def add_detect(detect_in: ProjectDetectIn, session: Session): def add_detect(detect_in: ProjectDetectIn, session: Session):
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
""" """
yolo_path = os.file_path(yolo_url, 'detect.py') yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id) room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n") await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"] commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
commend, commend,
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
room = 'detect_rtsp_' + str(detect_id) room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU # 选择设备CPU 或 GPU
device = select_device('cpu') device = select_device('cpu')
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型 # 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False) model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)

View File

@ -12,12 +12,14 @@ from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv from app.util.csv_utils import read_csv
from app.common.redis_cli import redis_conn
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
import yaml import yaml
import subprocess import subprocess
import torch
def add_project(info: ProjectInfoIn, session: Session, user_id: int): def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -180,9 +182,7 @@ def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, sessi
return data, project, name return data, project, name
async def run_commend(data: str, project: str, async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
name: str, epochs: int, patience: int, weights: str,
project_id: int, session: Session):
""" """
执行训练 执行训练
:param data: 训练数据集 :param data: 训练数据集
@ -197,13 +197,20 @@ async def run_commend(data: str, project: str,
""" """
yolo_path = os.file_path(yolo_url, 'train.py') yolo_path = os.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id) room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n") await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)] "--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻
if weights != '' and weights is not None: if weights != '' and weights is not None:
train_info = ptc.get_train(int(weights), session) train_info = ptc.get_train(int(weights), session)
if train_info is not None: if train_info is not None:
commend.append("--weights=" + train_info.best_pt) commend.append("--weights=" + train_info.best_pt)
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
commend, commend,
@ -212,7 +219,7 @@ async def run_commend(data: str, project: str,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息 stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题 text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8', encoding='latin1',
) as process: ) as process:
while process.poll() is None: while process.poll() is None:
line = process.stdout.readline() line = process.stdout.readline()

View File

@ -1,5 +1,12 @@
import uvicorn import uvicorn
import torch
from app.common.redis_cli import redis_conn
from app.application.app import my_app from app.application.app import my_app
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=True) # 在主线程初始化cuda
is_gpu = torch.cuda.is_available()
redis_conn.set('is_gpu', str(is_gpu))
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=False)

View File

@ -24,7 +24,7 @@ python-socketio == 5.12.1
# BASE ------------------------------------------------------------------------ # BASE ------------------------------------------------------------------------
gitpython>=3.1.30 gitpython>=3.1.30
matplotlib>=3.3 matplotlib>=3.3
numpy==2.0.2 numpy==1.24.0 # cuda版本的numpy需要降一下版本
opencv-python>=4.1.1 opencv-python>=4.1.1
pillow>=10.3.0 pillow>=10.3.0
psutil # system resources psutil # system resources
@ -32,12 +32,11 @@ PyYAML>=5.3.1
requests>=2.32.2 requests>=2.32.2
scipy==1.13.1 scipy==1.13.1
thop>=0.1.1 # FLOPs computation thop>=0.1.1 # FLOPs computation
torch --extra-index-url https://download.pytorch.org/whl/cu121 torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
torchvision --extra-index-url https://download.pytorch.org/whl/cu121 torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
tqdm>=4.66.3 tqdm>=4.66.3
ultralytics>=8.2.34 # https://ultralytics.com ultralytics==8.3.75 # https://ultralytics.com
# Plotting -------------------------------------------------------------------- # Plotting --------------------------------------------------------------------
pandas==2.2.3 pandas==2.2.3
seaborn>=0.11.0 seaborn==0.11.0 # 对应的这个依赖也需要降