提高版本到yolov11

This commit is contained in:
2025-06-09 15:27:45 +08:00
parent 86c6669593
commit 9e99b08d13
223 changed files with 333 additions and 43191 deletions

View File

@ -18,7 +18,7 @@ from typing import Optional
class ProjectTrainIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
weights_id: Optional[str] = Field(None, description="权重文件")
weights_id: Optional[int] = Field(None, description="权重文件")
weights_name: Optional[str] = Field(None, description="权重文件名称")
epochs: Optional[int] = Field(50, description="训练轮数")
patience: Optional[int] = Field(20, description="早停的耐心值")

View File

@ -1,13 +1,12 @@
from algo import YoloModel
from utils import os_utils as osu
from application.settings import *
from . import schemas, models, crud
from utils.websocket_server import room_manager
from apps.business.project import models as proj_models, crud as proj_crud
import yaml
import asyncio
import subprocess
from sqlalchemy.ext.asyncio import AsyncSession
@ -118,19 +117,16 @@ def run_event_loop(
# 运行异步函数,开始训练
loop_run = asyncio.new_event_loop()
asyncio.set_event_loop(loop_run)
loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience,
project_id, train_info, is_gup))
loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_info))
async def run_commend(
def run_commend(
data: str,
project: str,
name: str,
epochs: int,
patience: int,
project_id: int,
train_info: models.ProjectTrain,
is_gpu: str):
train_info: models.ProjectTrain):
"""
执行训练
:param data: 训练数据集
@ -138,47 +134,14 @@ async def run_commend(
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param train_info: 训练信息
:param is_gpu: 是否是gpu环境
:return:
"""
yolo_path = osu.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新训练
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n' and '0%' not in line and 'yolo' not in line:
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
if train_info is None:
model = YoloModel()
else:
model = YoloModel(train_info.best_pt)
model.train(data=data, epochs=epochs, project=project, name=name, patience=patience)
async def add_train(

View File

@ -25,8 +25,7 @@ app = APIRouter()
@app.post("/start", summary="执行训练")
async def run_train(
train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
auth: Auth = Depends(AllUserAuth())):
proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db)
@ -48,14 +47,14 @@ async def run_train(
if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = await service.before_train(proj_info, auth.db)
is_gpu = await rd.get('is_gpu')
train_info = None
if train_in.weights_id is not None:
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
# 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread(
target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, train_info, is_gpu))
target=service.run_commend,
args=(data, project, name, train_in.epochs, train_in.patience, train_info)
)
thread_train.start()
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
return SuccessResponse(msg="执行成功")