使用cuda版torch进行训练和推理

This commit is contained in:
2025-03-17 15:45:46 +08:00
parent ff7b6fda0e
commit ebc4e9df4f
5 changed files with 37 additions and 15 deletions

View File

@ -1,5 +1,4 @@
import time
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
@ -20,6 +19,7 @@ from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import yolo_url
from app.websocket.web_socket_server import room_manager
from app.common.redis_cli import redis_conn
def add_detect(detect_in: ProjectDetectIn, session: Session):
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程
with subprocess.Popen(
commend,
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)