完成推理模块的转移

This commit is contained in:
2025-04-17 15:57:16 +08:00
parent 74e8f0d415
commit b0379e64c9
130 changed files with 14269 additions and 3201 deletions

View File

@ -5,10 +5,18 @@
# @File : crud.py
# @IDE : PyCharm
# @desc : 数据访问层
from sqlalchemy.ext.asyncio import AsyncSession
from core.crud import DalBase
from . import schemas, models
from utils.random_utils import random_str
from utils import os_utils as os
from application.settings import detect_url
from utils.huawei_obs import ObsClient
from utils import status
from core.exception import CustomException
from fastapi import UploadFile
class ProjectDetectDal(DalBase):
@ -17,17 +25,86 @@ class ProjectDetectDal(DalBase):
super(ProjectDetectDal, self).__init__()
self.db = db
self.model = models.ProjectDetect
self.schema = schemas.ProjectDetectSimpleOut
self.schema = schemas.ProjectDetectOut
async def check_name(self, name: str,project_id: int) -> bool:
"""
校验推理集合名称
"""
count = self.get_count(
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
return count > 0
async def add_detect(self, data: schemas.ProjectDetectIn):
"""
新增集合
"""
detect = models.ProjectDetect(**data.model_dump())
detect.detect_no = random_str(6)
detect.detect_version = 0
detect.detect_status = '0'
url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url
await self.create_data(data)
return detect
async def delete_detects(self, ids: list[int]):
"""
删除集合数据+文件夹的文件夹+每次推理日志的文件
"""
for id_ in ids:
detect_info = await self.get_data(data_id=id_)
if detect_info.file_type != 'rtsp':
os.delete_paths(detect_info.folder_url)
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids])
for log in logs:
os.delete_paths(log.folder_url)
await self.delete_datas(ids=ids, v_soft=False)
class ProjectDetectImgDal(DalBase):
class ProjectDetectFileDal(DalBase):
def __init__(self, db: AsyncSession):
super(ProjectDetectImgDal, self).__init__()
super(ProjectDetectFileDal, self).__init__()
self.db = db
self.model = models.ProjectDetectImg
self.schema = schemas.ProjectDetectImgSimpleOut
self.model = models.ProjectDetectFile
self.schema = schemas.ProjectDetectFileOut
async def file_count(self, detect_id: int) -> int:
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
return count
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]):
images = []
for file in files:
image = models.ProjectDetectFile()
image.detect_id = detect.id
image.file_name = file.filename
# 保存原图
path = os.save_images(detect.folder_url, file=file)
image.image_url = path
# 上传到obs
object_key = detect.detect_no + '/' + file.filename
success, key, url = ObsClient.put_file(object_key=object_key, file_path=path)
if success:
image.object_key = object_key
image.thumb_image_url = url
else:
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
images.append(image)
await self.create_datas(images)
async def delete_files(self, ids: list[int]):
file_urls = []
object_keys = []
for id_ in ids:
file = self.get_data(data_id=id_)
if file:
file_urls.append(file.file_url)
object_keys.append(file.object_key)
os.delete_paths(file_urls)
ObsClient.del_objects(object_keys)
await self.delete_datas(ids, v_soft=False)
class ProjectDetectLogDal(DalBase):
@ -38,10 +115,10 @@ class ProjectDetectLogDal(DalBase):
self.schema = schemas.ProjectDetectLogSimpleOut
class ProjectDetectLogImgDal(DalBase):
class ProjectDetectLogFileDal(DalBase):
def __init__(self, db: AsyncSession):
super(ProjectDetectLogImgDal, self).__init__()
super(ProjectDetectLogFileDal, self).__init__()
self.db = db
self.model = models.ProjectDetectLogImg
self.schema = schemas.ProjectDetectLogImgSimpleOut
self.model = models.ProjectDetectLogFile
self.schema = schemas.ProjectDetectLogFileOut

View File

@ -0,0 +1 @@
from .detect import ProjectDetect, ProjectDetectFile, ProjectDetectLog, ProjectDetectLogFile

View File

@ -22,17 +22,18 @@ class ProjectDetect(BaseModel):
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
class ProjectDetectImg(BaseModel):
class ProjectDetectFile(BaseModel):
"""
待推理图片
"""
__tablename__ = "project_detect_img"
__table_args__ = ({'comment': '待推理图片'})
__tablename__ = "project_detect_file"
__table_args__ = ({'comment': '待推理文件'})
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
file_url: Mapped[str] = mapped_column(String(255), nullable=False)
object_key: Mapped[str] = mapped_column(String(255), nullable=False)
thumb_file_url: Mapped[str] = mapped_column(String(255), nullable=False)
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
@ -55,13 +56,13 @@ class ProjectDetectLog(BaseModel):
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
class ProjectDetectLogImg(BaseModel):
class ProjectDetectLogFile(BaseModel):
"""
推理完成的图片
"""
__tablename__ = "project_detect_log_img"
__tablename__ = "project_detect_log_file"
__table_args__ = ({'comment': '项目训练版本信息表'})
log_id: Mapped[int] = mapped_column(Integer, nullable=False)
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
file_url: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -1,4 +1,4 @@
from .project_detect import ProjectDetectParams
from .project_detect_img import ProjectDetectImgParams
from .project_detect_file import ProjectDetectFileParams
from .project_detect_log import ProjectDetectLogParams
from .project_detect_log_img import ProjectDetectLogImgParams
from .project_detect_log_file import ProjectDetectLogFileParams

View File

@ -6,10 +6,14 @@
# @IDE : PyCharm
# @desc : 项目推理集合信息
from fastapi import Depends
from fastapi import Depends, Query
from core.dependencies import Paging, QueryParams
class ProjectDetectParams(QueryParams):
def __init__(self, params: Paging = Depends()):
def __init__(
self,
project_id: int | None = Query(None, title="项目id"),
params: Paging = Depends()):
super().__init__(params)
self.project_id = project_id

View File

@ -0,0 +1,19 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : project_detect_file.py
# @IDE : PyCharm
# @desc : 项目推理集合图片信息
from fastapi import Depends, Query
from core.dependencies import Paging, QueryParams
class ProjectDetectFileParams(QueryParams):
def __init__(
self,
detect_id: int | None = Query(0, title="推理集合id"),
params: Paging = Depends()):
super().__init__(params)
self.detect_id = detect_id

View File

@ -1,15 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : project_detect_img.py
# @IDE : PyCharm
# @desc : 项目推理集合图片信息
from fastapi import Depends
from core.dependencies import Paging, QueryParams
class ProjectDetectImgParams(QueryParams):
def __init__(self, params: Paging = Depends()):
super().__init__(params)

View File

@ -6,10 +6,14 @@
# @IDE : PyCharm
# @desc : 项目推理记录信息
from fastapi import Depends
from fastapi import Depends, Query
from core.dependencies import Paging, QueryParams
class ProjectDetectLogParams(QueryParams):
def __init__(self, params: Paging = Depends()):
def __init__(
self,
detect_id: int | None = Query(0, title="推理集合id"),
params: Paging = Depends()):
super().__init__(params)
self.detect_id = detect_id

View File

@ -0,0 +1,19 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:31
# @File : project_detect_log_file.py
# @IDE : PyCharm
# @desc : 项目推理记录图片信息
from fastapi import Depends, Query
from core.dependencies import Paging, QueryParams
class ProjectDetectLogFileParams(QueryParams):
def __init__(
self,
log_id: int | None = Query(0, title="推理记录id"),
params: Paging = Depends()):
super().__init__(params)
self.log_id = log_id

View File

@ -1,15 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:31
# @File : project_detect_log_img.py
# @IDE : PyCharm
# @desc : 项目推理记录图片信息
from fastapi import Depends
from core.dependencies import Paging, QueryParams
class ProjectDetectLogImgParams(QueryParams):
def __init__(self, params: Paging = Depends()):
super().__init__(params)

View File

@ -1,4 +1,4 @@
from .project_detect import ProjectDetect, ProjectDetectSimpleOut
from .project_detect_img import ProjectDetectImg, ProjectDetectImgSimpleOut
from .project_detect_log import ProjectDetectLog, ProjectDetectLogSimpleOut
from .project_detect_log_img import ProjectDetectLogImg, ProjectDetectLogImgSimpleOut
from .project_detect import ProjectDetectIn, ProjectDetectPager, ProjectDetectOut, ProjectDetectList
from .project_detect_file import ProjectDetectFilePager, ProjectDetectFileOut
from .project_detect_log import ProjectDetectLogIn, ProjectDetectLogOut
from .project_detect_log_file import ProjectDetectLogFileOut

View File

@ -7,24 +7,43 @@
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional
from datetime import datetime
class ProjectDetect(BaseModel):
project_id: int = Field(..., title="None")
detect_name: str = Field(..., title="None")
detect_version: int = Field(..., title="None")
detect_no: str = Field(..., title="None")
detect_status: int = Field(..., title="None")
file_type: str = Field(..., title="None")
folder_url: str = Field(..., title="None")
rtsp_url: str = Field(..., title="None")
user_id: int = Field(..., title="None")
class ProjectDetectIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
file_type: Optional[str] = Field('img', description="推理集合文件类别")
detect_name: Optional[str] = Field(..., description="推理集合名称")
rtsp_url: Optional[str] = Field(None, description="视频流地址")
class ProjectDetectSimpleOut(ProjectDetect):
class ProjectDetectPager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
detect_name: Optional[str] = Field(None, description="推理集合名称")
pagerNum: Optional[int] = Field(1, description="当前页码")
pagerSize: Optional[int] = Field(10, description="每页数量")
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., title="编号")
create_datetime: DatetimeStr = Field(..., title="创建时间")
update_datetime: DatetimeStr = Field(..., title="更新时间")
class ProjectDetectOut(BaseModel):
id: Optional[int]
project_id: Optional[int]
detect_name: Optional[str]
detect_no: Optional[str]
detect_version: Optional[int]
file_type: Optional[str]
folder_url: Optional[str]
rtsp_url: Optional[str]
create_time: Optional[datetime]
model_config = ConfigDict(from_attributes=True)
class ProjectDetectList(BaseModel):
id: Optional[int]
file_type: Optional[str]
detect_name: Optional[str]
model_config = ConfigDict(from_attributes=True)

View File

@ -0,0 +1,26 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : project_detect_file.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from datetime import datetime
class ProjectDetectFilePager(BaseModel):
detect_id: Optional[int] = Field(..., description="训练集合id")
pagerNum: Optional[int] = Field(None, description="当前页码")
pagerSize: Optional[int] = Field(None, description="每页数量")
class ProjectDetectFileOut(BaseModel):
id: Optional[int] = Field(None, description="id")
detect_id: Optional[int] = Field(..., description="训练集合id")
file_name: Optional[str] = Field(None, description="文件名称")
thumb_file_url: Optional[str] = Field(None, description="文件路径")
create_time: Optional[datetime] = Field(None, description="上传时间")
model_config = ConfigDict(from_attributes=True)

View File

@ -1,26 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : project_detect_img.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
class ProjectDetectImg(BaseModel):
detect_id: int = Field(..., title="None")
file_name: str = Field(..., title="None")
image_url: str = Field(..., title="None")
thumb_image_url: str = Field(..., title="None")
user_id: int = Field(..., title="None")
class ProjectDetectImgSimpleOut(ProjectDetectImg):
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., title="编号")
create_datetime: DatetimeStr = Field(..., title="创建时间")
update_datetime: DatetimeStr = Field(..., title="更新时间")

View File

@ -7,25 +7,24 @@
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional
from datetime import datetime
class ProjectDetectLog(BaseModel):
detect_id: int = Field(..., title="None")
detect_version: str = Field(..., title="None")
detect_name: str = Field(..., title="None")
train_id: int = Field(..., title="None")
train_version: str = Field(..., title="None")
pt_type: str = Field(..., title="None")
pt_url: str = Field(..., title="None")
folder_url: str = Field(..., title="None")
detect_folder_url: str = Field(..., title="None")
user_id: int = Field(..., title="None")
class ProjectDetectLogIn(BaseModel):
detect_id: Optional[int] = Field(..., description="推理集合id")
train_id: Optional[int] = Field(..., description="训练结果id")
pt_type: Optional[str] = Field('best', description="权重文件类型")
class ProjectDetectLogSimpleOut(ProjectDetectLog):
class ProjectDetectLogOut(BaseModel):
id: Optional[int]
detect_id: Optional[int]
detect_version: Optional[str]
detect_name: Optional[str]
train_id: Optional[int]
train_version: Optional[str]
pt_type: Optional[str]
create_time: Optional[datetime]
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., title="编号")
create_datetime: DatetimeStr = Field(..., title="创建时间")
update_datetime: DatetimeStr = Field(..., title="更新时间")

View File

@ -0,0 +1,20 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:31
# @File : project_detect_log_file.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
class ProjectDetectLogFileOut(BaseModel):
id: Optional[int]
file_name: Optional[str]
thumb_file_url: Optional[str]
create_time: Optional[datetime]
model_config = ConfigDict(from_attributes=True)

View File

@ -1,24 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:31
# @File : project_detect_log_img.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
class ProjectDetectLogImg(BaseModel):
log_id: int = Field(..., title="None")
file_name: str = Field(..., title="None")
image_url: str = Field(..., title="None")
class ProjectDetectLogImgSimpleOut(ProjectDetectLogImg):
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., title="编号")
create_datetime: DatetimeStr = Field(..., title="创建时间")
update_datetime: DatetimeStr = Field(..., title="更新时间")

View File

@ -0,0 +1,226 @@
from application.settings import yolo_url, detect_url
from utils.websocket_server import room_manager
from utils import os_utils as os
from . import models, crud, schemas
from apps.business.train import models as train_models
from utils.yolov5.models.common import DetectMultiBackend
from utils.yolov5.utils.torch_utils import select_device
from utils.yolov5.utils.dataloaders import LoadStreams
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
from ultralytics.utils.plotting import Annotator, colors
import time
import torch
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
async def before_detect(
detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect,
train: train_models.ProjectTrain,
db: AsyncSession):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param db:
:return:
"""
# 推理版本
version_path = 'v' + str(detect.detect_version + 1)
# 权重文件
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
# 推理集合文件路径
img_url = detect.folder_url
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
# 构建推理记录数据
detect_log = models.ProjectDetectLog()
detect_log.detect_name = detect.detect_name
detect_log.detect_id = detect.id
detect_log.detect_version = version_path
detect_log.train_id = train.id
detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
await crud.ProjectDetectLogDal(db).create_data(detect_log)
return detect_log
async def run_detect_img(
weights: str,
source: str,
project: str,
name: str,
log_id: int,
detect_id: int,
db: AsyncSession,
rd: Redis):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:param log_id: 日志id
:param detect_id: 推理集合id
:param db: 数据库session
:param rd: Redis
:return:
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
detect_files = crud.ProjectDetectFileDal(db).get_data(
v_where=[models.ProjectDetectFile.detect_id == detect_id])
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param rd: Redis :redis
:return:
"""
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3) # 等待3s等待websocket进入
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
with dt[1]:
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image
p, im0, frame = path[i], im0s[i].copy(), dataset.count
annotator = Annotator(im0, line_width=3, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
annotator.box_label(xyxy, label, color=colors(c, True))
# Stream results
im0 = annotator.result()
# 将帧编码为 JPEG
ret, jpeg = cv2.imencode('.jpg', im0)
if ret:
frame_data = jpeg.tobytes()
await room_manager.send_stream_to_room(room, frame_data)
else:
print(room, '结束推理')
break
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
# 可选: 关闭循环
loop.close()
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
# 可选: 关闭循环
loop.close()

View File

@ -5,15 +5,21 @@
# @File : views.py
# @IDE : PyCharm
# @desc : 路由,视图文件
from core.dependencies import IdList
from apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
from apps.vadmin.auth.utils.current import AllUserAuth
from core.database import db_getter
from . import schemas, crud, models, params
from fastapi import Depends, APIRouter
from utils.response import SuccessResponse
import service
from . import schemas, crud, params
from core.dependencies import IdList
from core.database import redis_getter
from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import threading
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, APIRouter, Form, UploadFile
app = APIRouter()
@ -22,129 +28,120 @@ app = APIRouter()
###########################################################
# 项目推理集合信息
###########################################################
@app.get("/project/detect", summary="获取项目推理集合信息列表", tags=["项目推理集合信息"])
async def get_project_detect_list(p: params.ProjectDetectParams = Depends(), auth: Auth = Depends(AllUserAuth())):
@app.get("/list", summary="获取项目推理集合信息列表")
async def detect_list(
p: params.ProjectDetectParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.post("/project/detect", summary="创建项目推理集合信息", tags=["项目推理集合信息"])
async def create_project_detect(data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectDal(auth.db).create_data(data=data))
@app.post("/", summary="创建项目推理集合信息")
async def add_detect(
data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.create_data(data=data)
return SuccessResponse(msg="保存成功")
@app.delete("/project/detect", summary="删除项目推理集合信息", description="硬删除", tags=["项目推理集合信息"])
async def delete_project_detect_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
@app.delete("/", summary="删除项目推理集合信息")
async def delete_detect(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
@app.put("/project/detect/{data_id}", summary="更新项目推理集合信息", tags=["项目推理集合信息"])
async def put_project_detect(data_id: int, data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectDal(auth.db).put_data(data_id, data))
@app.get("/project/detect/{data_id}", summary="获取项目推理集合信息信息", tags=["项目推理集合信息"])
async def get_project_detect(data_id: int, db: AsyncSession = Depends(db_getter)):
schema = schemas.ProjectDetectSimpleOut
return SuccessResponse(await crud.ProjectDetectDal(db).get_data(data_id, v_schema=schema))
###########################################################
# 项目推理集合图片信息
# 项目推理集合文件信息
###########################################################
@app.get("/project/detect/img", summary="获取项目推理集合图片信息列表", tags=["项目推理集合图片信息"])
async def get_project_detect_img_list(p: params.ProjectDetectImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.get("/file", summary="获取项目推理集合文件信息列表")
async def file_list(
p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
if p.limit:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
else:
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)
@app.post("/project/detect/img", summary="创建项目推理集合图片信息", tags=["项目推理集合图片信息"])
async def create_project_detect_img(data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).create_data(data=data))
@app.post("/file", summary="上传项目推理集合文件")
async def upload_file(
detect_id: int = Form(...),
files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())):
file_dal = crud.ProjectDetectFileDal(auth.db)
detect_out = file_dal.get_data(data_id=detect_id)
if detect_out is None:
return ErrorResponse("训练集合查询失败,请刷新后再试")
await file_dal.add_file(detect_out, files)
return SuccessResponse(msg="上传成功")
@app.delete("/project/detect/img", summary="删除项目推理集合图片信息", description="硬删除", tags=["项目推理集合图片信息"])
async def delete_project_detect_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
@app.delete("/file", summary="删除项目推理集合文件信息")
async def delete_file(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
return SuccessResponse("删除成功")
@app.put("/project/detect/img/{data_id}", summary="更新项目推理集合图片信息", tags=["项目推理集合图片信息"])
async def put_project_detect_img(data_id: int, data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).put_data(data_id, data))
@app.get("/project/detect/img/{data_id}", summary="获取项目推理集合图片信息信息", tags=["项目推理集合图片信息"])
async def get_project_detect_img(data_id: int, db: AsyncSession = Depends(db_getter)):
schema = schemas.ProjectDetectImgSimpleOut
return SuccessResponse(await crud.ProjectDetectImgDal(db).get_data(data_id, v_schema=schema))
@app.post("/detect", summary="开始推理")
def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = detect_dal.get_data(detect_log_in.detect_id)
if detect is None:
return ErrorResponse(msg="训练集合不存在")
train = train_dal.get_data(detect_log_in.train_id)
if train is None:
return ErrorResponse("训练权重不存在")
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.id, detect_log.detect_id, auth.db,))
thread_train.start()
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
thread_train.start()
return SuccessResponse(msg="执行成功")
###########################################################
# 项目推理记录信息
###########################################################
@app.get("/project/detect/log", summary="获取项目推理记录信息列表", tags=["项目推理记录信息"])
async def get_project_detect_log_list(p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())):
@app.get("/log", summary="获取项目推理记录列表")
async def log_pager(
p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.post("/project/detect/log", summary="创建项目推理记录信息", tags=["项目推理记录信息"])
async def create_project_detect_log(data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).create_data(data=data))
@app.delete("/project/detect/log", summary="删除项目推理记录信息", description="硬删除", tags=["项目推理记录信息"])
async def delete_project_detect_log_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectLogDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
@app.put("/project/detect/log/{data_id}", summary="更新项目推理记录信息", tags=["项目推理记录信息"])
async def put_project_detect_log(data_id: int, data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).put_data(data_id, data))
@app.get("/project/detect/log/{data_id}", summary="获取项目推理记录信息信息", tags=["项目推理记录信息"])
async def get_project_detect_log(data_id: int, db: AsyncSession = Depends(db_getter)):
schema = schemas.ProjectDetectLogSimpleOut
return SuccessResponse(await crud.ProjectDetectLogDal(db).get_data(data_id, v_schema=schema))
###########################################################
# 项目推理记录图片信息
###########################################################
@app.get("/project/detect/log/img", summary="获取项目推理记录图片信息列表", tags=["项目推理记录图片信息"])
async def get_project_detect_log_img_list(p: params.ProjectDetectLogImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.post("/project/detect/log/img", summary="创建项目推理记录图片信息", tags=["项目推理记录图片信息"])
async def create_project_detect_log_img(data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).create_data(data=data))
@app.delete("/project/detect/log/img", summary="删除项目推理记录图片信息", description="硬删除", tags=["项目推理记录图片信息"])
async def delete_project_detect_log_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectLogImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
@app.put("/project/detect/log/img/{data_id}", summary="更新项目推理记录图片信息", tags=["项目推理记录图片信息"])
async def put_project_detect_log_img(data_id: int, data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).put_data(data_id, data))
@app.get("/project/detect/log/img/{data_id}", summary="获取项目推理记录图片信息信息", tags=["项目推理记录图片信息"])
async def get_project_detect_log_img(data_id: int, db: AsyncSession = Depends(db_getter)):
schema = schemas.ProjectDetectLogImgSimpleOut
return SuccessResponse(await crud.ProjectDetectLogImgDal(db).get_data(data_id, v_schema=schema))
@app.get("/log_files", summary="获取项目推理记录文件列表")
async def log_files(
p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)

View File

@ -10,7 +10,7 @@ from . import params, schemas, crud, models
from core.dependencies import IdList
from typing import List
from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi import APIRouter, Depends, UploadFile, Form
from apps.vadmin.auth.utils.current import FullAdminAuth
from apps.vadmin.auth.utils.validation.auth import Auth
@ -124,7 +124,7 @@ async def project_pager(
@app.post("/img", summary="上传图片")
async def up_img(
project_id: int = Form(...),
files: List[UploadFile] = File(...),
files: List[UploadFile] = Form(...),
img_type: str = Form(...),
auth: Auth = Depends(FullAdminAuth())
):

View File

@ -1,14 +1,13 @@
from . import schemas, models, crud
from apps.business.project import schemas as proj_schemas, models as proj_models, crud as proj_crud
from utils import os_utils as os
from application.settings import *
from . import schemas, models, crud
from utils.websocket_server import room_manager
from apps.business.project import models as proj_models, crud as proj_crud
import yaml
import asyncio
import subprocess
from typing import List
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
@ -73,7 +72,7 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
async def operate_img_label(
img_list: List[proj_models.ProjectImgLabel],
img_list: list[proj_models.ProjectImgLabel],
img_path: str,
label_path: str,
db: AsyncSession,