完成推理模块的主体功能
This commit is contained in:
@ -1,11 +1,14 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.model.crud import project_detect_crud as pdc
|
from app.model.crud import project_detect_crud as pdc
|
||||||
from app.model.schemas.project_detect_schemas import ProjectDetectPager
|
from app.service import project_detect_service as pds
|
||||||
|
from app.model.crud.project_train_crud import get_train
|
||||||
|
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn
|
||||||
from app.db.db_session import get_db
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
detect = APIRouter()
|
detect = APIRouter()
|
||||||
@ -14,10 +17,121 @@ detect = APIRouter()
|
|||||||
@detect.post("/detect_pager")
|
@detect.post("/detect_pager")
|
||||||
def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)):
|
def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
获取训练集的照片
|
获取训练集合
|
||||||
:param detect_pager:
|
:param detect_pager:
|
||||||
:param session:
|
:param session:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pager = pdc.get_detect_pager(detect_pager, session)
|
pager = pdc.get_detect_pager(detect_pager, session)
|
||||||
return rc.response_success_pager(pager)
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
|
@detect.post("/add_detect")
|
||||||
|
def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
新增训练集合
|
||||||
|
:param detect_in:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pds.add_detect(detect_in, session)
|
||||||
|
return rc.response_success("新增成功")
|
||||||
|
|
||||||
|
|
||||||
|
@detect.post("/get_img_list")
|
||||||
|
def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
查询训练集合中的图片列表
|
||||||
|
:param detect_img_pager:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if detect_img_pager.pagerNum is None and detect_img_pager.pagerSize is None:
|
||||||
|
img_list = pdc.get_img_list(detect_img_pager.detect_id, session)
|
||||||
|
img_list = jsonable_encoder(img_list)
|
||||||
|
return rc.response_success(data=img_list)
|
||||||
|
else:
|
||||||
|
pager = pdc.get_img_pager(detect_img_pager, session)
|
||||||
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
|
@detect.post("/upload_detect_img")
|
||||||
|
def upload_detect_img(detect_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
上传训练集合中的照片
|
||||||
|
:param detect_id:
|
||||||
|
:param files:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
detect_out = pdc.get_detect_by_id(detect_id, session)
|
||||||
|
if detect_out is None:
|
||||||
|
return rc.response_error("训练集合查询失败,请刷新后再试")
|
||||||
|
is_check, file_name = pds.check_image_name(detect_id, files, session)
|
||||||
|
if not is_check:
|
||||||
|
return rc.response_error(msg="存在重名的图片文件:" + file_name)
|
||||||
|
pds.upload_detect_imgs(detect_out, files, session)
|
||||||
|
return rc.response_success("上传成功")
|
||||||
|
|
||||||
|
|
||||||
|
@detect.get("/del_detect_img/{detect_img_id}")
|
||||||
|
def del_detect_img(detect_img_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
删除训练集合照片
|
||||||
|
:param detect_img_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = pds.del_detect_img(detect_img_id, session)
|
||||||
|
if result > 0:
|
||||||
|
return rc.response_success(msg="删除成功")
|
||||||
|
else:
|
||||||
|
return rc.response_error(msg="删除失败")
|
||||||
|
|
||||||
|
|
||||||
|
@detect.get("/run_detect_yolo")
|
||||||
|
def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
开始执行训练
|
||||||
|
:param detect_log_in:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
detect = pdc.get_detect_by_id(detect_log_in.detect_id, session)
|
||||||
|
if detect is None:
|
||||||
|
return rc.response_error("训练集合不存在")
|
||||||
|
train = get_train(detect_log_in.train_id, session)
|
||||||
|
if train is None:
|
||||||
|
return rc.response_error("训练权重不存在")
|
||||||
|
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
|
||||||
|
return StreamingResponse(pds.run_commend(detect_log.pt_url,
|
||||||
|
detect_log.folder_url,
|
||||||
|
detect_log.detect_folder_url,
|
||||||
|
detect_log.detect_version,
|
||||||
|
detect_log.id, detect_log.detect_id, session), media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
@detect.get("/get_log_list/{detect_id}")
|
||||||
|
def get_log_list(detect_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
根据推理集合id获取推理记录
|
||||||
|
:param detect_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = pdc.get_log_list(detect_id, session)
|
||||||
|
result = jsonable_encoder(result)
|
||||||
|
return rc.response_success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@detect.get("/get_log_imgs/{log_id}")
|
||||||
|
def get_log_imgs(log_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
根据推理集合中的结果图片
|
||||||
|
:param log_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = pdc.get_log_imgs(log_id, session)
|
||||||
|
result = jsonable_encoder(result)
|
||||||
|
return rc.response_success(data=result)
|
||||||
|
@ -32,7 +32,7 @@ def get_type_list(session: Session = Depends(get_db)):
|
|||||||
@project.post("/list")
|
@project.post("/list")
|
||||||
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
|
项目列表
|
||||||
:param info:
|
:param info:
|
||||||
:param session:
|
:param session:
|
||||||
:return:
|
:return:
|
||||||
|
@ -4,6 +4,7 @@ from starlette.responses import FileResponse
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.model.crud.project_image_crud import get_img_url
|
from app.model.crud.project_image_crud import get_img_url
|
||||||
|
from app.model.crud.project_detect_crud import get_detect_img_url
|
||||||
from app.config.config_reader import images_url
|
from app.config.config_reader import images_url
|
||||||
from app.db.db_session import get_db
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
@ -40,3 +41,38 @@ def view_thumb(image_id: int, session: Session = Depends(get_db)):
|
|||||||
if not os.path.isfile(image_path):
|
if not os.path.isfile(image_path):
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
return FileResponse(image_path, media_type='image/jpeg')
|
return FileResponse(image_path, media_type='image/jpeg')
|
||||||
|
|
||||||
|
|
||||||
|
@view.get("/view_detect_img/{image_id}")
|
||||||
|
def view_detect_img(image_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
查看图片
|
||||||
|
:param session:
|
||||||
|
:param image_id: 图片id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
sour_url, thumb_url = get_detect_img_url(image_id, session)
|
||||||
|
image_path = os.path.join(images_url, sour_url)
|
||||||
|
# 检查文件是否存在以及是否是文件
|
||||||
|
if not os.path.isfile(image_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
return FileResponse(image_path, media_type='image/jpeg')
|
||||||
|
|
||||||
|
|
||||||
|
@view.get("/view_detect_thumb/{image_id}")
|
||||||
|
def view_detect_thumb(image_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
查看图片
|
||||||
|
:param session:
|
||||||
|
:param image_id: 图片id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
sour_url, thumb_url = get_detect_img_url(image_id, session)
|
||||||
|
image_path = os.path.join(images_url, thumb_url)
|
||||||
|
# 检查文件是否存在以及是否是文件
|
||||||
|
if not os.path.isfile(image_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
return FileResponse(image_path, media_type='image/jpeg')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from app.api.sys.login_api import login
|
|||||||
from app.api.sys.sys_user_api import user
|
from app.api.sys.sys_user_api import user
|
||||||
from app.api.business.project_train_api import project
|
from app.api.business.project_train_api import project
|
||||||
from app.api.common.view_img import view
|
from app.api.common.view_img import view
|
||||||
|
from app.api.business.project_detect_api import detect
|
||||||
|
|
||||||
my_app = FastAPI()
|
my_app = FastAPI()
|
||||||
|
|
||||||
@ -33,5 +34,6 @@ my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
|
|||||||
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
||||||
my_app.include_router(view, tags=["查看图片"])
|
my_app.include_router(view, tags=["查看图片"])
|
||||||
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
||||||
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
|
my_app.include_router(project, prefix="/proj", tags=["项目训练API"])
|
||||||
|
my_app.include_router(detect, prefix="/detect", tags=["项目推理API"])
|
||||||
|
|
||||||
|
@ -116,9 +116,11 @@ class ProjectDetectLog(DbCommon):
|
|||||||
__tablename__ = "project_detect_log"
|
__tablename__ = "project_detect_log"
|
||||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
detect_version: Mapped[str] = mapped_column(String(10))
|
detect_version: Mapped[str] = mapped_column(String(10))
|
||||||
|
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
train_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
train_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
train_version: Mapped[str] = mapped_column(String(10))
|
train_version: Mapped[str] = mapped_column(String(10))
|
||||||
pt_type: Mapped[str] = mapped_column(String(32))
|
pt_type: Mapped[str] = mapped_column(String(10))
|
||||||
|
pt_url: Mapped[str] = mapped_column(String(255))
|
||||||
folder_url: Mapped[str] = mapped_column(String(255))
|
folder_url: Mapped[str] = mapped_column(String(255))
|
||||||
detect_folder_url: Mapped[str] = mapped_column(String(255))
|
detect_folder_url: Mapped[str] = mapped_column(String(255))
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@ from typing import List
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import asc, and_
|
from sqlalchemy import asc, and_
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi import UploadFile
|
|
||||||
|
|
||||||
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
||||||
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager
|
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
|
||||||
|
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut
|
||||||
from app.db.page_util import get_pager
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +21,18 @@ def add_detect(detect: ProjectDetect, session: Session):
|
|||||||
return detect
|
return detect
|
||||||
|
|
||||||
|
|
||||||
|
def get_detect_by_id(detect_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
根据id查询训练集合
|
||||||
|
:param detect_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
|
||||||
|
detect_out = ProjectDetectOut.from_orm(detect)
|
||||||
|
return detect_out
|
||||||
|
|
||||||
|
|
||||||
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||||||
"""
|
"""
|
||||||
查询推理集合分页数据
|
查询推理集合分页数据
|
||||||
@ -39,6 +51,26 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
|||||||
return pager
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def update_detect_status(detect_id: int, detect_status: int, session: Session):
|
||||||
|
"""
|
||||||
|
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||||||
|
:param detect_id:
|
||||||
|
:param detect_status: 0-未运行,1-运行中,2-已完成,-1-执行失败
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if detect_status == 2:
|
||||||
|
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||||||
|
'detect_status': detect_status,
|
||||||
|
'detect_version': ProjectDetect.detect_version + 1
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||||||
|
'detect_status': detect_status
|
||||||
|
})
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
def check_img_name(detect_id: int, file_name: str, session: Session):
|
def check_img_name(detect_id: int, file_name: str, session: Session):
|
||||||
"""
|
"""
|
||||||
校验上传的图片名称是否重名
|
校验上传的图片名称是否重名
|
||||||
@ -65,6 +97,40 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_list(detect_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
获取训练集合中的图片列表
|
||||||
|
:param detect_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
|
||||||
|
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
|
||||||
|
"""
|
||||||
|
获取训练集合中的图片列表,返回pager对象
|
||||||
|
:param detect_img_pager:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
|
||||||
|
.order_by(asc(ProjectDetectImg.id))
|
||||||
|
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
|
||||||
|
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
|
||||||
|
pager.data = jsonable_encoder(pager.data)
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_detect_img_url(image_id: int, session: Session):
|
||||||
|
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
|
||||||
|
sour_url = result.image_url
|
||||||
|
thumb_url = result.thumb_image_url
|
||||||
|
return sour_url, thumb_url
|
||||||
|
|
||||||
|
|
||||||
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
||||||
"""
|
"""
|
||||||
新增推理记录
|
新增推理记录
|
||||||
@ -86,3 +152,22 @@ def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Ses
|
|||||||
"""
|
"""
|
||||||
session.add_all(detect_log_imgs)
|
session.add_all(detect_log_imgs)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_list(detect_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
获取推理记录
|
||||||
|
:param detect_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
|
||||||
|
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_imgs(log_id: int, session: Session):
|
||||||
|
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
|
||||||
|
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class ProjectDetectPager(BaseModel):
|
|||||||
class ProjectDetectOut(BaseModel):
|
class ProjectDetectOut(BaseModel):
|
||||||
id: Optional[int]
|
id: Optional[int]
|
||||||
project_id: Optional[int]
|
project_id: Optional[int]
|
||||||
detect_name: Optional[int]
|
detect_name: Optional[str]
|
||||||
detect_no: Optional[str]
|
detect_no: Optional[str]
|
||||||
detect_version: Optional[int]
|
detect_version: Optional[int]
|
||||||
file_type: Optional[str]
|
file_type: Optional[str]
|
||||||
@ -33,6 +33,25 @@ class ProjectDetectOut(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectImgPager(BaseModel):
|
||||||
|
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||||
|
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||||
|
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectImageOut(BaseModel):
|
||||||
|
id: Optional[int] = Field(None, description="id")
|
||||||
|
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||||
|
file_name: Optional[str] = Field(None, description="文件名称")
|
||||||
|
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProjectDetectLogIn(BaseModel):
|
class ProjectDetectLogIn(BaseModel):
|
||||||
detect_id: Optional[int] = Field(..., description="推理集合id")
|
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||||
train_id: Optional[int] = Field(..., description="训练结果id")
|
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||||
@ -43,6 +62,7 @@ class ProjectDetectLogOut(BaseModel):
|
|||||||
id: Optional[int]
|
id: Optional[int]
|
||||||
detect_id: Optional[int]
|
detect_id: Optional[int]
|
||||||
detect_version: Optional[str]
|
detect_version: Optional[str]
|
||||||
|
detect_name: Optional[str]
|
||||||
train_id: Optional[int]
|
train_id: Optional[int]
|
||||||
train_version: Optional[int]
|
train_version: Optional[int]
|
||||||
pt_type: Optional[str]
|
pt_type: Optional[str]
|
||||||
|
@ -5,7 +5,7 @@ import subprocess
|
|||||||
|
|
||||||
from app.model.crud import project_detect_crud as pdc
|
from app.model.crud import project_detect_crud as pdc
|
||||||
from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn
|
from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn
|
||||||
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog
|
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog, ProjectDetectLogImg
|
||||||
from app.util.random_utils import random_str
|
from app.util.random_utils import random_str
|
||||||
from app.config.config_reader import detect_url
|
from app.config.config_reader import detect_url
|
||||||
from app.util import os_utils as os
|
from app.util import os_utils as os
|
||||||
@ -23,6 +23,7 @@ def add_detect(detect_in: ProjectDetectIn, session: Session):
|
|||||||
detect = ProjectDetect(**detect_in.dict())
|
detect = ProjectDetect(**detect_in.dict())
|
||||||
detect.detect_no = random_str(6)
|
detect.detect_no = random_str(6)
|
||||||
detect.detect_version = 0
|
detect.detect_version = 0
|
||||||
|
detect.detect_status = '0'
|
||||||
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||||
detect.folder_url = url
|
detect.folder_url = url
|
||||||
detect = pdc.add_detect(detect, session)
|
detect = pdc.add_detect(detect, session)
|
||||||
@ -67,6 +68,22 @@ def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], sessio
|
|||||||
pdc.add_detect_imgs(images, session)
|
pdc.add_detect_imgs(images, session)
|
||||||
|
|
||||||
|
|
||||||
|
def del_detect_img(detect_img_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
删除训练集合图片
|
||||||
|
:param detect_img_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
detect_img = session.query(ProjectDetectImg).filter_by(id=detect_img_id).first()
|
||||||
|
if detect_img is None:
|
||||||
|
return 0
|
||||||
|
os.delete_file_if_exists(detect_img.image_url, detect_img.thumb_image_url)
|
||||||
|
session.delete(detect_img)
|
||||||
|
session.commit()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session):
|
def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session):
|
||||||
"""
|
"""
|
||||||
开始推理
|
开始推理
|
||||||
@ -94,25 +111,35 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train:
|
|||||||
detect_log.train_id = train.id
|
detect_log.train_id = train.id
|
||||||
detect_log.train_version = train.train_version
|
detect_log.train_version = train.train_version
|
||||||
detect_log.pt_type = detect_in.pt_type
|
detect_log.pt_type = detect_in.pt_type
|
||||||
detect_log.folder_url = detect.folder_url
|
detect_log.pt_url = pt_url
|
||||||
|
detect_log.folder_url = img_url
|
||||||
detect_log.detect_folder_url = out_url
|
detect_log.detect_folder_url = out_url
|
||||||
detect_log = pdc.add_detect_log(detect_log, session)
|
detect_log = pdc.add_detect_log(detect_log, session)
|
||||||
return detect_log
|
return detect_log
|
||||||
|
|
||||||
|
|
||||||
def run_commend(weights: str, source: str, project: str, name: str,
|
def run_commend(weights: str, source: str, project: str, name: str,
|
||||||
detect_log_id: int, session: Session):
|
log_id: int, detect_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
执行yolov5的推理
|
||||||
|
:param weights: 权重文件
|
||||||
|
:param source: 图片所在文件
|
||||||
|
:param project: 推理完成的文件位置
|
||||||
|
:param name: 版本名称
|
||||||
|
:param log_id: 日志id
|
||||||
|
:param detect_id: 推理集合id
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||||
|
|
||||||
yield f"stdout: 模型推理开始,请稍等。。。 \n"
|
yield f"stdout: 模型推理开始,请稍等。。。 \n"
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
["python", '-u', yolo_path,
|
["python", '-u', yolo_path,
|
||||||
"--weights =" + weights,
|
"--weights", weights,
|
||||||
"--source =" + source,
|
"--source", source,
|
||||||
"--name=" + name,
|
"--name", name,
|
||||||
"--project=" + project,
|
"--project", project],
|
||||||
"--view-img"],
|
|
||||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||||
shell=False,
|
shell=False,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
@ -126,6 +153,24 @@ def run_commend(weights: str, source: str, project: str, name: str,
|
|||||||
if line != '\n':
|
if line != '\n':
|
||||||
yield line
|
yield line
|
||||||
|
|
||||||
|
# 等待进程结束并获取返回码
|
||||||
|
return_code = process.wait()
|
||||||
|
if return_code != 0:
|
||||||
|
pdc.update_detect_status(detect_id, -1, session)
|
||||||
|
else:
|
||||||
|
pdc.update_detect_status(detect_id, 2, session)
|
||||||
|
detect_imgs = pdc.get_img_list(detect_id, session)
|
||||||
|
detect_log_imgs = []
|
||||||
|
for detect_img in detect_imgs:
|
||||||
|
detect_log_img = ProjectDetectLogImg()
|
||||||
|
detect_log_img.log_id = log_id
|
||||||
|
image_url = os.file_path(project, name, detect_img.file_name)
|
||||||
|
detect_log_img.image_url = image_url
|
||||||
|
detect_log_img.file_name = detect_img.file_name
|
||||||
|
detect_log_imgs.append(detect_log_img)
|
||||||
|
pdc.add_detect_imgs(detect_log_imgs, session)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ def run_commend(data: str, project: str,
|
|||||||
train = ProjectTrain()
|
train = ProjectTrain()
|
||||||
train.project_id = project_id
|
train.project_id = project_id
|
||||||
train.train_version = name
|
train.train_version = name
|
||||||
bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt')
|
bast_pt_path = os.file_path(project, name, 'weights', 'best.pt')
|
||||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||||
train.best_pt = bast_pt_path
|
train.best_pt = bast_pt_path
|
||||||
train.last_pt = last_pt_path
|
train.last_pt = last_pt_path
|
||||||
|
Reference in New Issue
Block a user