from utils import os_utils as osu
from application.settings import *
from . import schemas, models, crud
from utils.websocket_server import room_manager
from apps.business.project import models as proj_models, crud as proj_crud


import yaml
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession


async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
    """
    yolov5执行训练任务
    :param proj_info: 项目信息
    :param db: 数据库session
    :return:
    """
    img_dal = proj_crud.ProjectImageDal(db)
    label_dal = proj_crud.ProjectLabelDal(db)
    # 先查询两个图片列表
    project_images_train = await img_dal.get_datas(
        v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'],
        limit=0,
        v_return_count=False,
        v_return_objs=True)
    project_images_val = await img_dal.get_datas(
        v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'],
        limit=0,
        v_return_count=False,
        v_return_objs=True)

    # 得到训练版本
    version_path = 'v' + str(proj_info.train_version + 1)

    # 创建训练的根目录
    train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path)

    # 查询项目所属标签,返回两个 id,name一一对应的数组
    label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id)

    # 创建图片的的两个文件夹
    img_path_train = osu.create_folder(train_path, 'images', 'train')
    img_path_val = osu.create_folder(train_path, 'images', 'val')

    # 创建标签的两个文件夹
    label_path_train = osu.create_folder(train_path, 'labels', 'train')
    label_path_val = osu.create_folder(train_path, 'labels', 'val')

    # 在根目录下创建yaml文件
    yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml')
    yaml_data = {
        'path': train_path,
        'train': 'images/train',
        'val': 'images/val',
        'test': None,
        'names': {i: name for i, name in enumerate(label_name_list)}
    }
    with open(yaml_file, 'w', encoding='utf-8') as file:
        yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)

    # 开始循环复制图片和生成label.txt
    # 先操作train
    await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
    # 再操作val
    await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)

    # 开始执行异步训练
    data = yaml_file
    project = osu.file_path(runs_url, proj_info.project_no)
    name = version_path

    return data, project, name


async def operate_img_label(
        img_list: list[proj_models.ProjectImage],
        img_path: str,
        label_path: str,
        db: AsyncSession,
        label_id_list: []):
    """
    生成图片和标签内容
    :param label_id_list:
    :param db: 数据库session
    :param img_list:
    :param img_path:
    :param label_path:
    :return:
    """
    for i in range(len(img_list)):
        image = img_list[i]
        # 先复制图片,并把图片改名,不改后缀
        file_name = 'image' + str(i)
        osu.copy_and_rename_file(image.image_url, img_path, file_name)
        # 查询这张图片的label信息然后生成这张照片的txt文件
        img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
        label_txt_path = osu.file_path(label_path, file_name + '.txt')
        with open(label_txt_path, 'w', encoding='utf-8') as file:
            for image_label in img_label_list:
                index = label_id_list.index(image_label.label_id)
                file.write(str(index) + ' ' + image_label.mark_center_x + ' '
                           + image_label.mark_center_y + ' '
                           + image_label.mark_width + ' '
                           + image_label.mark_height + '\n')


def run_event_loop(
        data: str,
        project: str,
        name: str,
        train_in: schemas.ProjectTrainIn,
        project_id: int,
        train_info: models.ProjectTrain,
        is_gup: str):
    # 运行异步函数,开始训练
    loop_run = asyncio.new_event_loop()
    asyncio.set_event_loop(loop_run)
    loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience,
                                            project_id, train_info, is_gup))


async def run_commend(
        data: str,
        project: str,
        name: str,
        epochs: int,
        patience: int,
        project_id: int,
        train_info: models.ProjectTrain,
        is_gpu: str):
    """
    执行训练
    :param data: 训练数据集
    :param project: 训练结果的项目目录
    :param name: 实验名称
    :param epochs: 训练轮数
    :param patience: 早停耐心值
    :param weights: 权重文件
    :param project_id: 项目id
    :param train_info: 训练信息
    :param is_gpu: 是否是gpu环境
    :return:
    """
    yolo_path = osu.file_path(yolo_url, 'train.py')
    room = 'train_' + str(project_id)
    await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
    commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
               "--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]

    # 增加权重文件,在之前训练的基础上重新训练
    if train_info is not None:
        commend.append("--weights=" + train_info.best_pt)

    # 判断是否存在cuda版本
    if is_gpu == 'True':
        commend.append("--device=0")
    # 启动子进程
    with subprocess.Popen(
            commend,
            bufsize=1,  # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,  # 这里可以显示yolov5训练过程中出现的进度条等信息
            text=True,  # 缓存内容为文本,避免后续编码显示问题
            encoding='utf-8',
    ) as process:
        while process.poll() is None:
            line = process.stdout.readline()
            process.stdout.flush()  # 刷新缓存,防止缓存过多造成卡死
            if line != '\n' and '0%' not in line:
                await room_manager.send_to_room(room, line + '\n')

        # 等待进程结束并获取返回码
        return_code = process.wait()
        if return_code != 0:
            await room_manager.send_to_room(room, 'error')
        else:
            await room_manager.send_to_room(room, 'success')


async def add_train(
        db,
        project_id,
        name,
        project,
        data,
        train_in,
        user_id):
    # 更新版本信息
    await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id)
    # 增加训练版本信息
    train = models.ProjectTrain()
    train.project_id = project_id
    train.train_version = name
    train_url = osu.file_path(project, name)
    train.train_url = train_url
    train.train_data = data
    train.user_id = user_id
    bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt')
    last_pt_path = osu.file_path(train_url, 'weights', 'last.pt')
    train.best_pt = bast_pt_path
    train.last_pt = last_pt_path
    if train_in is not None:
        train.weights_id = train_in.weights_id
        train.weights_name = train_in.weights_name
    train.patience = train_in.patience
    train.epochs = train_in.epochs
    await crud.ProjectTrainDal(db).create_model(data=train)