查询训练报告接口

This commit is contained in:
2025-03-11 11:42:09 +08:00
parent 7d736c4ac4
commit f49a6caf10
10 changed files with 143 additions and 39 deletions

View File

@ -1,16 +1,18 @@
import threading
import asyncio
from typing import List from typing import List
from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.db_session import get_db
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.model.crud import project_detect_crud as pdc from app.model.crud import project_detect_crud as pdc
from app.service import project_detect_service as pds from app.service import project_detect_service as pds
from app.model.crud.project_train_crud import get_train from app.model.crud.project_train_crud import get_train
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\ from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager
from app.db.db_session import get_db
detect = APIRouter() detect = APIRouter()
@ -116,11 +118,22 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend
if train is None: if train is None:
return rc.response_error("训练权重不存在") return rc.response_error("训练权重不存在")
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
return StreamingResponse(pds.run_commend(detect_log.pt_url, thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url,
detect_log.folder_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_folder_url,
detect_log.detect_version, detect_log.detect_version,
detect_log.id, detect_log.detect_id, session), media_type="text/plain") detect_log.id, detect_log.detect_id, session,))
thread_train.start()
return rc.response_success(msg="执行成功")
def run_event_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(pds.run_commend(weights, source, project, name, log_id, detect_id, session))
# 可选: 关闭循环
loop.close()
@detect.post("/get_log_pager") @detect.post("/get_log_pager")

View File

@ -246,19 +246,17 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)
data, project, name = ps.run_train_yolo(project_info, train_in, session) data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,)) project_id, session,))
thread_train.start(); thread_train.start()
return rc.response_success(msg="执行成功") return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str, def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn,
name: str, train_in: ProjectTrainIn,
project_id: int, session: Session): project_id: int, session: Session):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
# 运行异步函数 # 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs, loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id, train_in.patience, train_in.weights_id, project_id, session))
project_id, session))
# 可选: 关闭循环 # 可选: 关闭循环
loop.close() loop.close()
@ -276,4 +274,15 @@ def get_train_list(project_id: int, session: Session = Depends(get_db)):
return rc.response_success(data=result) return rc.response_success(data=result)
@project.get("/get_train_report/{train_id}")
def get_train_report(train_id: int, session: Session = Depends(get_db)):
"""
查询训练报告
:param train_id:
:param session:
:return:
"""
result_row = ps.get_train_result(train_id, session)
if result_row is None:
return rc.response_error("查询失败")
return rc.response_success(data=result_row)

View File

@ -20,7 +20,7 @@ async def websocket_room(websocket: WebSocket, room: str):
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
except Exception as e: except Exception as e:
if websocket.client_state != WebSocketState.DISCONNECTED: if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1000) await websocket.close(code=1000)

View File

@ -8,3 +8,8 @@ from app import my_app
@my_app.exception_handlers(HTTPException) @my_app.exception_handlers(HTTPException)
async def http_exception(request: Request, he: HTTPException): async def http_exception(request: Request, he: HTTPException):
return response_error(request.url + "出现异常:" + he.detail) return response_error(request.url + "出现异常:" + he.detail)
@my_app.exception_handlers(Exception)
async def http_exception(request: Request, he: Exception):
return response_error(request.url + "出现异常:" + he.detail)

View File

@ -81,6 +81,7 @@ class ProjectTrain(DbCommon):
__tablename__ = "project_train" __tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False) project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False) train_version: Mapped[str] = mapped_column(String(32), nullable=False)
train_url: Mapped[str] = mapped_column(String(255), nullable=False)
weights_id: Mapped[int] = mapped_column(Integer) weights_id: Mapped[int] = mapped_column(Integer)
weights_name: Mapped[str] = mapped_column(String(32)) weights_name: Mapped[str] = mapped_column(String(32))
epochs: Mapped[int] = mapped_column(Integer) epochs: Mapped[int] = mapped_column(Integer)

View File

@ -11,6 +11,7 @@ from app.config.config_reader import detect_url
from app.util import os_utils as os from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import yolo_url from app.config.config_reader import yolo_url
from app.websocket.web_socket_server import room_manager
def add_detect(detect_in: ProjectDetectIn, session: Session): def add_detect(detect_in: ProjectDetectIn, session: Session):
@ -133,14 +134,12 @@ def run_commend(weights: str, source: str, project: str, name: str,
:return: :return:
""" """
yolo_path = os.file_path(yolo_url, 'detect.py') yolo_path = os.file_path(yolo_url, 'detect.py')
yield f"stdout: 模型推理开始,请稍等。。。 \n" room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
["python", '-u', yolo_path, commend,
"--weights", weights,
"--source", source,
"--name", name,
"--project", project],
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存 bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False, shell=False,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -152,13 +151,14 @@ def run_commend(weights: str, source: str, project: str, name: str,
line = process.stdout.readline() line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n': if line != '\n':
yield line await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码 # 等待进程结束并获取返回码
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
pdc.update_detect_status(detect_id, -1, session) pdc.update_detect_status(detect_id, -1, session)
else: else:
await room_manager.send_to_room(room, 'success')
pdc.update_detect_status(detect_id, 2, session) pdc.update_detect_status(detect_id, 2, session)
detect_imgs = pdc.get_img_list(detect_id, session) detect_imgs = pdc.get_img_list(detect_id, session)
detect_log_imgs = [] detect_log_imgs = []

View File

@ -11,13 +11,13 @@ from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
import yaml import yaml
import subprocess import subprocess
import asyncio
def add_project(info: ProjectInfoIn, session: Session, user_id: int): def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -221,7 +221,7 @@ async def run_commend(data: str, project: str,
line = process.stdout.readline() line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n': if line != '\n':
await room_manager.send_to_room(room, line) await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码 # 等待进程结束并获取返回码
return_code = process.wait() return_code = process.wait()
@ -234,8 +234,10 @@ async def run_commend(data: str, project: str,
train = ProjectTrain() train = ProjectTrain()
train.project_id = project_id train.project_id = project_id
train.train_version = name train.train_version = name
bast_pt_path = os.file_path(project, name, 'weights', 'best.pt') train_url = os.file_path(project, name)
last_pt_path = os.file_path(project, name, 'weights', 'last.pt') train.train_url = train_url
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path train.best_pt = bast_pt_path
train.last_pt = last_pt_path train.last_pt = last_pt_path
if weights != None and weights != '': if weights != None and weights != '':
@ -275,19 +277,76 @@ def operate_img_label(img_list: List[ProjectImgLabel],
+ image_label.mark_height + '\n') + image_label.mark_height + '\n')
def split_array(data): def get_train_result(train_id: int, session: Session):
""" """
将数组按照3:1的比例切分成两个数组。 根据result.csv文件查询训练报告
:param data: 原始数组 :param train_id:
:return: 按照3:1比例分割后的两个数组 :param session:
:return:
""" """
total_length = len(data) train_info = ptc.get_train(train_id, session)
if total_length < 4: if train_info is None:
raise ValueError("数组长度至少需要为4才能进行3:1的分割") return None
# 计算分割点 result_csv_path = os.file_path(train_info.train_url, 'results.csv')
split_index = (total_length * 3) // 4 result_row = read_csv(result_csv_path)
# 使用切片分割数组 report_data = {}
part1 = data[:split_index] # 轮数
part2 = data[split_index:] epoch_data = []
return part1, part2 # 边界框回归损失Bounding Box Loss衡量预测框位置中心坐标、宽高与真实框的差异值越低表示定位越准。
train_box_loss = []
# 目标置信度损失Objectness Loss衡量检测到目标的置信度误差即是否包含物体值越低表示模型越能正确判断有无物体。
train_obj_loss = []
# 分类损失Classification Loss衡量预测类别与真实类别的差异值越低表示分类越准。
train_cls_loss = []
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
val_box_loss = []
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
val_obj_loss = []
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
val_cls_loss = []
# 精确率Precision正确检测的正样本占所有预测为正样本的比例反映“误检率”。值越高说明误检越少。
m_p = []
# 召回率Recall正确检测的正样本占所有真实正样本的比例反映“漏检率”。值越高说明漏检越少。
m_r = []
# 主干网络Backbone的学习率。
x_lr0 = []
# 检测头Head的学习率。
x_lr1 = []
for row in result_row:
epoch_data.append(row[0].strip())
train_box_loss.append(row[1].strip())
train_obj_loss.append(row[2].strip())
train_cls_loss.append(row[3].strip())
val_box_loss.append(row[8].strip())
val_obj_loss.append(row[9].strip())
val_cls_loss.append(row[10].strip())
m_p.append(row[4].strip())
m_r.append(row[5].strip())
x_lr0.append(row[11].strip())
x_lr1.append(row[12].strip())
report_data['epoch_data'] = epoch_data
report_data['train_box_loss'] = train_box_loss
report_data['train_obj_loss'] = train_obj_loss
report_data['train_cls_loss'] = train_cls_loss
report_data['val_box_loss'] = val_box_loss
report_data['val_obj_loss'] = val_obj_loss
report_data['val_cls_loss'] = val_cls_loss
report_data['m_p'] = m_p
report_data['m_r'] = m_r
report_data['x_lr0'] = x_lr0
report_data['x_lr1'] = x_lr1
return report_data

19
app/util/csv_utils.py Normal file
View File

@ -0,0 +1,19 @@
import csv
def read_csv(file_path):
"""
根据文件路径读取csv文件
:param file_path:
:return:
"""
with open(file_path, 'r', encoding='utf-8') as file:
# 创建 CSV 阅读器对象
csv_reader = csv.reader(file)
# 跳过标题行
next(csv_reader)
result_row = []
for row in csv_reader:
result_row.append(row)
return result_row
return None

View File

@ -2,7 +2,6 @@ import os
import shutil import shutil
from fastapi import UploadFile from fastapi import UploadFile
from PIL import Image from PIL import Image
"""文件处理相关的util"""
def file_path(*path): def file_path(*path):

View File

@ -2,11 +2,10 @@ import psutil
import platform import platform
import json import json
from datetime import datetime from datetime import datetime
# 获取服务器运行状态
def get_server_info(): def get_server_info():
# 获取服务器运行状态
info = {} info = {}
# 1. 系统基本信息 # 1. 系统基本信息