主题功能完成

This commit is contained in:
2025-02-24 11:05:17 +08:00
parent e56e03f545
commit 47a9164083
43 changed files with 10352 additions and 69 deletions

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import select
import subprocess
@ -113,12 +114,6 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 在根目录创建classes.txt文件
classes_txt = os.file_path(train_path, 'classes.txt')
with open(classes_txt, 'w', encoding='utf-8') as file:
for label_name in label_name_list:
file.write(label_name + '\n')
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
@ -152,35 +147,34 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
epochs = 10
yolo_path = os.file_path(yolo_url, 'train.py')
return data, project, name, epochs, yolo_path, version_path
return data, project, name
def run_commend(data: str, project: str,
name: str, epochs: int,
yolo_path: str, version_path: str,
project_id: int, session: Session):
yolo_path = os.file_path(yolo_url, 'train.py')
yield f"stdout: 模型训练开始,请稍等。。。"
# 启动子进程
with subprocess.Popen(
["python", yolo_path,
["python", '-u', yolo_path,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs)],
shell=True,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
yield line
# 等待进程结束并获取返回码
return_code = process.wait()
@ -191,7 +185,7 @@ def run_commend(data: str, project: str,
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_id
train.train_version = version_path
train.train_version = name
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
train.best_pt = bast_pt_path