新增部分接口
This commit is contained in:
@ -1,14 +1,15 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.model.crud import project_detect_crud as pdc
|
from app.model.crud import project_detect_crud as pdc
|
||||||
from app.service import project_detect_service as pds
|
from app.service import project_detect_service as pds
|
||||||
from app.model.crud.project_train_crud import get_train
|
from app.model.crud.project_train_crud import get_train
|
||||||
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn
|
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
|
||||||
|
ProjectDetectImgPager, ProjectDetectLogIn
|
||||||
from app.db.db_session import get_db
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
detect = APIRouter()
|
detect = APIRouter()
|
||||||
@ -26,6 +27,17 @@ def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_d
|
|||||||
return rc.response_success_pager(pager)
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
|
@detect.get("/detect_list/{project_id}")
|
||||||
|
def get_detect_list(project_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
根据项目id获取全部推理集合
|
||||||
|
:param project_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return rc.response_success(data=pdc.get_detect_list(project_id, session))
|
||||||
|
|
||||||
|
|
||||||
@detect.post("/add_detect")
|
@detect.post("/add_detect")
|
||||||
def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
|
def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
|
@ -230,7 +230,7 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
ps.run_commend(data, project_name, name, 100, project_id, session),
|
ps.run_commend(data, project_name, name, 50, project_id, session),
|
||||||
media_type="text/plain")
|
media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +51,18 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
|||||||
return pager
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_detect_list(project_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
根据项目id查询所有集合的列表
|
||||||
|
:param project_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id))
|
||||||
|
result = jsonable_encoder(query.all())
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def update_detect_status(detect_id: int, detect_status: int, session: Session):
|
def update_detect_status(detect_id: int, detect_status: int, session: Session):
|
||||||
"""
|
"""
|
||||||
更新项目训练状态,如果是已完成的话,train_version自动+1
|
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||||||
|
@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
|
|||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with torch.amp.autocast(device_type='cuda', enabled=amp):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
Reference in New Issue
Block a user