补全内容
This commit is contained in:
@ -159,4 +159,7 @@ def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_error("项目查询错误")
|
||||
if project_info.project_status == '1':
|
||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||
return StreamingResponse(ps.run_train_yolo(project_info, session), media_type="text/plain")
|
||||
data, project, name, epochs, yolo_path, version_path = ps.run_train_yolo(project_info, session)
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project, name, epochs, yolo_path, version_path, project_id, session),
|
||||
media_type="text/plain")
|
||||
|
@ -51,14 +51,12 @@ def update_project_status(project_id: int, project_status: str, session: Session
|
||||
:return:
|
||||
"""
|
||||
if project_status == '2':
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status,
|
||||
session.query(ProjectInfo).filter_by(id=project_id).update({
|
||||
'project_status': project_status,
|
||||
'train_version': ProjectInfo.train_version + 1
|
||||
})
|
||||
session.execute(stmt)
|
||||
else:
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status
|
||||
session.query(ProjectInfo).filter_by(id=project_id).update({
|
||||
'project_status': project_status
|
||||
})
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
@ -146,7 +146,8 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
|
||||
|
||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||
pic.update_project_status(project_info.id, '1')
|
||||
pic.update_project_status(project_info.id, '1', session)
|
||||
|
||||
# 开始训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
@ -154,9 +155,20 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
epochs = 10
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
return data, project, name, epochs, yolo_path, version_path
|
||||
|
||||
|
||||
def run_commend(data: str, project: str,
|
||||
name: str, epochs: int,
|
||||
yolo_path: str, version_path: str,
|
||||
project_id: int, session: Session):
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
|
||||
["python", yolo_path,
|
||||
"--data=" + data,
|
||||
"--project=" + project,
|
||||
"--name=" + name,
|
||||
"--epochs=" + str(epochs)],
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
@ -173,12 +185,12 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pic.update_project_status(project_info.id, '-1', session)
|
||||
pic.update_project_status(project_id, '-1', session)
|
||||
else:
|
||||
pic.update_project_status(project_info.id, '2', session)
|
||||
pic.update_project_status(project_id, '2', session)
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
train.project_id = project_info.id
|
||||
train.project_id = project_id
|
||||
train.train_version = version_path
|
||||
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
|
||||
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
|
||||
|
Reference in New Issue
Block a user