使用cuda版torch进行训练和推理
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
@ -20,6 +19,7 @@ from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
from app.common.redis_cli import redis_conn
|
||||
|
||||
|
||||
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||
@ -160,8 +160,12 @@ async def run_detect_img(weights: str, source: str, project: str, name: str, log
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
room = 'detect_' + str(detect_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
||||
is_gpu = redis_conn.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device", "0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
@ -209,6 +213,10 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
|
||||
room = 'detect_rtsp_' + str(detect_id)
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
is_gpu = redis_conn.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
device = select_device('cuda:0')
|
||||
|
||||
# 加载模型
|
||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
|
Reference in New Issue
Block a user