完成项目训练模块的接口测试
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from utils import os_utils as os
|
||||
from utils import os_utils as osu
|
||||
from application.settings import *
|
||||
from . import schemas, models, crud
|
||||
from utils.websocket_server import room_manager
|
||||
@ -19,34 +19,39 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
|
||||
:param db: 数据库session
|
||||
:return:
|
||||
"""
|
||||
proj_dal = proj_crud.ProjectInfoDal(db)
|
||||
img_dal = proj_crud.ProjectImageDal(db)
|
||||
label_dal = proj_crud.ProjectLabelDal(db)
|
||||
# 先查询两个图片列表
|
||||
project_images_train = img_dal.get_data(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'])
|
||||
project_images_val = img_dal.get_data(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'])
|
||||
project_images_train = await img_dal.get_datas(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'],
|
||||
limit=0,
|
||||
v_return_count=False,
|
||||
v_return_objs=True)
|
||||
project_images_val = await img_dal.get_datas(
|
||||
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'],
|
||||
limit=0,
|
||||
v_return_count=False,
|
||||
v_return_objs=True)
|
||||
|
||||
# 得到训练版本
|
||||
version_path = 'v' + str(proj_info.train_version + 1)
|
||||
|
||||
# 创建训练的根目录
|
||||
train_path = os.create_folder(datasets_url, proj_info.project_no, version_path)
|
||||
train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path)
|
||||
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id)
|
||||
label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id)
|
||||
|
||||
# 创建图片的的两个文件夹
|
||||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||||
img_path_train = osu.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = osu.create_folder(train_path, 'images', 'val')
|
||||
|
||||
# 创建标签的两个文件夹
|
||||
label_path_train = os.create_folder(train_path, 'labels', 'train')
|
||||
label_path_val = os.create_folder(train_path, 'labels', 'val')
|
||||
label_path_train = osu.create_folder(train_path, 'labels', 'train')
|
||||
label_path_val = osu.create_folder(train_path, 'labels', 'val')
|
||||
|
||||
# 在根目录下创建yaml文件
|
||||
yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml')
|
||||
yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml')
|
||||
yaml_data = {
|
||||
'path': train_path,
|
||||
'train': 'images/train',
|
||||
@ -59,20 +64,20 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
|
||||
|
||||
# 开始循环复制图片和生成label.txt
|
||||
# 先操作train
|
||||
operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
|
||||
await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
|
||||
# 再操作val
|
||||
operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
|
||||
await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
|
||||
|
||||
# 开始执行异步训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, proj_info.project_no)
|
||||
project = osu.file_path(runs_url, proj_info.project_no)
|
||||
name = version_path
|
||||
|
||||
return data, project, name
|
||||
|
||||
|
||||
async def operate_img_label(
|
||||
img_list: list[proj_models.ProjectImgLabel],
|
||||
img_list: list[proj_models.ProjectImage],
|
||||
img_path: str,
|
||||
label_path: str,
|
||||
db: AsyncSession,
|
||||
@ -90,10 +95,10 @@ async def operate_img_label(
|
||||
image = img_list[i]
|
||||
# 先复制图片,并把图片改名,不改后缀
|
||||
file_name = 'image' + str(i)
|
||||
os.copy_and_rename_file(image.image_url, img_path, file_name)
|
||||
osu.copy_and_rename_file(image.image_url, img_path, file_name)
|
||||
# 查询这张图片的label信息然后生成这张照片的txt文件
|
||||
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
|
||||
label_txt_path = os.file_path(label_path, file_name + '.txt')
|
||||
label_txt_path = osu.file_path(label_path, file_name + '.txt')
|
||||
with open(label_txt_path, 'w', encoding='utf-8') as file:
|
||||
for image_label in img_label_list:
|
||||
index = label_id_list.index(image_label.label_id)
|
||||
@ -103,20 +108,19 @@ async def operate_img_label(
|
||||
+ image_label.mark_height + '\n')
|
||||
|
||||
|
||||
async def run_event_loop(
|
||||
def run_event_loop(
|
||||
data: str,
|
||||
project: str,
|
||||
name: str,
|
||||
train_in: schemas.ProjectTrainIn,
|
||||
project_id: int,
|
||||
db: AsyncSession):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id,
|
||||
project_id, db))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
train_info: models.ProjectTrain,
|
||||
is_gup: str):
|
||||
# 运行异步函数,开始训练
|
||||
loop_run = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop_run)
|
||||
loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience,
|
||||
project_id, train_info, is_gup))
|
||||
|
||||
|
||||
async def run_commend(
|
||||
@ -125,10 +129,9 @@ async def run_commend(
|
||||
name: str,
|
||||
epochs: int,
|
||||
patience: int,
|
||||
weights: str,
|
||||
project_id: int,
|
||||
db: AsyncSession,
|
||||
rd: Redis):
|
||||
train_info: models.ProjectTrain,
|
||||
is_gpu: str):
|
||||
"""
|
||||
执行训练
|
||||
:param data: 训练数据集
|
||||
@ -138,23 +141,20 @@ async def run_commend(
|
||||
:param patience: 早停耐心值
|
||||
:param weights: 权重文件
|
||||
:param project_id: 项目id
|
||||
:param db: 数据库session
|
||||
:param rd: redis连接
|
||||
:param train_info: 训练信息
|
||||
:param is_gpu: 是否是gpu环境
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
yolo_path = osu.file_path(yolo_url, 'train.py')
|
||||
room = 'train_' + str(project_id)
|
||||
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
|
||||
|
||||
# 增加权重文件,在之前训练的基础上重新巡逻
|
||||
if weights != '' and weights is not None:
|
||||
train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights))
|
||||
if train_info is not None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
# 增加权重文件,在之前训练的基础上重新训练
|
||||
if train_info is not None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device=0")
|
||||
@ -180,20 +180,33 @@ async def run_commend(
|
||||
await room_manager.send_to_room(room, 'error')
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
# 然后保存版本训练信息
|
||||
train = models.ProjectTrain()
|
||||
train.project_id = project_id
|
||||
train.train_version = name
|
||||
train_url = os.file_path(project, name)
|
||||
train.train_url = train_url
|
||||
train.train_data = data
|
||||
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||||
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if weights is not None and weights != '':
|
||||
train.weights_id = weights
|
||||
train.weights_name = train_info.train_version
|
||||
train.patience = patience
|
||||
train.epochs = epochs
|
||||
await crud.ProjectTrainDal(db).create_data(data=train)
|
||||
|
||||
|
||||
async def add_train(
|
||||
db,
|
||||
project_id,
|
||||
name,
|
||||
project,
|
||||
data,
|
||||
train_in,
|
||||
user_id):
|
||||
# 更新版本信息
|
||||
await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id)
|
||||
# 增加训练版本信息
|
||||
train = models.ProjectTrain()
|
||||
train.project_id = project_id
|
||||
train.train_version = name
|
||||
train_url = osu.file_path(project, name)
|
||||
train.train_url = train_url
|
||||
train.train_data = data
|
||||
train.user_id = user_id
|
||||
bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt')
|
||||
last_pt_path = osu.file_path(train_url, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if train_in is not None:
|
||||
train.weights_id = train_in.weights_id
|
||||
train.weights_name = train_in.weights_name
|
||||
train.patience = train_in.patience
|
||||
train.epochs = train_in.epochs
|
||||
await crud.ProjectTrainDal(db).create_model(data=train)
|
Reference in New Issue
Block a user