完成推理模块的转移
This commit is contained in:
@ -5,15 +5,21 @@
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
from core.dependencies import IdList
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from core.database import db_getter
|
||||
from . import schemas, crud, models, params
|
||||
from fastapi import Depends, APIRouter
|
||||
from utils.response import SuccessResponse
|
||||
|
||||
import service
|
||||
from . import schemas, crud, params
|
||||
from core.dependencies import IdList
|
||||
from core.database import redis_getter
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
|
||||
import threading
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
@ -22,129 +28,120 @@ app = APIRouter()
|
||||
###########################################################
|
||||
# 项目推理集合信息
|
||||
###########################################################
|
||||
@app.get("/project/detect", summary="获取项目推理集合信息列表", tags=["项目推理集合信息"])
|
||||
async def get_project_detect_list(p: params.ProjectDetectParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/list", summary="获取项目推理集合信息列表")
|
||||
async def detect_list(
|
||||
p: params.ProjectDetectParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect", summary="创建项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def create_project_detect(data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).create_data(data=data))
|
||||
@app.post("/", summary="创建项目推理集合信息")
|
||||
async def add_detect(
|
||||
data: schemas.ProjectDetectIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
if await detect_dal.check_name(data.detect_name, data.project_id):
|
||||
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
||||
await detect_dal.create_data(data=data)
|
||||
return SuccessResponse(msg="保存成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect", summary="删除项目推理集合信息", description="硬删除", tags=["项目推理集合信息"])
|
||||
async def delete_project_detect_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.delete("/", summary="删除项目推理集合信息")
|
||||
async def delete_detect(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/{data_id}", summary="更新项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def put_project_detect(data_id: int, data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/{data_id}", summary="获取项目推理集合信息信息", tags=["项目推理集合信息"])
|
||||
async def get_project_detect(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理集合图片信息
|
||||
# 项目推理集合文件信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/img", summary="获取项目推理集合图片信息列表", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img_list(p: params.ProjectDetectImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
||||
async def file_list(
|
||||
p: params.ProjectDetectFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
if p.limit:
|
||||
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
else:
|
||||
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/project/detect/img", summary="创建项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def create_project_detect_img(data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).create_data(data=data))
|
||||
@app.post("/file", summary="上传项目推理集合文件")
|
||||
async def upload_file(
|
||||
detect_id: int = Form(...),
|
||||
files: list[UploadFile] = Form(...),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
file_dal = crud.ProjectDetectFileDal(auth.db)
|
||||
detect_out = file_dal.get_data(data_id=detect_id)
|
||||
if detect_out is None:
|
||||
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
||||
await file_dal.add_file(detect_out, files)
|
||||
return SuccessResponse(msg="上传成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect/img", summary="删除项目推理集合图片信息", description="硬删除", tags=["项目推理集合图片信息"])
|
||||
async def delete_project_detect_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
@app.delete("/file", summary="删除项目推理集合文件信息")
|
||||
async def delete_file(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/img/{data_id}", summary="更新项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def put_project_detect_img(data_id: int, data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/img/{data_id}", summary="获取项目推理集合图片信息信息", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
@app.post("/detect", summary="开始推理")
|
||||
def run_detect_yolo(
|
||||
detect_log_in: schemas.ProjectDetectLogIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
train_dal = ProjectTrainDal(auth.db)
|
||||
detect = detect_dal.get_data(detect_log_in.detect_id)
|
||||
if detect is None:
|
||||
return ErrorResponse(msg="训练集合不存在")
|
||||
train = train_dal.get_data(detect_log_in.train_id)
|
||||
if train is None:
|
||||
return ErrorResponse("训练权重不存在")
|
||||
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
if file_count == 0 and detect.rtsp_url is None:
|
||||
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
||||
if detect.file_type == 'img' or detect.file_type == 'video':
|
||||
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
||||
thread_train = threading.Thread(target=service.run_img_loop,
|
||||
args=(detect_log.pt_url, detect_log.folder_url,
|
||||
detect_log.detect_folder_url, detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, auth.db,))
|
||||
thread_train.start()
|
||||
elif detect.file_type == 'rtsp':
|
||||
room = 'detect_rtsp_' + str(detect.id)
|
||||
if not room_manager.rooms.get(room):
|
||||
if detect_log_in.pt_type == 'best':
|
||||
weights_pt = train.best_pt
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log", summary="获取项目推理记录信息列表", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log_list(p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/log", summary="获取项目推理记录列表")
|
||||
async def log_pager(
|
||||
p: params.ProjectDetectLogParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log", summary="创建项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def create_project_detect_log(data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log", summary="删除项目推理记录信息", description="硬删除", tags=["项目推理记录信息"])
|
||||
async def delete_project_detect_log_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/{data_id}", summary="更新项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def put_project_detect_log(data_id: int, data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/{data_id}", summary="获取项目推理记录信息信息", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录图片信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log/img", summary="获取项目推理记录图片信息列表", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img_list(p: params.ProjectDetectLogImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log/img", summary="创建项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def create_project_detect_log_img(data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log/img", summary="删除项目推理记录图片信息", description="硬删除", tags=["项目推理记录图片信息"])
|
||||
async def delete_project_detect_log_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/img/{data_id}", summary="更新项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def put_project_detect_log_img(data_id: int, data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/img/{data_id}", summary="获取项目推理记录图片信息信息", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(db).get_data(data_id, v_schema=schema))
|
||||
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
||||
async def log_files(
|
||||
p: params.ProjectDetectLogFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
Reference in New Issue
Block a user