完成线程启动训练并以websocket的方式进行发送shell
This commit is contained in:
@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
|
|||||||
from app.model.crud import project_info_crud as pic
|
from app.model.crud import project_info_crud as pic
|
||||||
from app.model.crud import project_image_crud as pimc
|
from app.model.crud import project_image_crud as pimc
|
||||||
from app.model.crud import project_train_crud as ptnc
|
from app.model.crud import project_train_crud as ptnc
|
||||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
|
||||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
|
||||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
|
||||||
from app.model.bussiness_model import ProjectLabel as pl
|
from app.model.bussiness_model import ProjectLabel as pl
|
||||||
|
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||||
|
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||||
|
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
||||||
|
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||||
from app.common.jwt_check import get_user_id
|
from app.common.jwt_check import get_user_id
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.service import project_train_service as ps
|
from app.service import project_train_service as ps
|
||||||
from app.db.db_session import get_db
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if image.pagerNum is None and image.pagerSize is None:
|
if image.pagerNum is None and image.pagerSize is None:
|
||||||
image_list = pimc.get_image_list(image.project_id, session)
|
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
|
||||||
result = jsonable_encoder(image_list)
|
result = jsonable_encoder(image_list)
|
||||||
return rc.response_success(data=result)
|
return rc.response_success(data=result)
|
||||||
else:
|
else:
|
||||||
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_success(data=img_leafer_out['leafer'])
|
return rc.response_success(data=img_leafer_out['leafer'])
|
||||||
|
|
||||||
|
|
||||||
@project.get("/run_train/{project_id}")
|
@project.post("/run_train")
|
||||||
async def run_train(project_id: int, session: Session = Depends(get_db)):
|
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
执行项目训练方法
|
执行项目训练方法
|
||||||
:param project_id:
|
:param train_in:
|
||||||
:param session:
|
:param session:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
project_id = train_in.project_id
|
||||||
project_info = pic.get_project_by_id(project_id, session)
|
project_info = pic.get_project_by_id(project_id, session)
|
||||||
if project_info is None:
|
if project_info is None:
|
||||||
return rc.response_error("项目查询错误")
|
return rc.response_error("项目查询错误")
|
||||||
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_error("请先上传验证图片")
|
return rc.response_error("请先上传验证图片")
|
||||||
if val_img_count < 5:
|
if val_img_count < 5:
|
||||||
return rc.response_error("验证图片少于5张,请继续上传验证图片")
|
return rc.response_error("验证图片少于5张,请继续上传验证图片")
|
||||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
||||||
return StreamingResponse(
|
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
||||||
ps.run_commend(data, project_name, name, 50, project_id, session),
|
project_id, session,))
|
||||||
media_type="text/plain")
|
thread_train.start();
|
||||||
|
return rc.response_success(msg="执行成功")
|
||||||
|
|
||||||
|
|
||||||
|
def run_event_loop(data: str, project: str,
|
||||||
|
name: str, train_in: ProjectTrainIn,
|
||||||
|
project_id: int, session: Session):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
# 运行异步函数
|
||||||
|
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
||||||
|
train_in.patience, train_in.weights_id,
|
||||||
|
project_id, session))
|
||||||
|
# 可选: 关闭循环
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@project.get("/get_train_list/{project_id}")
|
@project.get("/get_train_list/{project_id}")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
from app.websocket.web_socket_server import room_manager
|
from app.websocket.web_socket_server import room_manager
|
||||||
|
|
||||||
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_text()
|
data = await websocket.receive_text()
|
||||||
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"连接关闭: {e}")
|
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||||
|
await websocket.close(code=1000)
|
||||||
finally:
|
finally:
|
||||||
await room_manager.remove_from_room(room, websocket)
|
await room_manager.remove_from_room(room, websocket)
|
||||||
await websocket.close()
|
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||||
|
await websocket.close(code=1001)
|
||||||
|
@ -44,17 +44,17 @@ my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
|
|||||||
|
|
||||||
|
|
||||||
# fastapi定时执行任务
|
# fastapi定时执行任务
|
||||||
async def periodic_task():
|
async def serve_info_task():
|
||||||
while True:
|
while True:
|
||||||
server_json = get_server_json()
|
server_json = get_server_json()
|
||||||
await room_manager.send_to_room('server-info', server_json)
|
await room_manager.send_to_room('server-info', server_json)
|
||||||
await asyncio.sleep(2) # 每 10 秒执行
|
await asyncio.sleep(5) # 每 5 秒执行
|
||||||
|
|
||||||
|
|
||||||
@my_app.on_event("startup")
|
@my_app.on_event("startup")
|
||||||
async def start_periodic_task():
|
async def start_periodic_task():
|
||||||
# 在后台启动异步任务
|
# 在后台启动异步任务
|
||||||
asyncio.create_task(periodic_task())
|
asyncio.create_task(serve_info_task())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from fastapi import status
|
|||||||
from app.db.page_util import Pager
|
from app.db.page_util import Pager
|
||||||
|
|
||||||
|
|
||||||
def response_code_view(code: int,msg: str) -> Response:
|
def response_code_view(code: int, msg: str) -> Response:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=code,
|
status_code=code,
|
||||||
content={
|
content={
|
||||||
|
@ -81,6 +81,10 @@ class ProjectTrain(DbCommon):
|
|||||||
__tablename__ = "project_train"
|
__tablename__ = "project_train"
|
||||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
weights_id: Mapped[int] = mapped_column(Integer)
|
||||||
|
weights_name: Mapped[str] = mapped_column(String(32))
|
||||||
|
epochs: Mapped[int] = mapped_column(Integer)
|
||||||
|
patience: Mapped[int] = mapped_column(Integer)
|
||||||
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
|
@ -3,10 +3,20 @@ from pydantic import BaseModel, Field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectTrainIn(BaseModel):
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
weights_id: Optional[int] = Field(None, description="权重文件")
|
||||||
|
epochs: Optional[int] = Field(50, description="训练轮数")
|
||||||
|
patience: Optional[int] = Field(20, description="早停的耐心值")
|
||||||
|
|
||||||
|
|
||||||
class ProjectTrainOut(BaseModel):
|
class ProjectTrainOut(BaseModel):
|
||||||
"""项目训练版本信息表"""
|
"""项目训练版本信息表"""
|
||||||
id: Optional[int] = Field(None, description="训练id")
|
id: Optional[int] = Field(None, description="训练id")
|
||||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||||
|
weights_name: Optional[str] = Field(None, description="权重名称")
|
||||||
|
epochs: Optional[int] = Field(None, description="训练轮数")
|
||||||
|
patience: Optional[int] = Field(None, description="早停的耐心值")
|
||||||
create_time: Optional[datetime] = Field(None, description="训练时间")
|
create_time: Optional[datetime] = Field(None, description="训练时间")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||||||
|
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||||
from app.model.crud import project_info_crud as pic
|
from app.model.crud import project_info_crud as pic
|
||||||
from app.model.crud import project_image_crud as pimc
|
from app.model.crud import project_image_crud as pimc
|
||||||
from app.model.crud import project_label_crud as plc
|
from app.model.crud import project_label_crud as plc
|
||||||
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
|
|||||||
from app.util import os_utils as os
|
from app.util import os_utils as os
|
||||||
from app.util import random_utils as ru
|
from app.util import random_utils as ru
|
||||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||||
|
from app.websocket.web_socket_server import room_manager
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
import yaml
|
import yaml
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||||
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
|
|||||||
return img_leafer_out
|
return img_leafer_out
|
||||||
|
|
||||||
|
|
||||||
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
|
||||||
"""
|
"""
|
||||||
yolov5执行训练任务
|
yolov5执行训练任务
|
||||||
|
:param train_in: 训练参数
|
||||||
:param project_info: 项目信息
|
:param project_info: 项目信息
|
||||||
:param session: 数据库session
|
:param session: 数据库session
|
||||||
:return:
|
:return:
|
||||||
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
|||||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||||
pic.update_project_status(project_info.id, '1', session)
|
pic.update_project_status(project_info.id, '1', session)
|
||||||
|
|
||||||
# 开始训练
|
# 开始执行异步训练
|
||||||
data = yaml_file
|
data = yaml_file
|
||||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
project = os.file_path(runs_url, project_info.project_no)
|
||||||
name = version_path
|
name = version_path
|
||||||
|
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
|
||||||
|
# train_in.patience, train_in.weights_id,
|
||||||
|
# train_in.project_id, session,))
|
||||||
|
# thread_train.start();
|
||||||
return data, project, name
|
return data, project, name
|
||||||
|
|
||||||
|
|
||||||
def run_commend(data: str, project: str,
|
async def run_commend(data: str, project: str,
|
||||||
name: str, epochs: int,
|
name: str, epochs: int, patience: int, weights: str,
|
||||||
project_id: int, session: Session):
|
project_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
执行训练
|
||||||
|
:param data: 训练数据集
|
||||||
|
:param project: 训练结果的项目目录
|
||||||
|
:param name: 实验名称
|
||||||
|
:param epochs: 训练轮数
|
||||||
|
:param patience: 早停耐心值
|
||||||
|
:param weights: 权重文件
|
||||||
|
:param project_id: 项目id
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||||
|
room = 'train_' + str(project_id)
|
||||||
yield f"stdout: 模型训练开始,请稍等。。。\n"
|
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||||
|
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||||
|
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
|
||||||
|
if weights != None and weights != '':
|
||||||
|
train_info = ptc.get_train(weights, session)
|
||||||
|
if train_info != None:
|
||||||
|
commend.append("--weights=" + train_info.best_pt)
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
["python", '-u', yolo_path,
|
commend,
|
||||||
"--data=" + data,
|
|
||||||
"--project=" + project,
|
|
||||||
"--name=" + name,
|
|
||||||
"--epochs=" + str(epochs),
|
|
||||||
"--batch-size=4",
|
|
||||||
"--exist-ok"],
|
|
||||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||||
shell=False,
|
shell=False,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
|
|||||||
line = process.stdout.readline()
|
line = process.stdout.readline()
|
||||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||||
if line != '\n':
|
if line != '\n':
|
||||||
yield line
|
await room_manager.send_to_room(room, line)
|
||||||
|
|
||||||
# 等待进程结束并获取返回码
|
# 等待进程结束并获取返回码
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
pic.update_project_status(project_id, '-1', session)
|
pic.update_project_status(project_id, '-1', session)
|
||||||
else:
|
else:
|
||||||
|
await room_manager.send_to_room(room, 'success')
|
||||||
pic.update_project_status(project_id, '2', session)
|
pic.update_project_status(project_id, '2', session)
|
||||||
# 然后保存版本训练信息
|
# 然后保存版本训练信息
|
||||||
train = ProjectTrain()
|
train = ProjectTrain()
|
||||||
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
|
|||||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||||
train.best_pt = bast_pt_path
|
train.best_pt = bast_pt_path
|
||||||
train.last_pt = last_pt_path
|
train.last_pt = last_pt_path
|
||||||
|
if weights != None and weights != '':
|
||||||
|
train.weights_id = weights
|
||||||
|
train.weights_name = train_info.train_version
|
||||||
|
train.patience = patience
|
||||||
|
train.epochs = epochs
|
||||||
ptc.add_train(train, session)
|
ptc.add_train(train, session)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import psutil
|
|||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
# 获取服务器运行状态
|
||||||
|
|
||||||
|
|
||||||
def get_server_info():
|
def get_server_info():
|
||||||
@ -74,4 +75,4 @@ def get_server_info():
|
|||||||
|
|
||||||
def get_server_json():
|
def get_server_json():
|
||||||
server_info = get_server_info()
|
server_info = get_server_info()
|
||||||
return json.dumps(server_info, indent=2, ensure_ascii=False)
|
return json.dumps(server_info, indent=2, ensure_ascii=False)
|
||||||
|
@ -10,6 +10,7 @@ python-multipart==0.0.5
|
|||||||
redis~=4.1.4
|
redis~=4.1.4
|
||||||
SQLAlchemy~=2.0.34
|
SQLAlchemy~=2.0.34
|
||||||
uvicorn~=0.17.5
|
uvicorn~=0.17.5
|
||||||
|
uvicorn[standard]
|
||||||
loguru~=0.6.0
|
loguru~=0.6.0
|
||||||
xlrd~=2.0.1
|
xlrd~=2.0.1
|
||||||
bcrypt==3.2.0
|
bcrypt==3.2.0
|
||||||
|
Reference in New Issue
Block a user