import torch
import cv2
import numpy as np
import torch
import os
import importlib
from model.plugins.ModelBase import ModelBase
from loguru import logger

'''
class ModelManager_tmp():
    def __init__(self):
        print("ModelInit")

    def __del__(self):
        print("ModelManager DEL")

    def __preprocess_image(self,image, cfg, bgr2rgb=True):
        """图片预处理"""
        img, scale_ratio, pad_size = letterbox(image, new_shape=cfg['input_shape'])
        if bgr2rgb:
            img = img[:, :, ::-1]
        img = img.transpose(2, 0, 1)  # HWC2CHW
        img = np.ascontiguousarray(img, dtype=np.float32)
        return img, scale_ratio, pad_size

    def __draw_bbox(self,bbox, img0, color, wt, names):
        """在图片上画预测框"""
        det_result_str = ''
        for idx, class_id in enumerate(bbox[:, 5]):
            if float(bbox[idx][4] < float(0.05)):
                continue
            img0 = cv2.rectangle(img0, (int(bbox[idx][0]), int(bbox[idx][1])), (int(bbox[idx][2]), int(bbox[idx][3])),
                                 color, wt)
            img0 = cv2.putText(img0, str(idx) + ' ' + names[int(class_id)], (int(bbox[idx][0]), int(bbox[idx][1] + 16)),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
            img0 = cv2.putText(img0, '{:.4f}'.format(bbox[idx][4]), (int(bbox[idx][0]), int(bbox[idx][1] + 32)),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
            det_result_str += '{} {} {} {} {} {}\n'.format(
                names[bbox[idx][5]], str(bbox[idx][4]), bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3])
        return img0

    def __get_labels_from_txt(self,path):
        """从txt文件获取图片标签"""
        labels_dict = dict()
        with open(path) as f:
            for cat_id, label in enumerate(f.readlines()):
                labels_dict[cat_id] = label.strip()
        return labels_dict

    def __draw_prediction(self,pred, image, labels):
        """在图片上画出预测框并进行可视化展示"""
        imgbox = widgets.Image(format='jpg', height=720, width=1280)
        img_dw = self.__draw_bbox(pred, image, (0, 255, 0), 2, labels)
        imgbox.value = cv2.imencode('.jpg', img_dw)[1].tobytes()
        display(imgbox)

    def __infer_image(self,img_path, model, class_names, cfg):
        """图片推理"""
        # 图片载入
        image = cv2.imread(img_path)
        # 数据预处理
        img, scale_ratio, pad_size = self.__preprocess_image(image, cfg)
        # 模型推理
        output = model.infer([img])[0]

        output = torch.tensor(output)
        # 非极大值抑制后处理
        boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
        pred_all = boxout[0].numpy()
        # 预测坐标转换
        scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
        # 图片预测结果可视化
        self.__draw_prediction(pred_all, image, class_names)

    def __infer_frame_with_vis(self,image, model, labels_dict, cfg, bgr2rgb=True):
        # 数据预处理
        img, scale_ratio, pad_size = self.__preprocess_image(image, cfg, bgr2rgb)
        # 模型推理
        output = model.infer([img])[0]

        output = torch.tensor(output)
        # 非极大值抑制后处理
        boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
        pred_all = boxout[0].numpy()
        # 预测坐标转换
        scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
        # 图片预测结果可视化
        img_vis = self.__draw_bbox(pred_all, image, (0, 255, 0), 2, labels_dict)
        return img_vis

    def __img2bytes(self,image):
        """将图片转换为字节码"""
        return bytes(cv2.imencode('.jpg', image)[1])
    def __infer_camera(self,model, labels_dict, cfg):
        """外设摄像头实时推理"""

        def find_camera_index():
            max_index_to_check = 10  # Maximum index to check for camera

            for index in range(max_index_to_check):
                cap = cv2.VideoCapture(index)
                if cap.read()[0]:
                    cap.release()
                    return index

            # If no camera is found
            raise ValueError("No camera found.")

        # 获取摄像头    --这里可以换成RTSP流
        camera_index = find_camera_index()
        cap = cv2.VideoCapture(camera_index)
        # 初始化可视化对象
        image_widget = widgets.Image(format='jpeg', width=1280, height=720)
        display(image_widget)
        while True:
            # 对摄像头每一帧进行推理和可视化
            _, img_frame = cap.read()
            image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg)
            image_widget.value = self.__img2bytes(image_pred)

    def __infer_video(self,video_path, model, labels_dict, cfg):
        """视频推理"""
        image_widget = widgets.Image(format='jpeg', width=800, height=600)
        display(image_widget)

        # 读入视频
        cap = cv2.VideoCapture(video_path)
        while True:
            ret, img_frame = cap.read()
            if not ret:
                break
            # 对视频帧进行推理
            image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg, bgr2rgb=True)
            image_widget.value = self.__img2bytes(image_pred)

    def startWork(self,infer_mode,file_paht = ""):
        cfg = {
            'conf_thres': 0.4,  # 模型置信度阈值,阈值越低,得到的预测框越多
            'iou_thres': 0.5,  # IOU阈值,高于这个阈值的重叠预测框会被过滤掉
            'input_shape': [640, 640],  # 模型输入尺寸
        }

        model_path = 'yolo.om'
        label_path = './coco_names.txt'
        # 初始化推理模型
        model = InferSession(0, model_path)
        labels_dict = self.__get_labels_from_txt(label_path)

        #执行验证
        if infer_mode == 'image':
            img_path = 'world_cup.jpg'
            self.__infer_image(img_path, model, labels_dict, cfg)
        elif infer_mode == 'camera':
            self.__infer_camera(model, labels_dict, cfg)
        elif infer_mode == 'video':
            video_path = 'racing.mp4'
            self.__infer_video(video_path, model, labels_dict, cfg)
'''

'''
算法实现类,实现算法执行线程,根据配内容,以线程方式执行算法模块
'''
class ModelManager():
    def __init__(self):
        print("ModelManager init")

    def __del__(self):
        print("ModelManager del")

    def doWork(self):
        pass

#动态导入文件 -- 方法二  -- 相对推荐使用该方法  但spec感觉没什么用
def import_source(spec, plgpath):
    module = None
    if os.path.exists(plgpath):
        module_spec = importlib.util.spec_from_file_location(spec, plgpath)
        module = importlib.util.module_from_spec(module_spec)
        module_spec.loader.exec_module(module)
    else:
        logger.error("{}文件不存在".format(plgpath))
    return module

#plgpath 为list [poc][file_name][name]
def run_plugin(plgpath, target,copy_flag=True):
    module = import_source("", plgpath)
    if module:
        classname = "Model"
        plg = getattr(module, classname)()
        if not isinstance(plg, ModelBase):
            raise Exception("{} not rx_Model".format(plg))
        new_plg = plg
        result = new_plg.doWork("","","","")  # 执行plugin基类的run, 返回结果
        return result
    else:
        print("模型加载失败")
        return None

def test():
    run_plugin("plugins/RYRQ_Model_ACL.py","")

if __name__ == "__main__":
    test()