查询训练报告接口
This commit is contained in:
@ -11,6 +11,7 @@ from app.config.config_reader import detect_url
|
||||
from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
|
||||
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||
@ -133,14 +134,12 @@ def run_commend(weights: str, source: str, project: str, name: str,
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
yield f"stdout: 模型推理开始,请稍等。。。 \n"
|
||||
room = 'detect_' + str(detect_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project", project, "--save-txt"]
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", '-u', yolo_path,
|
||||
"--weights", weights,
|
||||
"--source", source,
|
||||
"--name", name,
|
||||
"--project", project],
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
@ -152,13 +151,14 @@ def run_commend(weights: str, source: str, project: str, name: str,
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
yield line
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pdc.update_detect_status(detect_id, -1, session)
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
pdc.update_detect_status(detect_id, 2, session)
|
||||
detect_imgs = pdc.get_img_list(detect_id, session)
|
||||
detect_log_imgs = []
|
||||
|
@ -11,13 +11,13 @@ from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
from app.util.csv_utils import read_csv
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
import asyncio
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -221,7 +221,7 @@ async def run_commend(data: str, project: str,
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
await room_manager.send_to_room(room, line)
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
@ -234,8 +234,10 @@ async def run_commend(data: str, project: str,
|
||||
train = ProjectTrain()
|
||||
train.project_id = project_id
|
||||
train.train_version = name
|
||||
bast_pt_path = os.file_path(project, name, 'weights', 'best.pt')
|
||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||
train_url = os.file_path(project, name)
|
||||
train.train_url = train_url
|
||||
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||||
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if weights != None and weights != '':
|
||||
@ -275,19 +277,76 @@ def operate_img_label(img_list: List[ProjectImgLabel],
|
||||
+ image_label.mark_height + '\n')
|
||||
|
||||
|
||||
def split_array(data):
|
||||
def get_train_result(train_id: int, session: Session):
|
||||
"""
|
||||
将数组按照3:1的比例切分成两个数组。
|
||||
:param data: 原始数组
|
||||
:return: 按照3:1比例分割后的两个数组
|
||||
根据result.csv文件查询训练报告
|
||||
:param train_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
total_length = len(data)
|
||||
if total_length < 4:
|
||||
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
|
||||
# 计算分割点
|
||||
split_index = (total_length * 3) // 4
|
||||
# 使用切片分割数组
|
||||
part1 = data[:split_index]
|
||||
part2 = data[split_index:]
|
||||
return part1, part2
|
||||
train_info = ptc.get_train(train_id, session)
|
||||
if train_info is None:
|
||||
return None
|
||||
result_csv_path = os.file_path(train_info.train_url, 'results.csv')
|
||||
result_row = read_csv(result_csv_path)
|
||||
report_data = {}
|
||||
# 轮数
|
||||
epoch_data = []
|
||||
# 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。
|
||||
train_box_loss = []
|
||||
# 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。
|
||||
train_obj_loss = []
|
||||
# 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。
|
||||
train_cls_loss = []
|
||||
|
||||
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
|
||||
val_box_loss = []
|
||||
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
|
||||
val_obj_loss = []
|
||||
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
|
||||
val_cls_loss = []
|
||||
|
||||
# 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。
|
||||
m_p = []
|
||||
# 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。
|
||||
m_r = []
|
||||
|
||||
# 主干网络(Backbone)的学习率。
|
||||
x_lr0 = []
|
||||
# 检测头(Head)的学习率。
|
||||
x_lr1 = []
|
||||
|
||||
for row in result_row:
|
||||
epoch_data.append(row[0].strip())
|
||||
|
||||
train_box_loss.append(row[1].strip())
|
||||
train_obj_loss.append(row[2].strip())
|
||||
train_cls_loss.append(row[3].strip())
|
||||
|
||||
val_box_loss.append(row[8].strip())
|
||||
val_obj_loss.append(row[9].strip())
|
||||
val_cls_loss.append(row[10].strip())
|
||||
|
||||
m_p.append(row[4].strip())
|
||||
m_r.append(row[5].strip())
|
||||
|
||||
x_lr0.append(row[11].strip())
|
||||
x_lr1.append(row[12].strip())
|
||||
|
||||
report_data['epoch_data'] = epoch_data
|
||||
|
||||
report_data['train_box_loss'] = train_box_loss
|
||||
report_data['train_obj_loss'] = train_obj_loss
|
||||
report_data['train_cls_loss'] = train_cls_loss
|
||||
|
||||
report_data['val_box_loss'] = val_box_loss
|
||||
report_data['val_obj_loss'] = val_obj_loss
|
||||
report_data['val_cls_loss'] = val_cls_loss
|
||||
|
||||
report_data['m_p'] = m_p
|
||||
report_data['m_r'] = m_r
|
||||
|
||||
report_data['x_lr0'] = x_lr0
|
||||
report_data['x_lr1'] = x_lr1
|
||||
|
||||
return report_data
|
||||
|
Reference in New Issue
Block a user