新增部分接口

This commit is contained in:
2025-03-06 16:35:01 +08:00
parent 05e2e8a59a
commit 3d39e89e26
4 changed files with 28 additions and 4 deletions

View File

@ -1,14 +1,15 @@
from typing import List from typing import List
from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.model.crud import project_detect_crud as pdc from app.model.crud import project_detect_crud as pdc
from app.service import project_detect_service as pds from app.service import project_detect_service as pds
from app.model.crud.project_train_crud import get_train from app.model.crud.project_train_crud import get_train
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
ProjectDetectImgPager, ProjectDetectLogIn
from app.db.db_session import get_db from app.db.db_session import get_db
detect = APIRouter() detect = APIRouter()
@ -26,6 +27,17 @@ def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_d
return rc.response_success_pager(pager) return rc.response_success_pager(pager)
@detect.get("/detect_list/{project_id}")
def get_detect_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id获取全部推理集合
:param project_id:
:param session:
:return:
"""
return rc.response_success(data=pdc.get_detect_list(project_id, session))
@detect.post("/add_detect") @detect.post("/add_detect")
def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)): def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
""" """

View File

@ -230,7 +230,7 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
return rc.response_error("项目当前存在训练进程,请稍后再试") return rc.response_error("项目当前存在训练进程,请稍后再试")
data, project_name, name = ps.run_train_yolo(project_info, session) data, project_name, name = ps.run_train_yolo(project_info, session)
return StreamingResponse( return StreamingResponse(
ps.run_commend(data, project_name, name, 100, project_id, session), ps.run_commend(data, project_name, name, 50, project_id, session),
media_type="text/plain") media_type="text/plain")

View File

@ -51,6 +51,18 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
return pager return pager
def get_detect_list(project_id: int, session: Session):
"""
根据项目id查询所有集合的列表
:param project_id:
:param session:
:return:
"""
query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id))
result = jsonable_encoder(query.all())
return result
def update_detect_status(detect_id: int, detect_status: int, session: Session): def update_detect_status(detect_id: int, detect_status: int, session: Session):
""" """
更新项目训练状态,如果是已完成的话train_version自动+1 更新项目训练状态,如果是已完成的话train_version自动+1

View File

@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward # Forward
with torch.cuda.amp.autocast(amp): with torch.amp.autocast(device_type='cuda', enabled=amp):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: if RANK != -1: