from utils import os_utils as os
from . import models, crud, schemas
from utils.websocket_server import room_manager
from application.settings import yolo_url, detect_url
from apps.business.train import models as train_models
from utils.yolov5.utils.dataloaders import LoadStreams
from utils.yolov5.utils.torch_utils import select_device
from ultralytics.utils.plotting import Annotator, colors
from utils.yolov5.models.common import DetectMultiBackend
from apps.business.deepsort import service as deepsort_service
from utils.yolov5.utils.general import check_img_size, non_max_suppression, cv2, scale_boxes

import time
import torch
import asyncio
import subprocess
from sqlalchemy.ext.asyncio import AsyncSession


async def before_detect(
        detect_in: schemas.ProjectDetectLogIn,
        detect: models.ProjectDetect,
        train: train_models.ProjectTrain,
        db: AsyncSession,
        user_id: int):
    """
    开始推理
    :param detect:
    :param detect_in:
    :param train:
    :param db:
    :param user_id:
    :return:
    """
    # 推理版本
    version_path = 'v' + str(detect.detect_version + 1)

    # 权重文件
    pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt

    # 推理集合文件路径
    img_url = detect.folder_url

    out_url = os.file_path(detect_url, detect.detect_no, 'detect')

    # 构建推理记录数据
    detect_log = models.ProjectDetectLog()
    detect_log.detect_name = detect.detect_name
    detect_log.detect_id = detect.id
    detect_log.detect_version = version_path
    detect_log.train_id = train.id
    detect_log.train_version = train.train_version
    detect_log.pt_type = detect_in.pt_type
    detect_log.pt_url = pt_url
    detect_log.folder_url = img_url
    detect_log.detect_folder_url = out_url
    detect_log.user_id = user_id
    await crud.ProjectDetectLogDal(db).create_model(detect_log)
    return detect_log


def run_img_loop(
        weights: str,
        source: str,
        project: str,
        name: str,
        detect_id: int,
        is_gpu: str):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    # 运行异步函数
    loop.run_until_complete(run_detect_img(weights, source, project, name, detect_id, is_gpu))
    # 可选: 关闭循环
    loop.close()


async def run_detect_img(
        weights: str,
        source: str,
        project: str,
        name: str,
        detect_id: int,
        is_gpu: str):
    """
    执行yolov5的推理
    :param weights: 权重文件
    :param source: 图片所在文件
    :param project: 推理完成的文件位置
    :param name: 版本名称
    :param log_id: 日志id
    :param detect_id: 推理集合id
    :param db: 数据库session
    :param is_gpu: 是否gpu加速
    :return:
    """
    yolo_path = os.file_path(yolo_url, 'detect.py')
    room = 'detect_' + str(detect_id)
    await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
    commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
               project, "--save-txt", "--conf-thres", "0.6"]
    # 判断是否存在cuda版本
    if is_gpu == 'True':
        commend.append("--device=0")
    # 启动子进程
    with subprocess.Popen(
            commend,
            bufsize=1,  # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,  # 这里可以显示yolov5训练过程中出现的进度条等信息
            text=True,  # 缓存内容为文本,避免后续编码显示问题
            encoding='utf-8',
    ) as process:
        while process.poll() is None:
            line = process.stdout.readline()
            process.stdout.flush()  # 刷新缓存,防止缓存过多造成卡死
            if line != '\n' and 'yolo' not in line:
                await room_manager.send_to_room(room, line + '\n')
        # 等待进程结束并获取返回码
        return_code = process.wait()
        if return_code != 0:
            await room_manager.send_to_room(room, 'error')
        else:
            await room_manager.send_to_room(room, 'success')


async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, name):
    """
    更新推理集合的状态
    """
    detect_dal = crud.ProjectDetectDal(db)
    detect = await detect_dal.get_data(detect_id)
    detect.detect_version = detect.detect_version + 1
    await detect_dal.put_data(data_id=detect_id, data=detect)
    detect_files = await crud.ProjectDetectFileDal(db).get_datas(
        limit=0,
        v_where=[models.ProjectDetectFile.detect_id == detect_id],
        v_return_objs=True,
        v_return_count=False)
    detect_log_files = []
    for detect_file in detect_files:
        detect_log_img = models.ProjectDetectLogFile()
        detect_log_img.log_id = log_id
        image_url = os.file_path(project, name, detect_file.file_name)
        detect_log_img.file_url = image_url
        detect_log_img.file_name = detect_file.file_name
        detect_log_files.append(detect_log_img)
    await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)


async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
    """
    rtsp 视频流推理
    :param detect_id: 训练集的id
    :param weights_pt: 权重文件
    :param rtsp_url: 视频流地址
    :param data: yaml文件
    :param is_gpu: 是否启用加速
    :return:
    """
    room = 'detect_rtsp_' + str(detect_id)
    # 选择设备(CPU 或 GPU)
    device = select_device('cpu')
    # 判断是否存在cuda版本
    if is_gpu == 'True':
        device = select_device('cuda:0')

    # 加载模型
    model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)

    stride, names, pt = model.stride, model.names, model.pt
    img_sz = check_img_size((640, 640), s=stride)  # check image size

    dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
    bs = len(dataset)

    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))

    time.sleep(3)  # 等待3s,等待websocket进入

    start_time = time.time()

    for path, im, im0s, vid_cap, s in dataset:
        # 检查是否已经超过10分钟(600秒)
        elapsed_time = time.time() - start_time
        if elapsed_time > 600:  # 600 seconds = 10 minutes
            print(room, "已达到最大执行时间,结束推理。")
            break
        if room_manager.rooms.get(room):
            im = torch.from_numpy(im).to(model.device)
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:
                im = im[None]  # expand for batch dim

            # Inference
            pred = model(im, augment=False, visualize=False)
            # NMS
            pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)

            # Process predictions
            for i, det in enumerate(pred):  # per image
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                annotator = Annotator(im0, line_width=3, example=str(names))
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                    # Write results
                    for *xyxy, conf, cls in reversed(det):
                        c = int(cls)  # integer class
                        label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
                        annotator.box_label(xyxy, label, color=colors(c, True))

                # Stream results
                im0 = annotator.result()
                # 将帧编码为 JPEG
                ret, jpeg = cv2.imencode('.jpg', im0)
                if ret:
                    frame_data = jpeg.tobytes()
                    await room_manager.send_stream_to_room(room, frame_data)
        else:
            print(room, '结束推理')
            break


def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    # 运行异步函数
    loop.run_until_complete(
        run_detect_rtsp(
            weights_pt,
            rtsp_url,
            data,
            detect_id,
            is_gpu
        )
    )
    # 可选: 关闭循环
    loop.close()


def run_deepsort_loop(
        detect_id: int,
        weights_pt: str,
        data: str,
        idx_to_class: {},
        sort_type: str = 'video',
        video_path: str = None,
        rtsp_url: str = None
):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    # 运行异步函数
    loop.run_until_complete(
        deepsort_service.run_deepsort(
            detect_id,
            weights_pt,
            data,
            idx_to_class,
            sort_type,
            video_path,
            rtsp_url
        )
    )
    # 可选: 关闭循环
    loop.close()