使用cuda版torch进行训练和推理
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import torch
|
||||||
|
|
||||||
from app.application.token_middleware import TokenMiddleware
|
from app.application.token_middleware import TokenMiddleware
|
||||||
from app.application.logger_middleware import LoggerMiddleware
|
from app.application.logger_middleware import LoggerMiddleware
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
@ -20,6 +19,7 @@ from app.util import os_utils as os
|
|||||||
from app.util import random_utils as ru
|
from app.util import random_utils as ru
|
||||||
from app.config.config_reader import yolo_url
|
from app.config.config_reader import yolo_url
|
||||||
from app.websocket.web_socket_server import room_manager
|
from app.websocket.web_socket_server import room_manager
|
||||||
|
from app.common.redis_cli import redis_conn
|
||||||
|
|
||||||
|
|
||||||
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||||
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
|
|||||||
"""
|
"""
|
||||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||||
room = 'detect_' + str(detect_id)
|
room = 'detect_' + str(detect_id)
|
||||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
||||||
|
is_gpu = redis_conn.get('is_gpu')
|
||||||
|
# 判断是否存在cuda版本
|
||||||
|
if is_gpu == 'True':
|
||||||
|
commend.append("--device", "0")
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
commend,
|
commend,
|
||||||
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
|||||||
room = 'detect_rtsp_' + str(detect_id)
|
room = 'detect_rtsp_' + str(detect_id)
|
||||||
# 选择设备(CPU 或 GPU)
|
# 选择设备(CPU 或 GPU)
|
||||||
device = select_device('cpu')
|
device = select_device('cpu')
|
||||||
|
is_gpu = redis_conn.get('is_gpu')
|
||||||
|
# 判断是否存在cuda版本
|
||||||
|
if is_gpu == 'True':
|
||||||
|
device = select_device('cuda:0')
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||||
|
@ -12,12 +12,14 @@ from app.util import random_utils as ru
|
|||||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||||
from app.websocket.web_socket_server import room_manager
|
from app.websocket.web_socket_server import room_manager
|
||||||
from app.util.csv_utils import read_csv
|
from app.util.csv_utils import read_csv
|
||||||
|
from app.common.redis_cli import redis_conn
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
import yaml
|
import yaml
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||||
@ -180,9 +182,7 @@ def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, sessi
|
|||||||
return data, project, name
|
return data, project, name
|
||||||
|
|
||||||
|
|
||||||
async def run_commend(data: str, project: str,
|
async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
|
||||||
name: str, epochs: int, patience: int, weights: str,
|
|
||||||
project_id: int, session: Session):
|
|
||||||
"""
|
"""
|
||||||
执行训练
|
执行训练
|
||||||
:param data: 训练数据集
|
:param data: 训练数据集
|
||||||
@ -197,13 +197,20 @@ async def run_commend(data: str, project: str,
|
|||||||
"""
|
"""
|
||||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||||
room = 'train_' + str(project_id)
|
room = 'train_' + str(project_id)
|
||||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||||
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
|
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
|
||||||
|
|
||||||
|
# 增加权重文件,在之前训练的基础上重新巡逻
|
||||||
if weights != '' and weights is not None:
|
if weights != '' and weights is not None:
|
||||||
train_info = ptc.get_train(int(weights), session)
|
train_info = ptc.get_train(int(weights), session)
|
||||||
if train_info is not None:
|
if train_info is not None:
|
||||||
commend.append("--weights=" + train_info.best_pt)
|
commend.append("--weights=" + train_info.best_pt)
|
||||||
|
|
||||||
|
is_gpu = redis_conn.get('is_gpu')
|
||||||
|
# 判断是否存在cuda版本
|
||||||
|
if is_gpu == 'True':
|
||||||
|
commend.append("--device=0")
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
commend,
|
commend,
|
||||||
@ -212,7 +219,7 @@ async def run_commend(data: str, project: str,
|
|||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||||
encoding='utf-8',
|
encoding='latin1',
|
||||||
) as process:
|
) as process:
|
||||||
while process.poll() is None:
|
while process.poll() is None:
|
||||||
line = process.stdout.readline()
|
line = process.stdout.readline()
|
||||||
|
9
main.py
9
main.py
@ -1,5 +1,12 @@
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from app.common.redis_cli import redis_conn
|
||||||
from app.application.app import my_app
|
from app.application.app import my_app
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=True)
|
# 在主线程初始化cuda
|
||||||
|
is_gpu = torch.cuda.is_available()
|
||||||
|
redis_conn.set('is_gpu', str(is_gpu))
|
||||||
|
uvicorn.run("main:my_app", host='0.0.0.0', port=8080, reload=False)
|
||||||
|
@ -24,7 +24,7 @@ python-socketio == 5.12.1
|
|||||||
# BASE ------------------------------------------------------------------------
|
# BASE ------------------------------------------------------------------------
|
||||||
gitpython>=3.1.30
|
gitpython>=3.1.30
|
||||||
matplotlib>=3.3
|
matplotlib>=3.3
|
||||||
numpy==2.0.2
|
numpy==1.24.0 # cuda版本的numpy需要降一下版本
|
||||||
opencv-python>=4.1.1
|
opencv-python>=4.1.1
|
||||||
pillow>=10.3.0
|
pillow>=10.3.0
|
||||||
psutil # system resources
|
psutil # system resources
|
||||||
@ -32,12 +32,11 @@ PyYAML>=5.3.1
|
|||||||
requests>=2.32.2
|
requests>=2.32.2
|
||||||
scipy==1.13.1
|
scipy==1.13.1
|
||||||
thop>=0.1.1 # FLOPs computation
|
thop>=0.1.1 # FLOPs computation
|
||||||
torch --extra-index-url https://download.pytorch.org/whl/cu121
|
torch==2.3.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
||||||
torchvision --extra-index-url https://download.pytorch.org/whl/cu121
|
torchvision==0.18.1+cu121 --extra-index-url https://mirrors.aliyun.com/pytorch-wheels/cu121/
|
||||||
torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
|
|
||||||
tqdm>=4.66.3
|
tqdm>=4.66.3
|
||||||
ultralytics>=8.2.34 # https://ultralytics.com
|
ultralytics==8.3.75 # https://ultralytics.com
|
||||||
|
|
||||||
# Plotting --------------------------------------------------------------------
|
# Plotting --------------------------------------------------------------------
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
seaborn>=0.11.0
|
seaborn==0.11.0 # 对应的这个依赖也需要降
|
Reference in New Issue
Block a user