提高版本到yolov11
This commit is contained in:
@ -25,8 +25,7 @@ app = APIRouter()
|
||||
@app.post("/start", summary="执行训练")
|
||||
async def run_train(
|
||||
train_in: schemas.ProjectTrainIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
proj_id = train_in.project_id
|
||||
proj_dal = ProjectInfoDal(auth.db)
|
||||
proj_img_dal = ProjectImageDal(auth.db)
|
||||
@ -48,14 +47,14 @@ async def run_train(
|
||||
if val_label_count > 0:
|
||||
return ErrorResponse("验证图片中存在未标注的图片")
|
||||
data, project, name = await service.before_train(proj_info, auth.db)
|
||||
is_gpu = await rd.get('is_gpu')
|
||||
train_info = None
|
||||
if train_in.weights_id is not None:
|
||||
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
|
||||
# 异步执行操作,操作过程通过websocket进行同步
|
||||
thread_train = threading.Thread(
|
||||
target=service.run_event_loop,
|
||||
args=(data, project, name, train_in, proj_id, train_info, is_gpu))
|
||||
target=service.run_commend,
|
||||
args=(data, project, name, train_in.epochs, train_in.patience, train_info)
|
||||
)
|
||||
thread_train.start()
|
||||
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
Reference in New Issue
Block a user