from tools.ToolBase import ToolBase
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import NameOID
import re
import json
import pexpect

class OpensslTool(ToolBase):
    def validate_instruction(self, instruction):
        #指令过滤
        timeout = 60*5
        return instruction,timeout

    def do_worker_pexpect(self, str_instruction, timeout, ext_params):
        try:
            result = ""
            exc_do = pexpect.spawn('bash', ['-c', str_instruction], timeout=timeout,
                                   encoding='utf-8')  # spawn 第一个参数是可执行文件
            index = exc_do.expect([
                pexpect.TIMEOUT,
                pexpect.EOF
            ])
            result += str(exc_do.before)
            if index == 0:
                result += f"\n执行超时{timeout}秒"
            elif index == 1:
                pass
            else:
                print("遇到其他输出!")
                pass
            return result
        except Exception as e:
            return f"执行错误: {str(e)}"

    def execute_instruction(self, instruction_old):
        ext_params = self.create_extparams()
        # 第一步:验证指令合法性
        instruction,time_out = self.validate_instruction(instruction_old)
        if not instruction:
            return False, instruction_old, "该指令暂不执行!","",ext_params
        # 过滤修改后的指令是否需要判重?同样指令再执行结果一致?待定---#?

        # 第二步:执行指令---需要对ftp指令进行区分判断
        #output = self.do_worker_script(instruction, time_out, ext_params)
        output = self.do_worker_pexpect(instruction, time_out, ext_params)

        # 第三步:分析执行结果
        analysis = self.analyze_result(output,instruction,"","")

        return True, instruction, analysis,output,ext_params


    def parse_name(self,name):
        """解析X509名称对象为结构化字典"""
        return {
            NameOID.COUNTRY_NAME: name.get_attributes_for_oid(NameOID.COUNTRY_NAME),
            NameOID.STATE_OR_PROVINCE_NAME: name.get_attributes_for_oid(NameOID.STATE_OR_PROVINCE_NAME),
            NameOID.LOCALITY_NAME: name.get_attributes_for_oid(NameOID.LOCALITY_NAME),
            NameOID.ORGANIZATION_NAME: name.get_attributes_for_oid(NameOID.ORGANIZATION_NAME),
            NameOID.COMMON_NAME: name.get_attributes_for_oid(NameOID.COMMON_NAME),
            NameOID.ORGANIZATIONAL_UNIT_NAME: name.get_attributes_for_oid(NameOID.ORGANIZATIONAL_UNIT_NAME),
        }

    def parse_ssl_info(self,output):
        # 提取证书内容
        certs = re.findall(
            r'-----BEGIN CERTIFICATE-----(.*?)-----END CERTIFICATE-----',
            output,
            re.DOTALL
        )

        results = []
        cert_obj = None
        for cert in certs:
            cert_data = "-----BEGIN CERTIFICATE-----" + cert + "-----END CERTIFICATE-----"
            try:
                cert_obj = x509.load_pem_x509_certificate(cert_data.encode(), default_backend())
            except ValueError as e:
                print(f"证书加载失败:{str(e)}")
                continue

            san_list = []
            try:
                san_ext = cert_obj.extensions.get_extension_for_class(x509.SubjectAlternativeName)
                san_list = san_ext.value.get_values_for_type(x509.DNSName)
            except x509.ExtensionNotFound:
                pass

            if cert_obj:
                results.append({
                    'subject': str(cert_obj.subject),
                    'issuer': str(cert_obj.issuer),
                    'san': str(san_list),
                    'validity': {
                        'start': str(cert_obj.not_valid_before),
                        'end': str(cert_obj.not_valid_after)
                    },
                    'signature_algorithm': str(cert_obj.signature_algorithm_oid._name)
                })

        return results

    def analyze_result(self, result,instruction,stderr,stdout):
        #指令结果分析
        if len(result) > 3000:
            result = self.parse_ssl_info(stdout)
            result = json.dumps(result,ensure_ascii=False)
        return result