查询训练报告接口
This commit is contained in:
@ -1,16 +1,18 @@
|
|||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db.db_session import get_db
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.model.crud import project_detect_crud as pdc
|
from app.model.crud import project_detect_crud as pdc
|
||||||
from app.service import project_detect_service as pds
|
from app.service import project_detect_service as pds
|
||||||
from app.model.crud.project_train_crud import get_train
|
from app.model.crud.project_train_crud import get_train
|
||||||
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
|
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
|
||||||
ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager
|
ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager
|
||||||
from app.db.db_session import get_db
|
|
||||||
|
|
||||||
detect = APIRouter()
|
detect = APIRouter()
|
||||||
|
|
||||||
@ -116,11 +118,22 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend
|
|||||||
if train is None:
|
if train is None:
|
||||||
return rc.response_error("训练权重不存在")
|
return rc.response_error("训练权重不存在")
|
||||||
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
|
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
|
||||||
return StreamingResponse(pds.run_commend(detect_log.pt_url,
|
thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url,
|
||||||
detect_log.folder_url,
|
detect_log.folder_url,
|
||||||
detect_log.detect_folder_url,
|
detect_log.detect_folder_url,
|
||||||
detect_log.detect_version,
|
detect_log.detect_version,
|
||||||
detect_log.id, detect_log.detect_id, session), media_type="text/plain")
|
detect_log.id, detect_log.detect_id, session,))
|
||||||
|
thread_train.start()
|
||||||
|
return rc.response_success(msg="执行成功")
|
||||||
|
|
||||||
|
|
||||||
|
def run_event_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
# 运行异步函数
|
||||||
|
loop.run_until_complete(pds.run_commend(weights, source, project, name, log_id, detect_id, session))
|
||||||
|
# 可选: 关闭循环
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@detect.post("/get_log_pager")
|
@detect.post("/get_log_pager")
|
||||||
|
@ -246,19 +246,17 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)
|
|||||||
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
||||||
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
||||||
project_id, session,))
|
project_id, session,))
|
||||||
thread_train.start();
|
thread_train.start()
|
||||||
return rc.response_success(msg="执行成功")
|
return rc.response_success(msg="执行成功")
|
||||||
|
|
||||||
|
|
||||||
def run_event_loop(data: str, project: str,
|
def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn,
|
||||||
name: str, train_in: ProjectTrainIn,
|
|
||||||
project_id: int, session: Session):
|
project_id: int, session: Session):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
# 运行异步函数
|
# 运行异步函数
|
||||||
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
||||||
train_in.patience, train_in.weights_id,
|
train_in.patience, train_in.weights_id, project_id, session))
|
||||||
project_id, session))
|
|
||||||
# 可选: 关闭循环
|
# 可选: 关闭循环
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
@ -276,4 +274,15 @@ def get_train_list(project_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_success(data=result)
|
return rc.response_success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@project.get("/get_train_report/{train_id}")
|
||||||
|
def get_train_report(train_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
查询训练报告
|
||||||
|
:param train_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result_row = ps.get_train_result(train_id, session)
|
||||||
|
if result_row is None:
|
||||||
|
return rc.response_error("查询失败")
|
||||||
|
return rc.response_success(data=result_row)
|
||||||
|
@ -20,7 +20,7 @@ async def websocket_room(websocket: WebSocket, room: str):
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_text()
|
data = await websocket.receive_text()
|
||||||
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||||
await websocket.close(code=1000)
|
await websocket.close(code=1000)
|
||||||
|
@ -8,3 +8,8 @@ from app import my_app
|
|||||||
@my_app.exception_handlers(HTTPException)
|
@my_app.exception_handlers(HTTPException)
|
||||||
async def http_exception(request: Request, he: HTTPException):
|
async def http_exception(request: Request, he: HTTPException):
|
||||||
return response_error(request.url + "出现异常:" + he.detail)
|
return response_error(request.url + "出现异常:" + he.detail)
|
||||||
|
|
||||||
|
|
||||||
|
@my_app.exception_handlers(Exception)
|
||||||
|
async def http_exception(request: Request, he: Exception):
|
||||||
|
return response_error(request.url + "出现异常:" + he.detail)
|
||||||
|
@ -81,6 +81,7 @@ class ProjectTrain(DbCommon):
|
|||||||
__tablename__ = "project_train"
|
__tablename__ = "project_train"
|
||||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
train_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
weights_id: Mapped[int] = mapped_column(Integer)
|
weights_id: Mapped[int] = mapped_column(Integer)
|
||||||
weights_name: Mapped[str] = mapped_column(String(32))
|
weights_name: Mapped[str] = mapped_column(String(32))
|
||||||
epochs: Mapped[int] = mapped_column(Integer)
|
epochs: Mapped[int] = mapped_column(Integer)
|
||||||
|
@ -11,6 +11,7 @@ from app.config.config_reader import detect_url
|
|||||||
from app.util import os_utils as os
|
from app.util import os_utils as os
|
||||||
from app.util import random_utils as ru
|
from app.util import random_utils as ru
|
||||||
from app.config.config_reader import yolo_url
|
from app.config.config_reader import yolo_url
|
||||||
|
from app.websocket.web_socket_server import room_manager
|
||||||
|
|
||||||
|
|
||||||
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||||
@ -133,14 +134,12 @@ def run_commend(weights: str, source: str, project: str, name: str,
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||||
yield f"stdout: 模型推理开始,请稍等。。。 \n"
|
room = 'detect_' + str(detect_id)
|
||||||
|
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||||
|
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
["python", '-u', yolo_path,
|
commend,
|
||||||
"--weights", weights,
|
|
||||||
"--source", source,
|
|
||||||
"--name", name,
|
|
||||||
"--project", project],
|
|
||||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||||
shell=False,
|
shell=False,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
@ -152,13 +151,14 @@ def run_commend(weights: str, source: str, project: str, name: str,
|
|||||||
line = process.stdout.readline()
|
line = process.stdout.readline()
|
||||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||||
if line != '\n':
|
if line != '\n':
|
||||||
yield line
|
await room_manager.send_to_room(room, line + '\n')
|
||||||
|
|
||||||
# 等待进程结束并获取返回码
|
# 等待进程结束并获取返回码
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
pdc.update_detect_status(detect_id, -1, session)
|
pdc.update_detect_status(detect_id, -1, session)
|
||||||
else:
|
else:
|
||||||
|
await room_manager.send_to_room(room, 'success')
|
||||||
pdc.update_detect_status(detect_id, 2, session)
|
pdc.update_detect_status(detect_id, 2, session)
|
||||||
detect_imgs = pdc.get_img_list(detect_id, session)
|
detect_imgs = pdc.get_img_list(detect_id, session)
|
||||||
detect_log_imgs = []
|
detect_log_imgs = []
|
||||||
|
@ -11,13 +11,13 @@ from app.util import os_utils as os
|
|||||||
from app.util import random_utils as ru
|
from app.util import random_utils as ru
|
||||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||||
from app.websocket.web_socket_server import room_manager
|
from app.websocket.web_socket_server import room_manager
|
||||||
|
from app.util.csv_utils import read_csv
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
import yaml
|
import yaml
|
||||||
import subprocess
|
import subprocess
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||||
@ -221,7 +221,7 @@ async def run_commend(data: str, project: str,
|
|||||||
line = process.stdout.readline()
|
line = process.stdout.readline()
|
||||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||||
if line != '\n':
|
if line != '\n':
|
||||||
await room_manager.send_to_room(room, line)
|
await room_manager.send_to_room(room, line + '\n')
|
||||||
|
|
||||||
# 等待进程结束并获取返回码
|
# 等待进程结束并获取返回码
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
@ -234,8 +234,10 @@ async def run_commend(data: str, project: str,
|
|||||||
train = ProjectTrain()
|
train = ProjectTrain()
|
||||||
train.project_id = project_id
|
train.project_id = project_id
|
||||||
train.train_version = name
|
train.train_version = name
|
||||||
bast_pt_path = os.file_path(project, name, 'weights', 'best.pt')
|
train_url = os.file_path(project, name)
|
||||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
train.train_url = train_url
|
||||||
|
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||||||
|
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||||||
train.best_pt = bast_pt_path
|
train.best_pt = bast_pt_path
|
||||||
train.last_pt = last_pt_path
|
train.last_pt = last_pt_path
|
||||||
if weights != None and weights != '':
|
if weights != None and weights != '':
|
||||||
@ -275,19 +277,76 @@ def operate_img_label(img_list: List[ProjectImgLabel],
|
|||||||
+ image_label.mark_height + '\n')
|
+ image_label.mark_height + '\n')
|
||||||
|
|
||||||
|
|
||||||
def split_array(data):
|
def get_train_result(train_id: int, session: Session):
|
||||||
"""
|
"""
|
||||||
将数组按照3:1的比例切分成两个数组。
|
根据result.csv文件查询训练报告
|
||||||
:param data: 原始数组
|
:param train_id:
|
||||||
:return: 按照3:1比例分割后的两个数组
|
:param session:
|
||||||
|
:return:
|
||||||
"""
|
"""
|
||||||
total_length = len(data)
|
train_info = ptc.get_train(train_id, session)
|
||||||
if total_length < 4:
|
if train_info is None:
|
||||||
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
|
return None
|
||||||
# 计算分割点
|
result_csv_path = os.file_path(train_info.train_url, 'results.csv')
|
||||||
split_index = (total_length * 3) // 4
|
result_row = read_csv(result_csv_path)
|
||||||
# 使用切片分割数组
|
report_data = {}
|
||||||
part1 = data[:split_index]
|
# 轮数
|
||||||
part2 = data[split_index:]
|
epoch_data = []
|
||||||
return part1, part2
|
# 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。
|
||||||
|
train_box_loss = []
|
||||||
|
# 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。
|
||||||
|
train_obj_loss = []
|
||||||
|
# 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。
|
||||||
|
train_cls_loss = []
|
||||||
|
|
||||||
|
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
|
||||||
|
val_box_loss = []
|
||||||
|
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
|
||||||
|
val_obj_loss = []
|
||||||
|
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
|
||||||
|
val_cls_loss = []
|
||||||
|
|
||||||
|
# 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。
|
||||||
|
m_p = []
|
||||||
|
# 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。
|
||||||
|
m_r = []
|
||||||
|
|
||||||
|
# 主干网络(Backbone)的学习率。
|
||||||
|
x_lr0 = []
|
||||||
|
# 检测头(Head)的学习率。
|
||||||
|
x_lr1 = []
|
||||||
|
|
||||||
|
for row in result_row:
|
||||||
|
epoch_data.append(row[0].strip())
|
||||||
|
|
||||||
|
train_box_loss.append(row[1].strip())
|
||||||
|
train_obj_loss.append(row[2].strip())
|
||||||
|
train_cls_loss.append(row[3].strip())
|
||||||
|
|
||||||
|
val_box_loss.append(row[8].strip())
|
||||||
|
val_obj_loss.append(row[9].strip())
|
||||||
|
val_cls_loss.append(row[10].strip())
|
||||||
|
|
||||||
|
m_p.append(row[4].strip())
|
||||||
|
m_r.append(row[5].strip())
|
||||||
|
|
||||||
|
x_lr0.append(row[11].strip())
|
||||||
|
x_lr1.append(row[12].strip())
|
||||||
|
|
||||||
|
report_data['epoch_data'] = epoch_data
|
||||||
|
|
||||||
|
report_data['train_box_loss'] = train_box_loss
|
||||||
|
report_data['train_obj_loss'] = train_obj_loss
|
||||||
|
report_data['train_cls_loss'] = train_cls_loss
|
||||||
|
|
||||||
|
report_data['val_box_loss'] = val_box_loss
|
||||||
|
report_data['val_obj_loss'] = val_obj_loss
|
||||||
|
report_data['val_cls_loss'] = val_cls_loss
|
||||||
|
|
||||||
|
report_data['m_p'] = m_p
|
||||||
|
report_data['m_r'] = m_r
|
||||||
|
|
||||||
|
report_data['x_lr0'] = x_lr0
|
||||||
|
report_data['x_lr1'] = x_lr1
|
||||||
|
|
||||||
|
return report_data
|
||||||
|
19
app/util/csv_utils.py
Normal file
19
app/util/csv_utils.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
def read_csv(file_path):
|
||||||
|
"""
|
||||||
|
根据文件路径读取csv文件
|
||||||
|
:param file_path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
# 创建 CSV 阅读器对象
|
||||||
|
csv_reader = csv.reader(file)
|
||||||
|
# 跳过标题行
|
||||||
|
next(csv_reader)
|
||||||
|
result_row = []
|
||||||
|
for row in csv_reader:
|
||||||
|
result_row.append(row)
|
||||||
|
return result_row
|
||||||
|
return None
|
@ -2,7 +2,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
"""文件处理相关的util"""
|
|
||||||
|
|
||||||
|
|
||||||
def file_path(*path):
|
def file_path(*path):
|
||||||
|
@ -2,11 +2,10 @@ import psutil
|
|||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
# 获取服务器运行状态
|
|
||||||
|
|
||||||
|
|
||||||
def get_server_info():
|
def get_server_info():
|
||||||
|
# 获取服务器运行状态
|
||||||
info = {}
|
info = {}
|
||||||
|
|
||||||
# 1. 系统基本信息
|
# 1. 系统基本信息
|
||||||
|
Reference in New Issue
Block a user