#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version        : 1.0
# @Create Time    : 2025/04/03 10:32
# @File           : crud.py
# @IDE            : PyCharm
# @desc           : 数据访问层

from sqlalchemy.ext.asyncio import AsyncSession
from core.crud import DalBase
from . import models, schemas
from utils import os_utils as os
from utils.csv_utils import read_csv


class ProjectTrainDal(DalBase):

    def __init__(self, db: AsyncSession):
        super(ProjectTrainDal, self).__init__()
        self.db = db
        self.model = models.ProjectTrain
        self.schema = schemas.ProjectTrainOut

    async def get_result(self, train_id: int):
        """
        查询训练报告
        """
        data = await self.get_data(data_id=train_id)
        if data is None:
            return None
        result_csv_path = os.file_path(data.train_url, 'results.csv')
        result_row = read_csv(result_csv_path)
        report_data = {}
        # 轮数
        epoch_data = []
        # 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。
        train_box_loss = []
        # 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。
        train_obj_loss = []
        # 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。
        train_cls_loss = []

        # 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
        val_box_loss = []
        # 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
        val_obj_loss = []
        # 验证集的分类损失,反映模型在未见数据上的分类准确性。
        val_cls_loss = []

        # 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。
        m_p = []
        # 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。
        m_r = []

        # 主干网络(Backbone)的学习率。
        x_lr0 = []
        # 检测头(Head)的学习率。
        x_lr1 = []

        for row in result_row:
            epoch_data.append(row[0].strip())

            train_box_loss.append(row[1].strip())
            train_obj_loss.append(row[2].strip())
            train_cls_loss.append(row[3].strip())

            val_box_loss.append(row[8].strip())
            val_obj_loss.append(row[9].strip())
            val_cls_loss.append(row[10].strip())

            m_p.append(row[4].strip())
            m_r.append(row[5].strip())

            x_lr0.append(row[11].strip())
            x_lr1.append(row[12].strip())

        report_data['epoch_data'] = epoch_data

        report_data['train_box_loss'] = train_box_loss
        report_data['train_obj_loss'] = train_obj_loss
        report_data['train_cls_loss'] = train_cls_loss

        report_data['val_box_loss'] = val_box_loss
        report_data['val_obj_loss'] = val_obj_loss
        report_data['val_cls_loss'] = val_cls_loss

        report_data['m_p'] = m_p
        report_data['m_r'] = m_r

        report_data['x_lr0'] = x_lr0
        report_data['x_lr1'] = x_lr1

        return report_data