From 3d39e89e26d60e9eff4ec23cdbeb449c35b9d481 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Thu, 6 Mar 2025 16:35:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E9=83=A8=E5=88=86=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_detect_api.py | 16 ++++++++++++++-- app/api/business/project_train_api.py | 2 +- app/model/crud/project_detect_crud.py | 12 ++++++++++++ yolov5/train.py | 2 +- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index e06fa60..fd3ba54 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -1,14 +1,15 @@ from typing import List from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi.responses import StreamingResponse -from fastapi.encoders import jsonable_encoder +from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session from app.common import reponse_code as rc from app.model.crud import project_detect_crud as pdc from app.service import project_detect_service as pds from app.model.crud.project_train_crud import get_train -from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn +from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\ + ProjectDetectImgPager, ProjectDetectLogIn from app.db.db_session import get_db detect = APIRouter() @@ -26,6 +27,17 @@ def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_d return rc.response_success_pager(pager) +@detect.get("/detect_list/{project_id}") +def get_detect_list(project_id: int, session: Session = Depends(get_db)): + """ + 根据项目id获取全部推理集合 + :param project_id: + :param session: + :return: + """ + return rc.response_success(data=pdc.get_detect_list(project_id, session)) + + @detect.post("/add_detect") def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)): """ diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index ecae2db..65ecc52 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -230,7 +230,7 @@ async def run_train(project_id: int, session: Session = Depends(get_db)): return rc.response_error("项目当前存在训练进程,请稍后再试") data, project_name, name = ps.run_train_yolo(project_info, session) return StreamingResponse( - ps.run_commend(data, project_name, name, 100, project_id, session), + ps.run_commend(data, project_name, name, 50, project_id, session), media_type="text/plain") diff --git a/app/model/crud/project_detect_crud.py b/app/model/crud/project_detect_crud.py index 10d444a..6e06783 100644 --- a/app/model/crud/project_detect_crud.py +++ b/app/model/crud/project_detect_crud.py @@ -51,6 +51,18 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): return pager +def get_detect_list(project_id: int, session: Session): + """ + 根据项目id查询所有集合的列表 + :param project_id: + :param session: + :return: + """ + query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id)) + result = jsonable_encoder(query.all()) + return result + + def update_detect_status(detect_id: int, detect_status: int, session: Session): """ 更新项目训练状态,如果是已完成的话,train_version自动+1 diff --git a/yolov5/train.py b/yolov5/train.py index a5734f6..8d8e2de 100644 --- a/yolov5/train.py +++ b/yolov5/train.py @@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.cuda.amp.autocast(amp): + with torch.amp.autocast(device_type='cuda', enabled=amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: