查询训练报告接口

This commit is contained in:
2025-03-11 11:42:09 +08:00
parent 7d736c4ac4
commit f49a6caf10
10 changed files with 143 additions and 39 deletions

View File

@ -246,19 +246,17 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)
data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,))
thread_train.start();
thread_train.start()
return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str,
name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id,
project_id, session))
train_in.patience, train_in.weights_id, project_id, session))
# 可选: 关闭循环
loop.close()
@ -276,4 +274,15 @@ def get_train_list(project_id: int, session: Session = Depends(get_db)):
return rc.response_success(data=result)
@project.get("/get_train_report/{train_id}")
def get_train_report(train_id: int, session: Session = Depends(get_db)):
"""
查询训练报告
:param train_id:
:param session:
:return:
"""
result_row = ps.get_train_result(train_id, session)
if result_row is None:
return rc.response_error("查询失败")
return rc.response_success(data=result_row)