主题功能完成
This commit is contained in:
@ -153,13 +153,19 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
||||
|
||||
|
||||
@project.get("/run_train/{project_id}")
|
||||
def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
"""
|
||||
执行项目训练方法
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
if project_info is None:
|
||||
return rc.response_error("项目查询错误")
|
||||
if project_info.project_status == '1':
|
||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||
data, project, name, epochs, yolo_path, version_path = ps.run_train_yolo(project_info, session)
|
||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project, name, epochs, yolo_path, version_path, project_id, session),
|
||||
ps.run_commend(data, project_name, name, 10, project_id, session),
|
||||
media_type="text/plain")
|
||||
|
@ -1,42 +0,0 @@
|
||||
import asyncio
|
||||
import subprocess
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
test = APIRouter()
|
||||
|
||||
|
||||
async def generate_data():
|
||||
for i in range(1, 10): # 生成 5 行数据
|
||||
await asyncio.sleep(1) # 等待 1 秒
|
||||
yield f"data: This is line {i}\n\n" # 返回 SSE 格式的数据
|
||||
|
||||
|
||||
def run_command(command):
|
||||
"""执行命令并实时打印每一行输出"""
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
|
||||
bufsize=1, # 行缓冲
|
||||
) as process:
|
||||
# 使用iter逐行读取stdout和stderr
|
||||
for line in process.stdout:
|
||||
yield f"stdout: {line.strip()} \n"
|
||||
|
||||
for line in process.stderr:
|
||||
yield f"stderr: {line.strip()} \n"
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
print(f"Process exited with non-zero code: {return_code}")
|
||||
|
||||
|
||||
@test.get("/stream")
|
||||
async def stream_response():
|
||||
return StreamingResponse(run_command(["ping", "-n", "10", "127.0.0.1"]), media_type="text/plain")
|
@ -9,7 +9,6 @@ from app.api.sys.login_api import login
|
||||
from app.api.sys.sys_user_api import user
|
||||
from app.api.business.project_api import project
|
||||
from app.api.common.view_img import view
|
||||
from app.api.common.test_api import test
|
||||
|
||||
my_app = FastAPI()
|
||||
|
||||
@ -36,5 +35,4 @@ my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
||||
my_app.include_router(view, prefix="/view_img", tags=["查看图片"])
|
||||
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
||||
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
|
||||
my_app.include_router(test, prefix="/test", tags=["测试用API"])
|
||||
|
||||
|
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import select
|
||||
import subprocess
|
||||
|
||||
|
||||
@ -113,12 +114,6 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
|
||||
|
||||
# 在根目录创建classes.txt文件
|
||||
classes_txt = os.file_path(train_path, 'classes.txt')
|
||||
with open(classes_txt, 'w', encoding='utf-8') as file:
|
||||
for label_name in label_name_list:
|
||||
file.write(label_name + '\n')
|
||||
|
||||
# 创建图片的的两个文件夹
|
||||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||||
@ -152,35 +147,34 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
name = version_path
|
||||
epochs = 10
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
return data, project, name, epochs, yolo_path, version_path
|
||||
return data, project, name
|
||||
|
||||
|
||||
def run_commend(data: str, project: str,
|
||||
name: str, epochs: int,
|
||||
yolo_path: str, version_path: str,
|
||||
project_id: int, session: Session):
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
yield f"stdout: 模型训练开始,请稍等。。。"
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", yolo_path,
|
||||
["python", '-u', yolo_path,
|
||||
"--data=" + data,
|
||||
"--project=" + project,
|
||||
"--name=" + name,
|
||||
"--epochs=" + str(epochs)],
|
||||
shell=True,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
|
||||
bufsize=1, # 行缓冲
|
||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||
encoding='utf-8',
|
||||
) as process:
|
||||
# 使用iter逐行读取stdout和stderr
|
||||
for line in process.stdout:
|
||||
yield f"stdout: {line.strip()} \n"
|
||||
|
||||
for line in process.stderr:
|
||||
yield f"stderr: {line.strip()} \n"
|
||||
while process.poll() is None:
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
yield line
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
@ -191,7 +185,7 @@ def run_commend(data: str, project: str,
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
train.project_id = project_id
|
||||
train.train_version = version_path
|
||||
train.train_version = name
|
||||
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
|
||||
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
|
Reference in New Issue
Block a user