完成推理模块的转移
This commit is contained in:
@ -5,10 +5,18 @@
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 数据访问层
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from core.crud import DalBase
|
||||
from . import schemas, models
|
||||
from utils.random_utils import random_str
|
||||
from utils import os_utils as os
|
||||
from application.settings import detect_url
|
||||
from utils.huawei_obs import ObsClient
|
||||
from utils import status
|
||||
from core.exception import CustomException
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
class ProjectDetectDal(DalBase):
|
||||
@ -17,17 +25,86 @@ class ProjectDetectDal(DalBase):
|
||||
super(ProjectDetectDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetect
|
||||
self.schema = schemas.ProjectDetectSimpleOut
|
||||
self.schema = schemas.ProjectDetectOut
|
||||
|
||||
async def check_name(self, name: str,project_id: int) -> bool:
|
||||
"""
|
||||
校验推理集合名称
|
||||
"""
|
||||
count = self.get_count(
|
||||
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
|
||||
return count > 0
|
||||
|
||||
async def add_detect(self, data: schemas.ProjectDetectIn):
|
||||
"""
|
||||
新增集合
|
||||
"""
|
||||
detect = models.ProjectDetect(**data.model_dump())
|
||||
detect.detect_no = random_str(6)
|
||||
detect.detect_version = 0
|
||||
detect.detect_status = '0'
|
||||
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||
detect.folder_url = url
|
||||
await self.create_data(data)
|
||||
return detect
|
||||
|
||||
async def delete_detects(self, ids: list[int]):
|
||||
"""
|
||||
删除集合数据+文件夹的文件夹+每次推理日志的文件
|
||||
"""
|
||||
for id_ in ids:
|
||||
detect_info = await self.get_data(data_id=id_)
|
||||
if detect_info.file_type != 'rtsp':
|
||||
os.delete_paths(detect_info.folder_url)
|
||||
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids])
|
||||
for log in logs:
|
||||
os.delete_paths(log.folder_url)
|
||||
await self.delete_datas(ids=ids, v_soft=False)
|
||||
|
||||
|
||||
class ProjectDetectImgDal(DalBase):
|
||||
class ProjectDetectFileDal(DalBase):
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super(ProjectDetectImgDal, self).__init__()
|
||||
super(ProjectDetectFileDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetectImg
|
||||
self.schema = schemas.ProjectDetectImgSimpleOut
|
||||
self.model = models.ProjectDetectFile
|
||||
self.schema = schemas.ProjectDetectFileOut
|
||||
|
||||
async def file_count(self, detect_id: int) -> int:
|
||||
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
return count
|
||||
|
||||
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]):
|
||||
images = []
|
||||
for file in files:
|
||||
image = models.ProjectDetectFile()
|
||||
image.detect_id = detect.id
|
||||
image.file_name = file.filename
|
||||
# 保存原图
|
||||
path = os.save_images(detect.folder_url, file=file)
|
||||
image.image_url = path
|
||||
# 上传到obs
|
||||
object_key = detect.detect_no + '/' + file.filename
|
||||
success, key, url = ObsClient.put_file(object_key=object_key, file_path=path)
|
||||
if success:
|
||||
image.object_key = object_key
|
||||
image.thumb_image_url = url
|
||||
else:
|
||||
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
|
||||
images.append(image)
|
||||
await self.create_datas(images)
|
||||
|
||||
async def delete_files(self, ids: list[int]):
|
||||
file_urls = []
|
||||
object_keys = []
|
||||
for id_ in ids:
|
||||
file = self.get_data(data_id=id_)
|
||||
if file:
|
||||
file_urls.append(file.file_url)
|
||||
object_keys.append(file.object_key)
|
||||
os.delete_paths(file_urls)
|
||||
ObsClient.del_objects(object_keys)
|
||||
await self.delete_datas(ids, v_soft=False)
|
||||
|
||||
class ProjectDetectLogDal(DalBase):
|
||||
|
||||
@ -38,10 +115,10 @@ class ProjectDetectLogDal(DalBase):
|
||||
self.schema = schemas.ProjectDetectLogSimpleOut
|
||||
|
||||
|
||||
class ProjectDetectLogImgDal(DalBase):
|
||||
class ProjectDetectLogFileDal(DalBase):
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super(ProjectDetectLogImgDal, self).__init__()
|
||||
super(ProjectDetectLogFileDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetectLogImg
|
||||
self.schema = schemas.ProjectDetectLogImgSimpleOut
|
||||
self.model = models.ProjectDetectLogFile
|
||||
self.schema = schemas.ProjectDetectLogFileOut
|
||||
|
@ -0,0 +1 @@
|
||||
from .detect import ProjectDetect, ProjectDetectFile, ProjectDetectLog, ProjectDetectLogFile
|
@ -22,17 +22,18 @@ class ProjectDetect(BaseModel):
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class ProjectDetectImg(BaseModel):
|
||||
class ProjectDetectFile(BaseModel):
|
||||
"""
|
||||
待推理图片
|
||||
"""
|
||||
__tablename__ = "project_detect_img"
|
||||
__table_args__ = ({'comment': '待推理图片'})
|
||||
__tablename__ = "project_detect_file"
|
||||
__table_args__ = ({'comment': '待推理文件'})
|
||||
|
||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
object_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
@ -55,13 +56,13 @@ class ProjectDetectLog(BaseModel):
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class ProjectDetectLogImg(BaseModel):
|
||||
class ProjectDetectLogFile(BaseModel):
|
||||
"""
|
||||
推理完成的图片
|
||||
"""
|
||||
__tablename__ = "project_detect_log_img"
|
||||
__tablename__ = "project_detect_log_file"
|
||||
__table_args__ = ({'comment': '项目训练版本信息表'})
|
||||
|
||||
log_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
@ -1,4 +1,4 @@
|
||||
from .project_detect import ProjectDetectParams
|
||||
from .project_detect_img import ProjectDetectImgParams
|
||||
from .project_detect_file import ProjectDetectFileParams
|
||||
from .project_detect_log import ProjectDetectLogParams
|
||||
from .project_detect_log_img import ProjectDetectLogImgParams
|
||||
from .project_detect_log_file import ProjectDetectLogFileParams
|
||||
|
@ -6,10 +6,14 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合信息
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: int | None = Query(None, title="项目id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.project_id = project_id
|
||||
|
19
apps/business/detect/params/project_detect_file.py
Normal file
19
apps/business/detect/params/project_detect_file.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合图片信息
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectFileParams(QueryParams):
|
||||
def __init__(
|
||||
self,
|
||||
detect_id: int | None = Query(0, title="推理集合id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.detect_id = detect_id
|
@ -1,15 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合图片信息
|
||||
|
||||
from fastapi import Depends
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectImgParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
super().__init__(params)
|
@ -6,10 +6,14 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录信息
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
def __init__(
|
||||
self,
|
||||
detect_id: int | None = Query(0, title="推理集合id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.detect_id = detect_id
|
||||
|
19
apps/business/detect/params/project_detect_log_file.py
Normal file
19
apps/business/detect/params/project_detect_log_file.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录图片信息
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogFileParams(QueryParams):
|
||||
def __init__(
|
||||
self,
|
||||
log_id: int | None = Query(0, title="推理记录id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.log_id = log_id
|
@ -1,15 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录图片信息
|
||||
|
||||
from fastapi import Depends
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogImgParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
super().__init__(params)
|
@ -1,4 +1,4 @@
|
||||
from .project_detect import ProjectDetect, ProjectDetectSimpleOut
|
||||
from .project_detect_img import ProjectDetectImg, ProjectDetectImgSimpleOut
|
||||
from .project_detect_log import ProjectDetectLog, ProjectDetectLogSimpleOut
|
||||
from .project_detect_log_img import ProjectDetectLogImg, ProjectDetectLogImgSimpleOut
|
||||
from .project_detect import ProjectDetectIn, ProjectDetectPager, ProjectDetectOut, ProjectDetectList
|
||||
from .project_detect_file import ProjectDetectFilePager, ProjectDetectFileOut
|
||||
from .project_detect_log import ProjectDetectLogIn, ProjectDetectLogOut
|
||||
from .project_detect_log_file import ProjectDetectLogFileOut
|
||||
|
@ -7,24 +7,43 @@
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetect(BaseModel):
|
||||
project_id: int = Field(..., title="None")
|
||||
detect_name: str = Field(..., title="None")
|
||||
detect_version: int = Field(..., title="None")
|
||||
detect_no: str = Field(..., title="None")
|
||||
detect_status: int = Field(..., title="None")
|
||||
file_type: str = Field(..., title="None")
|
||||
folder_url: str = Field(..., title="None")
|
||||
rtsp_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
class ProjectDetectIn(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
file_type: Optional[str] = Field('img', description="推理集合文件类别")
|
||||
detect_name: Optional[str] = Field(..., description="推理集合名称")
|
||||
rtsp_url: Optional[str] = Field(None, description="视频流地址")
|
||||
|
||||
|
||||
class ProjectDetectSimpleOut(ProjectDetect):
|
||||
class ProjectDetectPager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
detect_name: Optional[str] = Field(None, description="推理集合名称")
|
||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
||||
|
||||
class ProjectDetectOut(BaseModel):
|
||||
id: Optional[int]
|
||||
project_id: Optional[int]
|
||||
detect_name: Optional[str]
|
||||
detect_no: Optional[str]
|
||||
detect_version: Optional[int]
|
||||
file_type: Optional[str]
|
||||
folder_url: Optional[str]
|
||||
rtsp_url: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectDetectList(BaseModel):
|
||||
id: Optional[int]
|
||||
file_type: Optional[str]
|
||||
detect_name: Optional[str]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
26
apps/business/detect/schemas/project_detect_file.py
Normal file
26
apps/business/detect/schemas/project_detect_file.py
Normal file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectFilePager(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||
|
||||
|
||||
class ProjectDetectFileOut(BaseModel):
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
file_name: Optional[str] = Field(None, description="文件名称")
|
||||
thumb_file_url: Optional[str] = Field(None, description="文件路径")
|
||||
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
@ -1,26 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
|
||||
|
||||
class ProjectDetectImg(BaseModel):
|
||||
detect_id: int = Field(..., title="None")
|
||||
file_name: str = Field(..., title="None")
|
||||
image_url: str = Field(..., title="None")
|
||||
thumb_image_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
|
||||
|
||||
class ProjectDetectImgSimpleOut(ProjectDetectImg):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
@ -7,25 +7,24 @@
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectLog(BaseModel):
|
||||
detect_id: int = Field(..., title="None")
|
||||
detect_version: str = Field(..., title="None")
|
||||
detect_name: str = Field(..., title="None")
|
||||
train_id: int = Field(..., title="None")
|
||||
train_version: str = Field(..., title="None")
|
||||
pt_type: str = Field(..., title="None")
|
||||
pt_url: str = Field(..., title="None")
|
||||
folder_url: str = Field(..., title="None")
|
||||
detect_folder_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
class ProjectDetectLogIn(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||
pt_type: Optional[str] = Field('best', description="权重文件类型")
|
||||
|
||||
|
||||
class ProjectDetectLogSimpleOut(ProjectDetectLog):
|
||||
class ProjectDetectLogOut(BaseModel):
|
||||
id: Optional[int]
|
||||
detect_id: Optional[int]
|
||||
detect_version: Optional[str]
|
||||
detect_name: Optional[str]
|
||||
train_id: Optional[int]
|
||||
train_version: Optional[str]
|
||||
pt_type: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
||||
|
20
apps/business/detect/schemas/project_detect_log_file.py
Normal file
20
apps/business/detect/schemas/project_detect_log_file.py
Normal file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectLogFileOut(BaseModel):
|
||||
id: Optional[int]
|
||||
file_name: Optional[str]
|
||||
thumb_file_url: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
@ -1,24 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
|
||||
|
||||
class ProjectDetectLogImg(BaseModel):
|
||||
log_id: int = Field(..., title="None")
|
||||
file_name: str = Field(..., title="None")
|
||||
image_url: str = Field(..., title="None")
|
||||
|
||||
|
||||
class ProjectDetectLogImgSimpleOut(ProjectDetectLogImg):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
226
apps/business/detect/service.py
Normal file
226
apps/business/detect/service.py
Normal file
@ -0,0 +1,226 @@
|
||||
from application.settings import yolo_url, detect_url
|
||||
from utils.websocket_server import room_manager
|
||||
from utils import os_utils as os
|
||||
from . import models, crud, schemas
|
||||
from apps.business.train import models as train_models
|
||||
from utils.yolov5.models.common import DetectMultiBackend
|
||||
from utils.yolov5.utils.torch_utils import select_device
|
||||
from utils.yolov5.utils.dataloaders import LoadStreams
|
||||
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
import time
|
||||
import torch
|
||||
import asyncio
|
||||
import subprocess
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def before_detect(
|
||||
detect_in: schemas.ProjectDetectLogIn,
|
||||
detect: models.ProjectDetect,
|
||||
train: train_models.ProjectTrain,
|
||||
db: AsyncSession):
|
||||
"""
|
||||
开始推理
|
||||
:param detect:
|
||||
:param detect_in:
|
||||
:param train:
|
||||
:param db:
|
||||
:return:
|
||||
"""
|
||||
# 推理版本
|
||||
version_path = 'v' + str(detect.detect_version + 1)
|
||||
|
||||
# 权重文件
|
||||
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
|
||||
|
||||
# 推理集合文件路径
|
||||
img_url = detect.folder_url
|
||||
|
||||
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
|
||||
|
||||
# 构建推理记录数据
|
||||
detect_log = models.ProjectDetectLog()
|
||||
detect_log.detect_name = detect.detect_name
|
||||
detect_log.detect_id = detect.id
|
||||
detect_log.detect_version = version_path
|
||||
detect_log.train_id = train.id
|
||||
detect_log.train_version = train.train_version
|
||||
detect_log.pt_type = detect_in.pt_type
|
||||
detect_log.pt_url = pt_url
|
||||
detect_log.folder_url = img_url
|
||||
detect_log.detect_folder_url = out_url
|
||||
await crud.ProjectDetectLogDal(db).create_data(detect_log)
|
||||
return detect_log
|
||||
|
||||
|
||||
async def run_detect_img(
|
||||
weights: str,
|
||||
source: str,
|
||||
project: str,
|
||||
name: str,
|
||||
log_id: int,
|
||||
detect_id: int,
|
||||
db: AsyncSession,
|
||||
rd: Redis):
|
||||
"""
|
||||
执行yolov5的推理
|
||||
:param weights: 权重文件
|
||||
:param source: 图片所在文件
|
||||
:param project: 推理完成的文件位置
|
||||
:param name: 版本名称
|
||||
:param log_id: 日志id
|
||||
:param detect_id: 推理集合id
|
||||
:param db: 数据库session
|
||||
:param rd: Redis
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
room = 'detect_' + str(detect_id)
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
|
||||
project, "--save-txt", "--conf-thres", "0.4"]
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device", "0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||
encoding='utf-8',
|
||||
) as process:
|
||||
while process.poll() is None:
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
await room_manager.send_to_room(room, 'error')
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
detect_files = crud.ProjectDetectFileDal(db).get_data(
|
||||
v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
detect_log_files = []
|
||||
for detect_file in detect_files:
|
||||
detect_log_img = models.ProjectDetectLogFile()
|
||||
detect_log_img.log_id = log_id
|
||||
image_url = os.file_path(project, name, detect_file.file_name)
|
||||
detect_log_img.image_url = image_url
|
||||
detect_log_img.file_name = detect_file.file_name
|
||||
detect_log_files.append(detect_log_img)
|
||||
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
|
||||
|
||||
|
||||
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
"""
|
||||
rtsp 视频流推理
|
||||
:param detect_id: 训练集的id
|
||||
:param weights_pt: 权重文件
|
||||
:param rtsp_url: 视频流地址
|
||||
:param data: yaml文件
|
||||
:param rd: Redis :redis
|
||||
:return:
|
||||
"""
|
||||
room = 'detect_rtsp_' + str(detect_id)
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
device = select_device('cuda:0')
|
||||
|
||||
# 加载模型
|
||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
|
||||
stride, names, pt = model.stride, model.names, model.pt
|
||||
imgsz = check_img_size((640, 640), s=stride) # check image size
|
||||
|
||||
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
|
||||
bs = len(dataset)
|
||||
|
||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
|
||||
|
||||
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
|
||||
|
||||
time.sleep(3) # 等待3s,等待websocket进入
|
||||
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
if room_manager.rooms.get(room):
|
||||
with dt[0]:
|
||||
im = torch.from_numpy(im).to(model.device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if model.xml and im.shape[0] > 1:
|
||||
ims = torch.chunk(im, im.shape[0], 0)
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
if model.xml and im.shape[0] > 1:
|
||||
pred = None
|
||||
for image in ims:
|
||||
if pred is None:
|
||||
pred = model(image, augment=False, visualize=False).unsqueeze(0)
|
||||
else:
|
||||
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
|
||||
dim=0)
|
||||
pred = [pred, None]
|
||||
else:
|
||||
pred = model(im, augment=False, visualize=False)
|
||||
# NMS
|
||||
with dt[2]:
|
||||
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
|
||||
|
||||
# Process predictions
|
||||
for i, det in enumerate(pred): # per image
|
||||
p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
||||
annotator = Annotator(im0, line_width=3, example=str(names))
|
||||
if len(det):
|
||||
# Rescale boxes from img_size to im0 size
|
||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
|
||||
# Write results
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls) # integer class
|
||||
label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
|
||||
annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
|
||||
# Stream results
|
||||
im0 = annotator.result()
|
||||
# 将帧编码为 JPEG
|
||||
ret, jpeg = cv2.imencode('.jpg', im0)
|
||||
if ret:
|
||||
frame_data = jpeg.tobytes()
|
||||
await room_manager.send_stream_to_room(room, frame_data)
|
||||
else:
|
||||
print(room, '结束推理')
|
||||
break
|
||||
|
||||
|
||||
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
@ -5,15 +5,21 @@
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
from core.dependencies import IdList
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from core.database import db_getter
|
||||
from . import schemas, crud, models, params
|
||||
from fastapi import Depends, APIRouter
|
||||
from utils.response import SuccessResponse
|
||||
|
||||
import service
|
||||
from . import schemas, crud, params
|
||||
from core.dependencies import IdList
|
||||
from core.database import redis_getter
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
|
||||
import threading
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
@ -22,129 +28,120 @@ app = APIRouter()
|
||||
###########################################################
|
||||
# 项目推理集合信息
|
||||
###########################################################
|
||||
@app.get("/project/detect", summary="获取项目推理集合信息列表", tags=["项目推理集合信息"])
|
||||
async def get_project_detect_list(p: params.ProjectDetectParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/list", summary="获取项目推理集合信息列表")
|
||||
async def detect_list(
|
||||
p: params.ProjectDetectParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect", summary="创建项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def create_project_detect(data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).create_data(data=data))
|
||||
@app.post("/", summary="创建项目推理集合信息")
|
||||
async def add_detect(
|
||||
data: schemas.ProjectDetectIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
if await detect_dal.check_name(data.detect_name, data.project_id):
|
||||
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
||||
await detect_dal.create_data(data=data)
|
||||
return SuccessResponse(msg="保存成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect", summary="删除项目推理集合信息", description="硬删除", tags=["项目推理集合信息"])
|
||||
async def delete_project_detect_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.delete("/", summary="删除项目推理集合信息")
|
||||
async def delete_detect(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/{data_id}", summary="更新项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def put_project_detect(data_id: int, data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/{data_id}", summary="获取项目推理集合信息信息", tags=["项目推理集合信息"])
|
||||
async def get_project_detect(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理集合图片信息
|
||||
# 项目推理集合文件信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/img", summary="获取项目推理集合图片信息列表", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img_list(p: params.ProjectDetectImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
||||
async def file_list(
|
||||
p: params.ProjectDetectFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
if p.limit:
|
||||
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
else:
|
||||
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/project/detect/img", summary="创建项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def create_project_detect_img(data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).create_data(data=data))
|
||||
@app.post("/file", summary="上传项目推理集合文件")
|
||||
async def upload_file(
|
||||
detect_id: int = Form(...),
|
||||
files: list[UploadFile] = Form(...),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
file_dal = crud.ProjectDetectFileDal(auth.db)
|
||||
detect_out = file_dal.get_data(data_id=detect_id)
|
||||
if detect_out is None:
|
||||
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
||||
await file_dal.add_file(detect_out, files)
|
||||
return SuccessResponse(msg="上传成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect/img", summary="删除项目推理集合图片信息", description="硬删除", tags=["项目推理集合图片信息"])
|
||||
async def delete_project_detect_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
@app.delete("/file", summary="删除项目推理集合文件信息")
|
||||
async def delete_file(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/img/{data_id}", summary="更新项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def put_project_detect_img(data_id: int, data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/img/{data_id}", summary="获取项目推理集合图片信息信息", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
@app.post("/detect", summary="开始推理")
|
||||
def run_detect_yolo(
|
||||
detect_log_in: schemas.ProjectDetectLogIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
train_dal = ProjectTrainDal(auth.db)
|
||||
detect = detect_dal.get_data(detect_log_in.detect_id)
|
||||
if detect is None:
|
||||
return ErrorResponse(msg="训练集合不存在")
|
||||
train = train_dal.get_data(detect_log_in.train_id)
|
||||
if train is None:
|
||||
return ErrorResponse("训练权重不存在")
|
||||
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
if file_count == 0 and detect.rtsp_url is None:
|
||||
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
||||
if detect.file_type == 'img' or detect.file_type == 'video':
|
||||
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
||||
thread_train = threading.Thread(target=service.run_img_loop,
|
||||
args=(detect_log.pt_url, detect_log.folder_url,
|
||||
detect_log.detect_folder_url, detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, auth.db,))
|
||||
thread_train.start()
|
||||
elif detect.file_type == 'rtsp':
|
||||
room = 'detect_rtsp_' + str(detect.id)
|
||||
if not room_manager.rooms.get(room):
|
||||
if detect_log_in.pt_type == 'best':
|
||||
weights_pt = train.best_pt
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log", summary="获取项目推理记录信息列表", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log_list(p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/log", summary="获取项目推理记录列表")
|
||||
async def log_pager(
|
||||
p: params.ProjectDetectLogParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log", summary="创建项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def create_project_detect_log(data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log", summary="删除项目推理记录信息", description="硬删除", tags=["项目推理记录信息"])
|
||||
async def delete_project_detect_log_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/{data_id}", summary="更新项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def put_project_detect_log(data_id: int, data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/{data_id}", summary="获取项目推理记录信息信息", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录图片信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log/img", summary="获取项目推理记录图片信息列表", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img_list(p: params.ProjectDetectLogImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log/img", summary="创建项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def create_project_detect_log_img(data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log/img", summary="删除项目推理记录图片信息", description="硬删除", tags=["项目推理记录图片信息"])
|
||||
async def delete_project_detect_log_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/img/{data_id}", summary="更新项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def put_project_detect_log_img(data_id: int, data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/img/{data_id}", summary="获取项目推理记录图片信息信息", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(db).get_data(data_id, v_schema=schema))
|
||||
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
||||
async def log_files(
|
||||
p: params.ProjectDetectLogFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from . import params, schemas, crud, models
|
||||
from core.dependencies import IdList
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||
from fastapi import APIRouter, Depends, UploadFile, Form
|
||||
from apps.vadmin.auth.utils.current import FullAdminAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
|
||||
@ -124,7 +124,7 @@ async def project_pager(
|
||||
@app.post("/img", summary="上传图片")
|
||||
async def up_img(
|
||||
project_id: int = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
files: List[UploadFile] = Form(...),
|
||||
img_type: str = Form(...),
|
||||
auth: Auth = Depends(FullAdminAuth())
|
||||
):
|
||||
|
@ -1,14 +1,13 @@
|
||||
from . import schemas, models, crud
|
||||
from apps.business.project import schemas as proj_schemas, models as proj_models, crud as proj_crud
|
||||
from utils import os_utils as os
|
||||
from application.settings import *
|
||||
from . import schemas, models, crud
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.project import models as proj_models, crud as proj_crud
|
||||
|
||||
|
||||
import yaml
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import List
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@ -73,7 +72,7 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
|
||||
|
||||
|
||||
async def operate_img_label(
|
||||
img_list: List[proj_models.ProjectImgLabel],
|
||||
img_list: list[proj_models.ProjectImgLabel],
|
||||
img_path: str,
|
||||
label_path: str,
|
||||
db: AsyncSession,
|
||||
|
Reference in New Issue
Block a user