diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index e911db2..203c098 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -1,16 +1,18 @@ +import threading +import asyncio from typing import List from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session +from app.db.db_session import get_db from app.common import reponse_code as rc from app.model.crud import project_detect_crud as pdc from app.service import project_detect_service as pds from app.model.crud.project_train_crud import get_train from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\ ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager -from app.db.db_session import get_db detect = APIRouter() @@ -116,11 +118,22 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend if train is None: return rc.response_error("训练权重不存在") detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) - return StreamingResponse(pds.run_commend(detect_log.pt_url, + thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version, - detect_log.id, detect_log.detect_id, session), media_type="text/plain") + detect_log.id, detect_log.detect_id, session,)) + thread_train.start() + return rc.response_success(msg="执行成功") + + +def run_event_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + # 运行异步函数 + loop.run_until_complete(pds.run_commend(weights, source, project, name, log_id, detect_id, session)) + # 可选: 关闭循环 + loop.close() @detect.post("/get_log_pager") diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index d3e1b9f..be3d89b 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -246,19 +246,17 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db) data, project, name = ps.run_train_yolo(project_info, train_in, session) thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, project_id, session,)) - thread_train.start(); + thread_train.start() return rc.response_success(msg="执行成功") -def run_event_loop(data: str, project: str, - name: str, train_in: ProjectTrainIn, - project_id: int, session: Session): +def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn, + project_id: int, session: Session): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # 运行异步函数 loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs, - train_in.patience, train_in.weights_id, - project_id, session)) + train_in.patience, train_in.weights_id, project_id, session)) # 可选: 关闭循环 loop.close() @@ -276,4 +274,15 @@ def get_train_list(project_id: int, session: Session = Depends(get_db)): return rc.response_success(data=result) - +@project.get("/get_train_report/{train_id}") +def get_train_report(train_id: int, session: Session = Depends(get_db)): + """ + 查询训练报告 + :param train_id: + :param session: + :return: + """ + result_row = ps.get_train_result(train_id, session) + if result_row is None: + return rc.response_error("查询失败") + return rc.response_success(data=result_row) diff --git a/app/api/business/websocket_api.py b/app/api/business/websocket_api.py index 740f97f..6bfa93b 100644 --- a/app/api/business/websocket_api.py +++ b/app/api/business/websocket_api.py @@ -20,7 +20,7 @@ async def websocket_room(websocket: WebSocket, room: str): try: while True: data = await websocket.receive_text() - room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) + await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) except Exception as e: if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close(code=1000) diff --git a/app/application/exception_handler.py b/app/application/exception_handler.py index f5b394f..e76adc6 100644 --- a/app/application/exception_handler.py +++ b/app/application/exception_handler.py @@ -7,4 +7,9 @@ from app import my_app @my_app.exception_handlers(HTTPException) async def http_exception(request: Request, he: HTTPException): - return response_error(request.url + "出现异常:" + he.detail) \ No newline at end of file + return response_error(request.url + "出现异常:" + he.detail) + + +@my_app.exception_handlers(Exception) +async def http_exception(request: Request, he: Exception): + return response_error(request.url + "出现异常:" + he.detail) diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index 771c302..ec90be1 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -81,6 +81,7 @@ class ProjectTrain(DbCommon): __tablename__ = "project_train" project_id: Mapped[int] = mapped_column(Integer, nullable=False) train_version: Mapped[str] = mapped_column(String(32), nullable=False) + train_url: Mapped[str] = mapped_column(String(255), nullable=False) weights_id: Mapped[int] = mapped_column(Integer) weights_name: Mapped[str] = mapped_column(String(32)) epochs: Mapped[int] = mapped_column(Integer) diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py index 1f12060..e6cf170 100644 --- a/app/service/project_detect_service.py +++ b/app/service/project_detect_service.py @@ -11,6 +11,7 @@ from app.config.config_reader import detect_url from app.util import os_utils as os from app.util import random_utils as ru from app.config.config_reader import yolo_url +from app.websocket.web_socket_server import room_manager def add_detect(detect_in: ProjectDetectIn, session: Session): @@ -133,14 +134,12 @@ def run_commend(weights: str, source: str, project: str, name: str, :return: """ yolo_path = os.file_path(yolo_url, 'detect.py') - yield f"stdout: 模型推理开始,请稍等。。。 \n" + room = 'detect_' + str(detect_id) + await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n") + commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"] # 启动子进程 with subprocess.Popen( - ["python", '-u', yolo_path, - "--weights", weights, - "--source", source, - "--name", name, - "--project", project], + commend, bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存 shell=False, stdout=subprocess.PIPE, @@ -152,13 +151,14 @@ def run_commend(weights: str, source: str, project: str, name: str, line = process.stdout.readline() process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 if line != '\n': - yield line + await room_manager.send_to_room(room, line + '\n') # 等待进程结束并获取返回码 return_code = process.wait() if return_code != 0: pdc.update_detect_status(detect_id, -1, session) else: + await room_manager.send_to_room(room, 'success') pdc.update_detect_status(detect_id, 2, session) detect_imgs = pdc.get_img_list(detect_id, session) detect_log_imgs = [] diff --git a/app/service/project_train_service.py b/app/service/project_train_service.py index 7530617..61d8366 100644 --- a/app/service/project_train_service.py +++ b/app/service/project_train_service.py @@ -11,13 +11,13 @@ from app.util import os_utils as os from app.util import random_utils as ru from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from app.websocket.web_socket_server import room_manager +from app.util.csv_utils import read_csv from sqlalchemy.orm import Session from typing import List from fastapi import UploadFile import yaml import subprocess -import asyncio def add_project(info: ProjectInfoIn, session: Session, user_id: int): @@ -221,7 +221,7 @@ async def run_commend(data: str, project: str, line = process.stdout.readline() process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 if line != '\n': - await room_manager.send_to_room(room, line) + await room_manager.send_to_room(room, line + '\n') # 等待进程结束并获取返回码 return_code = process.wait() @@ -234,8 +234,10 @@ async def run_commend(data: str, project: str, train = ProjectTrain() train.project_id = project_id train.train_version = name - bast_pt_path = os.file_path(project, name, 'weights', 'best.pt') - last_pt_path = os.file_path(project, name, 'weights', 'last.pt') + train_url = os.file_path(project, name) + train.train_url = train_url + bast_pt_path = os.file_path(train_url, 'weights', 'best.pt') + last_pt_path = os.file_path(train_url, 'weights', 'last.pt') train.best_pt = bast_pt_path train.last_pt = last_pt_path if weights != None and weights != '': @@ -275,19 +277,76 @@ def operate_img_label(img_list: List[ProjectImgLabel], + image_label.mark_height + '\n') -def split_array(data): +def get_train_result(train_id: int, session: Session): """ - 将数组按照3:1的比例切分成两个数组。 - :param data: 原始数组 - :return: 按照3:1比例分割后的两个数组 + 根据result.csv文件查询训练报告 + :param train_id: + :param session: + :return: """ - total_length = len(data) - if total_length < 4: - raise ValueError("数组长度至少需要为4才能进行3:1的分割") - # 计算分割点 - split_index = (total_length * 3) // 4 - # 使用切片分割数组 - part1 = data[:split_index] - part2 = data[split_index:] - return part1, part2 + train_info = ptc.get_train(train_id, session) + if train_info is None: + return None + result_csv_path = os.file_path(train_info.train_url, 'results.csv') + result_row = read_csv(result_csv_path) + report_data = {} + # 轮数 + epoch_data = [] + # 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。 + train_box_loss = [] + # 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。 + train_obj_loss = [] + # 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。 + train_cls_loss = [] + # 验证集的边界框回归损失,反映模型在未见数据上的定位能力。 + val_box_loss = [] + # 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。 + val_obj_loss = [] + # 验证集的分类损失,反映模型在未见数据上的分类准确性。 + val_cls_loss = [] + + # 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。 + m_p = [] + # 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。 + m_r = [] + + # 主干网络(Backbone)的学习率。 + x_lr0 = [] + # 检测头(Head)的学习率。 + x_lr1 = [] + + for row in result_row: + epoch_data.append(row[0].strip()) + + train_box_loss.append(row[1].strip()) + train_obj_loss.append(row[2].strip()) + train_cls_loss.append(row[3].strip()) + + val_box_loss.append(row[8].strip()) + val_obj_loss.append(row[9].strip()) + val_cls_loss.append(row[10].strip()) + + m_p.append(row[4].strip()) + m_r.append(row[5].strip()) + + x_lr0.append(row[11].strip()) + x_lr1.append(row[12].strip()) + + report_data['epoch_data'] = epoch_data + + report_data['train_box_loss'] = train_box_loss + report_data['train_obj_loss'] = train_obj_loss + report_data['train_cls_loss'] = train_cls_loss + + report_data['val_box_loss'] = val_box_loss + report_data['val_obj_loss'] = val_obj_loss + report_data['val_cls_loss'] = val_cls_loss + + report_data['m_p'] = m_p + report_data['m_r'] = m_r + + report_data['x_lr0'] = x_lr0 + report_data['x_lr1'] = x_lr1 + + return report_data diff --git a/app/util/csv_utils.py b/app/util/csv_utils.py new file mode 100644 index 0000000..a336e98 --- /dev/null +++ b/app/util/csv_utils.py @@ -0,0 +1,19 @@ +import csv + + +def read_csv(file_path): + """ + 根据文件路径读取csv文件 + :param file_path: + :return: + """ + with open(file_path, 'r', encoding='utf-8') as file: + # 创建 CSV 阅读器对象 + csv_reader = csv.reader(file) + # 跳过标题行 + next(csv_reader) + result_row = [] + for row in csv_reader: + result_row.append(row) + return result_row + return None diff --git a/app/util/os_utils.py b/app/util/os_utils.py index 378f0c9..7f5cb01 100644 --- a/app/util/os_utils.py +++ b/app/util/os_utils.py @@ -2,7 +2,6 @@ import os import shutil from fastapi import UploadFile from PIL import Image -"""文件处理相关的util""" def file_path(*path): diff --git a/app/util/ps_util.py b/app/util/ps_util.py index be17760..3f296c2 100644 --- a/app/util/ps_util.py +++ b/app/util/ps_util.py @@ -2,11 +2,10 @@ import psutil import platform import json from datetime import datetime -# 获取服务器运行状态 def get_server_info(): - + # 获取服务器运行状态 info = {} # 1. 系统基本信息