完成训练模块的转移
This commit is contained in:
@ -5,23 +5,19 @@
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 数据访问层
|
||||
import application.settings
|
||||
from . import schemas, models, params
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils import os_utils as os, random_utils as ru
|
||||
from utils.huawei_obs import ObsClient
|
||||
from utils import status
|
||||
from core.exception import CustomException
|
||||
if application.settings.DEBUG:
|
||||
from application.config.development import datasets_url, runs_url, images_url
|
||||
else:
|
||||
from application.config.production import datasets_url, runs_url, images_url
|
||||
from application.settings import datasets_url, runs_url, images_url
|
||||
|
||||
from typing import Any, List
|
||||
from core.crud import DalBase
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, case, and_
|
||||
from sqlalchemy import select, func, case
|
||||
|
||||
|
||||
class ProjectInfoDal(DalBase):
|
||||
@ -96,6 +92,9 @@ class ProjectInfoDal(DalBase):
|
||||
project: schemas.ProjectInfoIn,
|
||||
auth: Auth
|
||||
) -> Any:
|
||||
"""
|
||||
新建项目
|
||||
"""
|
||||
obj = self.model(**project.model_dump())
|
||||
obj.user_id = auth.user.id
|
||||
obj.project_no = ru.random_str(6)
|
||||
@ -106,7 +105,9 @@ class ProjectInfoDal(DalBase):
|
||||
obj.dept_id = 0
|
||||
else:
|
||||
obj.dept_id = auth.dept_ids[0]
|
||||
# 新建数据集文件夹
|
||||
os.create_folder(datasets_url, obj.project_no)
|
||||
# 新建训练文件夹
|
||||
os.create_folder(runs_url, obj.project_no)
|
||||
await self.flush(obj)
|
||||
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
|
||||
@ -214,6 +215,55 @@ class ProjectImageDal(DalBase):
|
||||
ObsClient.del_objects(object_keys)
|
||||
await self.delete_datas(ids)
|
||||
|
||||
async def get_img_count(
|
||||
self,
|
||||
proj_id: int) -> int:
|
||||
"""
|
||||
查询图片数量
|
||||
"""
|
||||
train_count = await self.get_count(
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'])
|
||||
val_count = await self.get_count(
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'])
|
||||
return train_count, val_count
|
||||
|
||||
async def check_image_label(
|
||||
self,
|
||||
proj_id: int) -> int:
|
||||
"""
|
||||
查询图片未标注数量
|
||||
"""
|
||||
# 1 子查询
|
||||
subquery = (
|
||||
select(
|
||||
models.ProjectImgLabel.image_id,
|
||||
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
|
||||
)
|
||||
.group_by(models.ProjectImgLabel.image_id)
|
||||
.subquery()
|
||||
)
|
||||
# 2 主查询
|
||||
query = (
|
||||
select(
|
||||
models.ProjectImage,
|
||||
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
||||
)
|
||||
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
|
||||
)
|
||||
train_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'],
|
||||
v_return_sql=True)
|
||||
train_count = await self.get_count(train_count_sql)
|
||||
|
||||
val_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'],
|
||||
v_return_sql=True)
|
||||
val_count = await self.get_count(val_count_sql)
|
||||
|
||||
return train_count, val_count
|
||||
|
||||
|
||||
class ProjectLabelDal(DalBase):
|
||||
"""
|
||||
@ -233,14 +283,27 @@ class ProjectLabelDal(DalBase):
|
||||
label_id: int = None
|
||||
):
|
||||
wheres = [
|
||||
models.ProjectLabel.project_id == pro_id,
|
||||
models.ProjectLabel.label_name == name
|
||||
self.model.project_id == pro_id,
|
||||
self.model.label_name == name
|
||||
]
|
||||
if label_id:
|
||||
wheres.append(models.ProjectLabel.id != label_id)
|
||||
wheres.append(self.model.id != label_id)
|
||||
count = await self.get_count(v_where=wheres)
|
||||
return count > 0
|
||||
|
||||
async def get_label_for_train(self, project_id: int):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = self.get_datas(
|
||||
v_where=[self.model.project_id == project_id],
|
||||
v_order='asc',
|
||||
v_order_field='id',
|
||||
v_return_count=False)
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
return id_list, name_list
|
||||
|
||||
|
||||
class ProjectImgLabelDal(DalBase):
|
||||
"""
|
||||
@ -260,6 +323,13 @@ class ProjectImgLabelDal(DalBase):
|
||||
img.image_id = image_id
|
||||
await self.create_datas(img_labels)
|
||||
|
||||
async def get_img_label_list(self, image_id: int):
|
||||
return await self.get_datas(
|
||||
v_return_count=False,
|
||||
v_where=[self.model.image_id == image_id],
|
||||
v_order="asc",
|
||||
v_order_field="id")
|
||||
|
||||
|
||||
class ProjectImgLeaferDal(DalBase):
|
||||
"""
|
||||
|
@ -15,14 +15,10 @@ class ProjectInfoParams(QueryParams):
|
||||
self,
|
||||
project_name: str | None = Query(None, title="项目名称"),
|
||||
type_code: str | None = Query(None, title="项目类别"),
|
||||
dept_id: str | None = Query(None, title="部门id"),
|
||||
user_id: str | None = Query(None, title="用户id"),
|
||||
params: Paging = Depends()
|
||||
):
|
||||
super().__init__(params)
|
||||
self.project_name = ("like", project_name)
|
||||
self.type_code = type_code
|
||||
self.dept_id = dept_id
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
|
@ -9,6 +9,8 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from core.crud import DalBase
|
||||
from . import models, schemas
|
||||
from utils import os_utils as os
|
||||
from utils.csv_utils import read_csv
|
||||
|
||||
|
||||
class ProjectTrainDal(DalBase):
|
||||
@ -17,4 +19,75 @@ class ProjectTrainDal(DalBase):
|
||||
super(ProjectTrainDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectTrain
|
||||
self.schema = schemas.ProjectTrainSimpleOut
|
||||
self.schema = schemas.ProjectTrainOut
|
||||
|
||||
async def get_result(self, train_id: int):
|
||||
"""
|
||||
查询训练报告
|
||||
"""
|
||||
data = await self.get_data(data_id=train_id)
|
||||
if data is None:
|
||||
return None
|
||||
result_csv_path = os.file_path(data.train_url, 'results.csv')
|
||||
result_row = read_csv(result_csv_path)
|
||||
report_data = {}
|
||||
# 轮数
|
||||
epoch_data = []
|
||||
# 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。
|
||||
train_box_loss = []
|
||||
# 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。
|
||||
train_obj_loss = []
|
||||
# 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。
|
||||
train_cls_loss = []
|
||||
|
||||
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
|
||||
val_box_loss = []
|
||||
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
|
||||
val_obj_loss = []
|
||||
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
|
||||
val_cls_loss = []
|
||||
|
||||
# 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。
|
||||
m_p = []
|
||||
# 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。
|
||||
m_r = []
|
||||
|
||||
# 主干网络(Backbone)的学习率。
|
||||
x_lr0 = []
|
||||
# 检测头(Head)的学习率。
|
||||
x_lr1 = []
|
||||
|
||||
for row in result_row:
|
||||
epoch_data.append(row[0].strip())
|
||||
|
||||
train_box_loss.append(row[1].strip())
|
||||
train_obj_loss.append(row[2].strip())
|
||||
train_cls_loss.append(row[3].strip())
|
||||
|
||||
val_box_loss.append(row[8].strip())
|
||||
val_obj_loss.append(row[9].strip())
|
||||
val_cls_loss.append(row[10].strip())
|
||||
|
||||
m_p.append(row[4].strip())
|
||||
m_r.append(row[5].strip())
|
||||
|
||||
x_lr0.append(row[11].strip())
|
||||
x_lr1.append(row[12].strip())
|
||||
|
||||
report_data['epoch_data'] = epoch_data
|
||||
|
||||
report_data['train_box_loss'] = train_box_loss
|
||||
report_data['train_obj_loss'] = train_obj_loss
|
||||
report_data['train_cls_loss'] = train_cls_loss
|
||||
|
||||
report_data['val_box_loss'] = val_box_loss
|
||||
report_data['val_obj_loss'] = val_obj_loss
|
||||
report_data['val_cls_loss'] = val_cls_loss
|
||||
|
||||
report_data['m_p'] = m_p
|
||||
report_data['m_r'] = m_r
|
||||
|
||||
report_data['x_lr0'] = x_lr0
|
||||
report_data['x_lr1'] = x_lr1
|
||||
|
||||
return report_data
|
||||
|
@ -0,0 +1 @@
|
||||
from .train import ProjectTrain
|
@ -6,7 +6,7 @@ from db.db_base import BaseModel
|
||||
|
||||
class ProjectTrain(BaseModel):
|
||||
"""
|
||||
项目训练版本信息表
|
||||
项目训练信息表
|
||||
"""
|
||||
__tablename__ = "project_train"
|
||||
__table_args__ = ({'comment': '项目训练版本信息表'})
|
||||
|
@ -6,10 +6,15 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目巡逻片信息
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectTrainParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: int | 0 = Query(0, title="项目id"),
|
||||
params: Paging = Depends()
|
||||
):
|
||||
super().__init__(params)
|
||||
self.project_id = project_id
|
||||
|
@ -1 +1 @@
|
||||
from .project_train import ProjectTrain, ProjectTrainSimpleOut
|
||||
from .project_train import ProjectTrainIn, ProjectTrainOut
|
||||
|
@ -6,27 +6,29 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
from typing import Optional
|
||||
|
||||
"""
|
||||
项目训练版本信息表
|
||||
"""
|
||||
|
||||
|
||||
class ProjectTrain(BaseModel):
|
||||
project_id: int = Field(..., title="None")
|
||||
train_version: str = Field(..., title="None")
|
||||
train_url: str = Field(..., title="None")
|
||||
train_data: str = Field(..., title="None")
|
||||
weights_id: int = Field(..., title="None")
|
||||
weights_name: str = Field(..., title="None")
|
||||
epochs: int = Field(..., title="None")
|
||||
patience: int = Field(..., title="None")
|
||||
best_pt: str = Field(..., title="None")
|
||||
last_pt: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
class ProjectTrainIn(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
weights_id: Optional[str] = Field(None, description="权重文件")
|
||||
epochs: Optional[int] = Field(50, description="训练轮数")
|
||||
patience: Optional[int] = Field(20, description="早停的耐心值")
|
||||
|
||||
|
||||
class ProjectTrainSimpleOut(ProjectTrain):
|
||||
class ProjectTrainOut(BaseModel):
|
||||
id: Optional[int] = Field(None, description="训练id")
|
||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||
weights_name: Optional[str] = Field(None, description="权重名称")
|
||||
epochs: Optional[int] = Field(None, description="训练轮数")
|
||||
patience: Optional[int] = Field(None, description="早停的耐心值")
|
||||
create_time: Optional[datetime] = Field(None, description="训练时间")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
||||
|
200
apps/business/train/service.py
Normal file
200
apps/business/train/service.py
Normal file
@ -0,0 +1,200 @@
|
||||
from . import schemas, models, crud
|
||||
from apps.business.project import schemas as proj_schemas, models as proj_models, crud as proj_crud
|
||||
from utils import os_utils as os
|
||||
from application.settings import *
|
||||
from utils.websocket_server import room_manager
|
||||
|
||||
|
||||
import yaml
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import List
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
|
||||
"""
|
||||
yolov5执行训练任务
|
||||
:param proj_info: 项目信息
|
||||
:param db: 数据库session
|
||||
:return:
|
||||
"""
|
||||
proj_dal = proj_crud.ProjectInfoDal(db)
|
||||
img_dal = proj_crud.ProjectImageDal(db)
|
||||
label_dal = proj_crud.ProjectLabelDal(db)
|
||||
# 先查询两个图片列表
|
||||
project_images_train = img_dal.get_data(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'])
|
||||
project_images_val = img_dal.get_data(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'])
|
||||
|
||||
# 得到训练版本
|
||||
version_path = 'v' + str(proj_info.train_version + 1)
|
||||
|
||||
# 创建训练的根目录
|
||||
train_path = os.create_folder(datasets_url, proj_info.project_no, version_path)
|
||||
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id)
|
||||
|
||||
# 创建图片的的两个文件夹
|
||||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||||
|
||||
# 创建标签的两个文件夹
|
||||
label_path_train = os.create_folder(train_path, 'labels', 'train')
|
||||
label_path_val = os.create_folder(train_path, 'labels', 'val')
|
||||
|
||||
# 在根目录下创建yaml文件
|
||||
yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml')
|
||||
yaml_data = {
|
||||
'path': train_path,
|
||||
'train': 'images/train',
|
||||
'val': 'images/val',
|
||||
'test': None,
|
||||
'names': {i: name for i, name in enumerate(label_name_list)}
|
||||
}
|
||||
with open(yaml_file, 'w', encoding='utf-8') as file:
|
||||
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
# 开始循环复制图片和生成label.txt
|
||||
# 先操作train
|
||||
operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
|
||||
# 再操作val
|
||||
operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
|
||||
|
||||
# 开始执行异步训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, proj_info.project_no)
|
||||
name = version_path
|
||||
|
||||
return data, project, name
|
||||
|
||||
|
||||
async def operate_img_label(
|
||||
img_list: List[proj_models.ProjectImgLabel],
|
||||
img_path: str,
|
||||
label_path: str,
|
||||
db: AsyncSession,
|
||||
label_id_list: []):
|
||||
"""
|
||||
生成图片和标签内容
|
||||
:param label_id_list:
|
||||
:param db: 数据库session
|
||||
:param img_list:
|
||||
:param img_path:
|
||||
:param label_path:
|
||||
:return:
|
||||
"""
|
||||
for i in range(len(img_list)):
|
||||
image = img_list[i]
|
||||
# 先复制图片,并把图片改名,不改后缀
|
||||
file_name = 'image' + str(i)
|
||||
os.copy_and_rename_file(image.image_url, img_path, file_name)
|
||||
# 查询这张图片的label信息然后生成这张照片的txt文件
|
||||
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
|
||||
label_txt_path = os.file_path(label_path, file_name + '.txt')
|
||||
with open(label_txt_path, 'w', encoding='utf-8') as file:
|
||||
for image_label in img_label_list:
|
||||
index = label_id_list.index(image_label.label_id)
|
||||
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
|
||||
+ image_label.mark_center_y + ' '
|
||||
+ image_label.mark_width + ' '
|
||||
+ image_label.mark_height + '\n')
|
||||
|
||||
|
||||
async def run_event_loop(
|
||||
data: str,
|
||||
project: str,
|
||||
name: str,
|
||||
train_in: schemas.ProjectTrainIn,
|
||||
project_id: int,
|
||||
db: AsyncSession):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id,
|
||||
project_id, db))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
async def run_commend(
|
||||
data: str,
|
||||
project: str,
|
||||
name: str,
|
||||
epochs: int,
|
||||
patience: int,
|
||||
weights: str,
|
||||
project_id: int,
|
||||
db: AsyncSession,
|
||||
rd: Redis):
|
||||
"""
|
||||
执行训练
|
||||
:param data: 训练数据集
|
||||
:param project: 训练结果的项目目录
|
||||
:param name: 实验名称
|
||||
:param epochs: 训练轮数
|
||||
:param patience: 早停耐心值
|
||||
:param weights: 权重文件
|
||||
:param project_id: 项目id
|
||||
:param db: 数据库session
|
||||
:param rd: redis连接
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
room = 'train_' + str(project_id)
|
||||
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
|
||||
|
||||
# 增加权重文件,在之前训练的基础上重新巡逻
|
||||
if weights != '' and weights is not None:
|
||||
train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights))
|
||||
if train_info is not None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device=0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||
encoding='utf-8',
|
||||
) as process:
|
||||
while process.poll() is None:
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n' and '0%' not in line:
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
await room_manager.send_to_room(room, 'error')
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
# 然后保存版本训练信息
|
||||
train = models.ProjectTrain()
|
||||
train.project_id = project_id
|
||||
train.train_version = name
|
||||
train_url = os.file_path(project, name)
|
||||
train.train_url = train_url
|
||||
train.train_data = data
|
||||
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||||
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if weights is not None and weights != '':
|
||||
train.weights_id = weights
|
||||
train.weights_name = train_info.train_version
|
||||
train.patience = patience
|
||||
train.epochs = epochs
|
||||
await crud.ProjectTrainDal(db).create_data(data=train)
|
@ -3,49 +3,70 @@
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:32
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import APIRouter, Depends
|
||||
from . import models, schemas, crud, params
|
||||
from core.dependencies import IdList
|
||||
from . import models, schemas, crud
|
||||
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from utils.response import SuccessResponse
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from core.database import db_getter
|
||||
import service
|
||||
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目巡逻片信息
|
||||
# 项目训练信息
|
||||
###########################################################
|
||||
@app.get("/project/train", summary="获取项目巡逻片信息列表", tags=["项目巡逻片信息"])
|
||||
async def get_project_train_list(p: params.ProjectTrainParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectTrainDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
@app.post("/", summary="执行训练")
|
||||
async def run_train(
|
||||
train_in: schemas.ProjectTrainIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
proj_id = train_in.project_id
|
||||
proj_dal = ProjectInfoDal(auth.db)
|
||||
proj_img_dal = ProjectImageDal(auth.db)
|
||||
proj_info = await proj_dal.get_data(proj_id)
|
||||
if proj_info is None:
|
||||
return ErrorResponse(msg="项目信息查询错误")
|
||||
train_count, val_count = await proj_img_dal.get_img_count(proj_id)
|
||||
if train_count == 0:
|
||||
return ErrorResponse("请先上传训练图片")
|
||||
if train_count < 10:
|
||||
return ErrorResponse("训练图片少于10张,请继续上传训练图片")
|
||||
if val_count == 0:
|
||||
return ErrorResponse("请先上传验证图片")
|
||||
if val_count < 5:
|
||||
return ErrorResponse("验证图片少于5张,请继续上传验证图片")
|
||||
train_label_count, val_label_count = await proj_img_dal.check_image_label(proj_id)
|
||||
if train_label_count > 0:
|
||||
return ErrorResponse("训练图片中存在未标注的图片")
|
||||
if val_label_count > 0:
|
||||
return ErrorResponse("验证图片中存在未标注的图片")
|
||||
data, project, name = service.before_train(proj_info, auth.db)
|
||||
# 异步执行操作,操作过程通过websocket进行同步
|
||||
thread_train = threading.Thread(
|
||||
target=service.run_event_loop,
|
||||
args=(data, project, name, train_in, proj_id, auth.db,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
|
||||
@app.post("/project/train", summary="创建项目巡逻片信息", tags=["项目巡逻片信息"])
|
||||
async def create_project_train(data: schemas.ProjectTrain, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectTrainDal(auth.db).create_data(data=data))
|
||||
@app.get("/{proj_id}", summary="查询训练列表")
|
||||
async def train_list(
|
||||
proj_id: int,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectTrainDal(auth.db).get_datas(
|
||||
v_where=[models.ProjectTrain.project_id == proj_id],
|
||||
v_schema=schemas.ProjectTrainOut,
|
||||
v_order="asc",
|
||||
v_order_field="id",v_return_count=False)
|
||||
return SuccessResponse(data=datas)
|
||||
|
||||
|
||||
@app.delete("/project/train", summary="删除项目巡逻片信息", description="硬删除", tags=["项目巡逻片信息"])
|
||||
async def delete_project_train_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectTrainDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/train/{data_id}", summary="更新项目巡逻片信息", tags=["项目巡逻片信息"])
|
||||
async def put_project_train(data_id: int, data: schemas.ProjectTrain, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectTrainDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/train/{data_id}", summary="获取项目巡逻片信息信息", tags=["项目巡逻片信息"])
|
||||
async def get_project_train(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectTrainSimpleOut
|
||||
return SuccessResponse(await crud.ProjectTrainDal(db).get_data(data_id, v_schema=schema))
|
||||
@app.get("/result/{proj_id}", summary="查询训练报告")
|
||||
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())):
|
||||
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
|
||||
return SuccessResponse(data=result)
|
||||
|
||||
|
Reference in New Issue
Block a user