使用cuda版torch进行训练和推理
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
from app.application.token_middleware import TokenMiddleware
|
||||
from app.application.logger_middleware import LoggerMiddleware
|
||||
|
@ -1,5 +1,4 @@
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
@ -20,6 +19,7 @@ from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
from app.common.redis_cli import redis_conn
|
||||
|
||||
|
||||
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
room = 'detect_' + str(detect_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
||||
is_gpu = redis_conn.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device", "0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
||||
room = 'detect_rtsp_' + str(detect_id)
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
is_gpu = redis_conn.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
device = select_device('cuda:0')
|
||||
|
||||
# 加载模型
|
||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
|
@ -12,12 +12,14 @@ from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
from app.util.csv_utils import read_csv
|
||||
from app.common.redis_cli import redis_conn
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -180,9 +182,7 @@ def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, sessi
|
||||
return data, project, name
|
||||
|
||||
|
||||
async def run_commend(data: str, project: str,
|
||||
name: str, epochs: int, patience: int, weights: str,
|
||||
project_id: int, session: Session):
|
||||
async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
|
||||
"""
|
||||
执行训练
|
||||
:param data: 训练数据集
|
||||
@ -197,13 +197,20 @@ async def run_commend(data: str, project: str,
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
room = 'train_' + str(project_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
|
||||
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
|
||||
|
||||
# 增加权重文件,在之前训练的基础上重新巡逻
|
||||
if weights != '' and weights is not None:
|
||||
train_info = ptc.get_train(int(weights), session)
|
||||
if train_info is not None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
|
||||
is_gpu = redis_conn.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device=0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
@ -212,7 +219,7 @@ async def run_commend(data: str, project: str,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||
encoding='utf-8',
|
||||
encoding='latin1',
|
||||
) as process:
|
||||
while process.poll() is None:
|
||||
line = process.stdout.readline()
|
||||
|
9
main.py
9
main.py
@ -1,5 +1,12 @@
|
||||
import uvicorn
|
||||
import torch
|
||||
|
||||
from app.common.redis_cli import redis_conn
|
||||
from app.application.app import my_app
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=True)
|
||||
# 在主线程初始化cuda
|
||||
is_gpu = torch.cuda.is_available()
|
||||
redis_conn.set('is_gpu', str(is_gpu))
|
||||
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=False)
|
||||
|
@ -24,7 +24,7 @@ python-socketio == 5.12.1
|
||||
# BASE ------------------------------------------------------------------------
|
||||
gitpython>=3.1.30
|
||||
matplotlib>=3.3
|
||||
numpy==2.0.2
|
||||
numpy==1.24.0 # cuda版本的numpy需要降一下版本
|
||||
opencv-python>=4.1.1
|
||||
pillow>=10.3.0
|
||||
psutil # system resources
|
||||
@ -32,12 +32,11 @@ PyYAML>=5.3.1
|
||||
requests>=2.32.2
|
||||
scipy==1.13.1
|
||||
thop>=0.1.1 # FLOPs computation
|
||||
torch --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torchvision --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
||||
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
||||
tqdm>=4.66.3
|
||||
ultralytics>=8.2.34 # https://ultralytics.com
|
||||
ultralytics==8.3.75 # https://ultralytics.com
|
||||
|
||||
# Plotting --------------------------------------------------------------------
|
||||
pandas==2.2.3
|
||||
seaborn>=0.11.0
|
||||
seaborn==0.11.0 # 对应的这个依赖也需要降
|
Reference in New Issue
Block a user