#!/usr/bin/python3
# Copyright 2025 VMware by Broadcom, Inc. All rights reserved.

import logging
import json
import subprocess
import sys
import requests
import base64
import yaml
import traceback

logger = logging.getLogger(__name__)
TMP_DIR = "/tmp/update-ca"

class CustomLogFormatter(logging.Formatter):
    grey   = "\x1b[90m"
    yellow = "\x1b[93m"
    red    = "\x1b[91m"
    reset  = "\x1b[0m"
    format = '%(module)s[%(levelname)s]: %(message)s'

    FORMATS = {
        logging.DEBUG: grey + format + reset,
        logging.INFO: format + reset,
        logging.WARNING: yellow + format + reset,
        logging.ERROR: red + format + reset,
        logging.CRITICAL: red + format + reset
    }

    def format(self, record):
        record.levelname = 'WARN' if record.levelname == 'WARNING' else record.levelname
        record.levelname = 'ERROR' if record.levelname == 'CRITICAL' else record.levelname
        log_fmt = self.FORMATS.get(record.levelno)
        formatter = logging.Formatter(log_fmt)
        return formatter.format(record)

def logger_setup(log_dst="console", log_level='info', log_file='/common/logs/update-cert.log'):
    handler = None

    log_levels = {
        "debug": logging.DEBUG,
        "info": logging.INFO,
        "warning": logging.WARNING,
        "error": logging.ERROR,
        "critical": logging.CRITICAL,
    }

    if log_dst == "file":
        handler = logging.FileHandler(log_file)
    elif log_dst == "console":
        handler = logging.StreamHandler()

    handler.setFormatter(CustomLogFormatter())
    logger.addHandler(handler)

    if log_level in log_levels.keys():
        logger.setLevel(log_levels[log_level])

class Cert:
    GET_SUBJECT_CMD = "openssl x509 -noout -subject -in %s"

    def __init__(self, cert_content):
        self.content = cert_content.strip()
        self.subject = None

    def get_dn(self):
        if self.subject:
            return self.subject
        cert_file_name = "%s/tmp_cert" % TMP_DIR
        with open(cert_file_name, 'w') as f:
            f.write(self.content)
        cmd = self.GET_SUBJECT_CMD % cert_file_name
        p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE, close_fds=True)
        res, stderr = p.communicate()
        rc = p.returncode
        if rc:
            error = stderr.decode('utf-8')
            logger.warning("failed to extract certificate subject, cert content"
                           ": %s" % self.content)
            logger.warning("assignning cert base64 as dn of this certificate")
            ca_base64 = base64.b64encode(self.content.encode('utf-8')).decode('utf-8')
            self.subject = ca_base64
        else:
            self.subject = res.decode('utf-8')
        return self.subject

    def get_content(self):
        return self.content

class Certs:
    START_LINE = '-----BEGIN CERTIFICATE-----'
    def __init__(self, certs_content):
        self.certs = {}
        logger.debug("cert content: %s" % certs_content)
        cert_slots = certs_content.split(self.START_LINE)
        for pem_cert_content in cert_slots[1:]:
            cert = Cert(self.START_LINE + pem_cert_content)
            self.certs[cert.get_dn()] = cert
        self.merged = 0

    def merge(self, new_cert_content):
        new_cert = Cert(new_cert_content)
        dn = new_cert.get_dn()
        if ((dn not in self.certs) or (dn in self.certs and
            self.certs[dn].get_content() != new_cert_content)):
            self.certs[dn] = new_cert
            self.merged = 1

    def is_merged(self):
        return self.merged

    def dump(self):
        certs_content = ""
        first = True
        for key, cert in self.certs.items():
            if not first:
                certs_content += '\n'
            else:
                first = False
            certs_content += cert.get_content()
        return certs_content

class AirgapRepo:
    verified = False

    def __new__(cls, fqdn, cafile):
        if not hasattr(cls, 'instance'):
            cls.instance = super(AirgapRepo, cls).__new__(cls)
        return cls.instance

    def __init__(self, fqdn, cafile):
        self.fqdn = fqdn
        self.cafile = cafile
        self.ca_plain_text = None
        self.ca_base64 = None
        return self.load_ca()

    def load_ca(self):
        try:
            with open(self.cafile, 'r', encoding='ascii') as file:
                self.ca_plain_text = file.read().strip()
                ca_bytes = self.ca_plain_text.encode('ascii')
                self.ca_base64 = base64.b64encode(ca_bytes).decode('ascii')
        except FileNotFoundError:
            logger.error("File is not found.")
        except PermissionError:
            logger.error("You don't have permission to access this file.")
        return
    def get_base64_ca(self):
        return self.ca_base64
    def get_plain_ca(self):
        return self.ca_plain_text
    def get_fqdn(self):
        return self.fqdn
    def verify(self):
        repo_url = "https://%s" % self.fqdn
        response = requests.get(repo_url, verify=self.cafile)
        if response.ok:
            self.verified = True
            return True, ""
        else:
            return False, response.reason
    def has_verified(self):
        return self.verified
    def is_valid(self):
        try:
            return self.verify()
        except Exception as e:
            return False, repr(e)
    def dump(self):
        logger.info("dumpping the repo: %s" % self.fqdn)
        logger.info("repo cafile: %s" % self.cafile)
        logger.info("repo ca plain text: %s" % self.ca_plain_text)
        logger.info("repo ca base64: %s" % self.ca_base64)
        logger.info("repo is_valid: %s, %s" % self.is_valid())

class Kbs:
    def __init__(self, k8s_service_ip):
        self.k8s_service_ip = k8s_service_ip
    
    def get_tkg_context(self, tkg_id):
        try:
            url = "http://%s:8888/api/v1/tkgcontext/%s?plaintext=true" % (self.k8s_service_ip, tkg_id)
            response = requests.get(url)
            logger.info("Successfully get TkgContext")
            return response.json()
        except Exception as error:
            logger.error("Failed to get TkgContext, err: %s" % error)
            sys.exit()
    
    def get_mgmt_cluster(self, mc_name):
        try:
            url = "http://%s:8888/api/v1/managementclusters" % self.k8s_service_ip
            mgmt_clusters = requests.get(url).json()
            for mc in mgmt_clusters:
                if mc['clusterName'] == mc_name:
                    return mc
            logger.error("Management cluster %s doesn't exist" % mc_name)
            sys.exit()
        except Exception as error:
            logger.error("Failed to get management clusters, err: %s" % error)
            sys.exit()
    
    def update_tkg_context(self, tkg_context):
        tkg_id = tkg_context['id']
        url = "http://%s:8888/api/v1/tkgcontext/%s" % (self.k8s_service_ip, tkg_id)
        response = requests.put(url, json = tkg_context)
        response.encoding = 'utf-8'
        if response.status_code == 200:
            logger.info("Updated tkgcontext %s with response %s" % (tkg_id, response))
        else:
            logger.error("Failed to update tkgcontext %s, err: %s" % (tkg_id, response.text))
            sys.exit()
            
    def update_mgmt_cluster(self, cluster, repo, mc_kubeconfig_file):
        cluster_client = MgmtClusterClient(mc_kubeconfig_file, repo)
        try:
            # Update relevant CR instances based on Cluster type 
            if cluster_client.IsClusterClass(cluster["clusterName"]):
                cluster_client.cc_update(cluster["clusterName"])
            else:
                cluster_client.update()
        except Exception as e:
            logger.critical("failed to update mgmt cluster configmap/crs with errors: %s" % repr(e))
            traceback.print_exc()

    def get_mgmtcluster_kubeconfig(self, mc_id):
        try:
            url = "http://%s:8888/api/v1/managementcluster/%s/kubeconfig" % (self.k8s_service_ip, mc_id)
            response = requests.get(url)
            logger.info("Successfully get management cluster Kubeconfig")
            return response.text
        except Exception as error:
            logger.error("Failed to get management cluster Kubeconfig, err: %s" % error)
            sys.exit()

class ClusterClient:
    RUN_SH = "%s/cmd.sh" % TMP_DIR
    def __init__(self, kubeconfig_file):
        self.kubeconfig_file = kubeconfig_file

    def run_cmd(self, cmd):
        kubectl_cmd = "kubectl --kubeconfig %s --request-timeout 30s %s" % (self.kubeconfig_file, cmd)
        with open(self.RUN_SH, 'w') as f:
            f.write(kubectl_cmd)
            f.close()

        logger.debug("running command: %s" % kubectl_cmd)
        cmd = 'bash ' + self.RUN_SH
        p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE, close_fds=True)
        res, stderr = p.communicate()
        rc = p.returncode
        if rc:
            error = stderr.decode('utf-8')
            if "NotFound" in error:
                raise NameError("run \"%s\" failed, err: %s" % (kubectl_cmd, error))
            raise RuntimeError("run \"%s\" failed, err: %s" % (kubectl_cmd, error))
        logger.debug(res.decode('utf-8'))
        return res

    def get_json(self, cmd):
        get_cmd = "get %s -o json" % cmd
        return self.run_cmd(get_cmd)

    def get_yaml(self, cmd):
        get_cmd = "get %s -o yaml" % cmd
        return self.run_cmd(get_cmd)

    def apply_yaml(self, file):
        cmd = "apply -f %s" % (file)
        return self.run_cmd(cmd)

    def get_dict(self, cmd):
        return json.loads(self.get_json(cmd))

    def get_secret_datavalues(self, ns, name):
        get_cmd = "get secret %s -n %s -o jsonpath={.data.\"values\\.yaml\"} | base64 -d" % (name, ns)
        return yaml.safe_load(self.run_cmd(get_cmd))

    def get_secret_data_value(self, ns, name):
        get_cmd = "get secret %s -n %s -o jsonpath={.data.value} | base64 -d" % (name, ns)
        return self.run_cmd(get_cmd)

class TkgClusterClient(ClusterClient):
    PATCH_KAPP_CM_CMD = "patch cm -n tkg-system kapp-controller-config -p"
    GET_KAPP_CM_CMD = "cm -n tkg-system kapp-controller-config"
    RESTART_KAPP_DEPLOYMENT_CMD = "rollout restart deployment -n tkg-system kapp-controller"
    def __init__(self, kubeconfig_file, repo):
        self.kubeconfig_file = kubeconfig_file
        self.repo = repo

    def merge_certs(self, old):
        #certstr = self.repo.get_plain_ca().replace('\n', '\\n')
        certstr = self.repo.get_plain_ca()
        if old is None:
             return certstr
        certs = Certs(old)
        certs.merge(certstr)
        if certs.is_merged():
            return certs.dump()
        return None

    def generate_cm_cacerts_patch_json(self, orig):
        patch_dict = {}
        patch_dict["data"] = {}

        newcerts = self.merge_certs(orig)
        if newcerts is None:
            return None

        patch_dict["data"]["caCerts"] = newcerts
        patch_cm_json = json.dumps(patch_dict, separators=(',',':')).replace('"', '\\"')
        return patch_cm_json

    def get_kapp_controller_cm(self):
        try:
            kapp_cm = self.get_dict(self.GET_KAPP_CM_CMD)
        except Exception as e:
            logger.critical("failed to get configmap kapp-controller-config/tkg-system, err: %s" % repr(e))
            traceback.print_exc()
            return None
        return kapp_cm

    def update_kapp_controller_cm(self):
        kapp_cm = self.get_kapp_controller_cm()
        if kapp_cm is None:
            return
        patch_cm_json = None
        if 'data' not in kapp_cm or 'caCerts' not in kapp_cm['data']:
            patch_cm_json = self.generate_cm_cacerts_patch_json(None)
        else:
            patch_cm_json = self.generate_cm_cacerts_patch_json(kapp_cm['data']['caCerts'])

        if patch_cm_json is None:
            logger.info("cluster kapp-controller-config is up to date, skip")
            return

        cmd = "%s \"%s\"" % (self.PATCH_KAPP_CM_CMD, patch_cm_json)
        self.run_cmd(cmd)
        self.run_cmd(self.RESTART_KAPP_DEPLOYMENT_CMD)
        logger.info("update cluster kapp-controller-config successfully")

    def verify_kapp_controller_cm(self):
        kapp_cm = self.get_kapp_controller_cm()
        if kapp_cm is None:
            return
        if 'data' in kapp_cm and 'caCerts' in kapp_cm['data']:
            cacerts = kapp_cm['data']['caCerts']
            new_cacerts = self.merge_certs(cacerts)
            if new_cacerts is None:
                logger.info("configmap kapp-controller-config/tkg-system: up to date")
                return
        logger.error("configmap kapp-controller-config/tkg-system: out of date")

class MgmtClusterClient(TkgClusterClient):
    PATCH_TKR_CM_CMD_TKR_SYSTEM = "patch cm -n tkr-system tkr-controller-config -p"
    GET_TKR_CM_CMD_TKR_SYSTEM = "cm -n tkr-system tkr-controller-config"
    RESTART_TKR_DEPLOYMENT_CMD_TKR_SYSTEM = "rollout restart deployment -n tkr-system tkr-controller-manager"
    PATCH_TKR_CM_CMD_TKG_SYSTEM = "patch cm -n tkg-system tkr-controller-config -p"
    GET_TKR_CM_CMD_TKG_SYSTEM = "cm -n tkg-system tkr-controller-config"
    RESTART_TKR_SOURCE_DEPLOYMENT_CMD_TKG_SYSTEM = "rollout restart deployment -n tkg-system tkr-source-controller-manager"
    RESTART_TKR_VSPHERE_DEPLOYMENT_CMD_TKG_SYSTEM = "rollout restart deployment -n tkg-system tkr-vsphere-resolver-webhook-manager"
    GET_ALL_KCP_CMD = "kcp -A"
    GET_ALL_KCT_CMD = "kubeadmconfigtemplate -A"
    TKG_ADDON_SECRET_HEADER = "#@data/values\n#@overlay/match-child-defaults missing_ok=True\n---\n"

    def __init__(self, kubeconfig_file, repo):
        super().__init__(kubeconfig_file, repo)
        self.kcps = None
        self.kcts = None

    def IsClusterClass(self, clusterName):
        GET_CLUSTER_CMD = "cluster -n tkg-system %s" % clusterName
        clusterCR = self.get_dict(GET_CLUSTER_CMD)
        if "topology" in clusterCR["spec"]:
            return True
        return False

    def update_tkg_pkg_secret(self):
        secretName = "tkg-pkg-tkg-system-values"
        ns = "tkg-system"
        try:
            _ = self.get_dict("secret -n %s %s" % (ns, secretName))
        except NameError:
            return
        get_cmd = "get secret %s -n %s -o jsonpath={.data.\"tkgpackagevalues\\.yaml\"} | base64 -d" % (secretName, ns)
        try:
            values_dict = yaml.safe_load(self.run_cmd(get_cmd))
        except Exception as e:
            logger.critical("failed to get tkgpackagevalues.yaml from secret [%s] in namespace [%s], err: %s"
                   % (secretName, ns, repr(e)))
            traceback.print_exc()
            raise e
        need_patch = False
        if 'configvalues' in values_dict:
            if 'TKG_CUSTOM_IMAGE_REPOSITORY_CA_CERTIFICATE' in values_dict['configvalues']:
                old = values_dict['configvalues']['TKG_CUSTOM_IMAGE_REPOSITORY_CA_CERTIFICATE']
                if old != self.repo.get_base64_ca():
                    values_dict['configvalues']['TKG_CUSTOM_IMAGE_REPOSITORY_CA_CERTIFICATE'] = self.repo.get_base64_ca()
                    need_patch = True
            if 'CUSTOM_TDNF_REPOSITORY_CERTIFICATE' in values_dict['configvalues']:
                old = values_dict['configvalues']['CUSTOM_TDNF_REPOSITORY_CERTIFICATE']
                if old != self.repo.get_base64_ca():
                    values_dict['configvalues']['CUSTOM_TDNF_REPOSITORY_CERTIFICATE'] = self.repo.get_base64_ca()
                    need_patch = True
        if ('tkrSourceControllerPackage' in values_dict
                and 'tkrSourceControllerPackageValues' in values_dict['tkrSourceControllerPackage']
                and 'caCerts' in values_dict['tkrSourceControllerPackage']['tkrSourceControllerPackageValues']):
            old = values_dict['tkrSourceControllerPackage']['tkrSourceControllerPackageValues']['caCerts']
            if old != self.repo.get_base64_ca():
                values_dict['tkrSourceControllerPackage']['tkrSourceControllerPackageValues']['caCerts'] = self.repo.get_base64_ca()
                need_patch = True
        if not need_patch:
            logger.info("secret %s/%s is up to date, skip" % (secretName, ns))
            return
        values = yaml.dump(values_dict)
        logger.debug("tkg-pkg secret new values: %s", values)
        base64_values = base64.b64encode(values.encode('utf-8')).decode('utf-8')
        patch_secret_dict = {}
        patch_secret_dict['data'] = {}
        patch_secret_dict['data']['tkgpackagevalues.yaml'] = base64_values
        patch_secret_json = json.dumps(patch_secret_dict,
                                    separators=(',',':')).replace('"', '\\"')
        patch_cmd = "patch secret %s -n %s --type=merge -p \"%s\"" % (
                    secretName, ns, patch_secret_json)
        self.run_cmd(patch_cmd)
        logger.info("update secret [%s] in namespace [%s] successfully"
                    % (secretName, ns))

    def verify_tkg_pkg_secret(self):
        secretName = "tkg-pkg-tkg-system-values"
        ns = "tkg-system"
        try:
            _ = self.get_dict("secret -n %s %s" % (ns, secretName))
        except NameError:
            return
        get_cmd = "get secret %s -n %s -o jsonpath={.data.\"tkgpackagevalues\\.yaml\"} | base64 -d" % (secretName, ns)
        try:
            values_dict = yaml.safe_load(self.run_cmd(get_cmd))
        except Exception as e:
            logger.critical("failed to get tkgpackagevalues.yaml from secret [%s] in namespace [%s], err: %s"
                   % (secretName, ns, repr(e)))
            traceback.print_exc()
            raise e
        if 'configvalues' in values_dict:
            if 'TKG_CUSTOM_IMAGE_REPOSITORY_CA_CERTIFICATE' in values_dict['configvalues']:
                cert = values_dict['configvalues']['TKG_CUSTOM_IMAGE_REPOSITORY_CA_CERTIFICATE']
                if cert != self.repo.get_base64_ca():
                    logger.error("secret %s/%s: out of date" % (secretName, ns))
                    return
            if 'CUSTOM_TDNF_REPOSITORY_CERTIFICATE' in values_dict['configvalues']:
                cert = values_dict['configvalues']['CUSTOM_TDNF_REPOSITORY_CERTIFICATE']
                if cert != self.repo.get_base64_ca():
                    logger.error("secret %s/%s: out of date" % (secretName, ns))
                    return
        if ('tkrSourceControllerPackage' in values_dict
                and 'tkrSourceControllerPackageValues' in values_dict['tkrSourceControllerPackage']
                and 'caCerts' in values_dict['tkrSourceControllerPackage']['tkrSourceControllerPackageValues']):
            cert = values_dict['tkrSourceControllerPackage']['tkrSourceControllerPackageValues']['caCerts']
            if cert != self.repo.get_base64_ca():
                logger.error("secret %s/%s: out of date" % (secretName, ns))
                return
        logger.info("secret %s/%s: up to date" % (secretName, ns))

    def update_tkr_controller_secrets(self):
        secretList = ["tkr-source-controller-values", "tkr-vsphere-resolver-values"]
        ns = "tkg-system"
        for secret_name in secretList:
            try:
                _ = self.get_dict("secret -n %s %s" % (ns, secret_name))
            except NameError:
                logger.info("secret %s/%s not found, skip")
                continue
            try:
                values_dict = self.get_secret_datavalues(ns, secret_name)
            except Exception as e:
                logger.critical("failed to get secret [%s] in namespace [%s], err: %s"
                                % (secret_name, ns, repr(e)))
                traceback.print_exc()
                raise e
            if 'caCerts' in values_dict:
                if values_dict['caCerts'] == self.repo.get_base64_ca():
                    logger.info("secret %s/%s is up to date, skip" % (secret_name, ns))
                    continue
                values_dict['caCerts'] = self.repo.get_base64_ca()
            values = yaml.dump(values_dict)
            logger.debug("tkr secret new values: %s", values)
            base64_values = base64.b64encode(values.encode('utf-8')).decode('utf-8')
            patch_secret_dict = {}
            patch_secret_dict['data'] = {}
            patch_secret_dict['data']['values.yaml'] = base64_values
            patch_secret_json = json.dumps(patch_secret_dict,
                                        separators=(',',':')).replace('"', '\\"')
            patch_cmd = "patch secret %s -n %s --type=merge -p \"%s\"" % (
                        secret_name, ns, patch_secret_json)
            self.run_cmd(patch_cmd)
            logger.info("update secret [%s] in namespace [%s] successfully"
                        % (secret_name, ns))

    def verify_tkr_controller_secrets(self):
        secretList = ["tkr-source-controller-values", "tkr-vsphere-resolver-values"]
        ns = "tkg-system"
        for secret_name in secretList:
            try:
                _ = self.get_dict("secret -n %s %s" % (ns, secret_name))
            except NameError:
                continue
            try:
                values_dict = self.get_secret_datavalues(ns, secret_name)
            except Exception as e:
                logger.critical("failed to get secret [%s] in namespace [%s], err: %s"
                                % (secret_name, ns, repr(e)))
                traceback.print_exc()
                return
            # tkg252 v1.28.11 MC tkr-vsphere-resolver-values has no such config
            if 'caCerts' not in values_dict:
                logger.info("caCerts doesn't exist in secret [%s] in namespace [%s]" % (
                            secret_name, ns))
                return
            cacerts = values_dict['caCerts']
            if cacerts == self.repo.get_base64_ca():
                logger.info("secret [%s] in namespace [%s]: up to date" % (
                            secret_name, ns))
            else:
                logger.error("secret [%s] in namespace [%s]: out of date" % (
                            secret_name, ns))

    def update_tkr_controller_cm(self):
        patch_dict = {}
        patch_dict["data"] = {}
        newcerts = self.repo.get_plain_ca()
        patch_dict["data"]["caCerts"] = newcerts
        patch_cm_json = json.dumps(patch_dict, separators=(',',':')).replace('"', '\\"')
        try:
            _ = self.get_dict(self.GET_TKR_CM_CMD_TKG_SYSTEM)
            cmd = "%s \"%s\"" % (self.PATCH_TKR_CM_CMD_TKG_SYSTEM, patch_cm_json)
            self.run_cmd(cmd)
            self.run_cmd(self.RESTART_TKR_SOURCE_DEPLOYMENT_CMD_TKG_SYSTEM)
            self.run_cmd(self.RESTART_TKR_VSPHERE_DEPLOYMENT_CMD_TKG_SYSTEM)
        except NameError:
            cmd = "%s \"%s\"" % (self.PATCH_TKR_CM_CMD_TKR_SYSTEM, patch_cm_json)
            self.run_cmd(cmd)
            self.run_cmd(self.RESTART_TKR_DEPLOYMENT_CMD_TKR_SYSTEM)
        logger.info("update management cluster tkr-controller-config successfully")

    def verify_tkr_controller_cm(self):
        try:
            tkr_cm = self.get_dict(self.GET_TKR_CM_CMD_TKG_SYSTEM)
        except NameError:
            tkr_cm = self.get_dict(self.GET_TKR_CM_CMD_TKR_SYSTEM)
        except Exception as e:
            logger.critical("failed to get configmap tkr-controller-config, err: %s" % repr(e))
            traceback.print_exc()
            return
        expected_ca = self.repo.get_plain_ca()
        if ('data' in tkr_cm and 'caCerts' in tkr_cm['data']
            and tkr_cm['data']['caCerts'].strip() == expected_ca):
            logger.info("configmap tkr-controller-config: up to date")
        else:
            logger.error("configmap tkr-controller-config: out of date")

    def get_kcp(self, ns):
        if self.kcps == None:
            self.kcps = self.get_dict(self.GET_ALL_KCP_CMD)['items']
        if ns is None:
            return self.kcps
        for kcp in self.kcps:
            if kcp['metadata']['namespace'] == ns:
                return kcp
        return None

    def update_kcp(self, ns):
        kcp = self.get_kcp(ns)
        if 'files' in kcp['spec']['kubeadmConfigSpec']:
            files = kcp['spec']['kubeadmConfigSpec']['files']
        else:
            files = []
        name = kcp['metadata']['name']
        ns = kcp['metadata']['namespace']
        fqdn = self.repo.get_fqdn()
        path = "/etc/containerd/%s.crt" % fqdn
        patch_json = None
        found = False
        for file in files:
            if file['path'] == path and file['encoding'] == 'base64':
                found = True
                if file['content'] != self.repo.get_base64_ca():
                    file['content'] = self.repo.get_base64_ca()
                else:
                    logger.info("kubecontrolplane cr[%s/%s]: up to date" % (name, ns))
                    return
                break
        if not found:
            f = {}
            f['path'] = path
            f['encoding'] = 'base64'
            f['content'] = self.repo.get_base64_ca()
            files.append(f)

        patch_dict = {}
        patch_dict['spec'] = {}
        patch_dict['spec']['kubeadmConfigSpec'] = {}
        patch_dict['spec']['kubeadmConfigSpec']['files'] = files
        patch_json = json.dumps(patch_dict, separators=(',',':'))
        patch_cmd = "patch kcp -n %s %s --type=merge -p '%s'" % (ns, name,
                    patch_json)
        self.run_cmd(patch_cmd)
        logger.info("update kcp %s airgap cacert file content successfully" % name)

    def verify_kcp(self, ns):
        try:
            kcp = self.get_kcp(ns)
        except Exception as e:
            logger.critical("failed to get kubecontrolplane in namespace[%s], err: %s" % (ns, repr(e)))
            traceback.print_exc()
            return
        name = kcp['metadata']['name']
        fqdn = self.repo.get_fqdn()
        if ('kubeadmConfigSpec' in kcp['spec'] and
            'files' in kcp['spec']['kubeadmConfigSpec']):
            files = kcp['spec']['kubeadmConfigSpec']['files']
            path = "/etc/containerd/%s.crt" % fqdn
            for file in files:
                if file['path'] == path:
                    if (file['encoding'] == 'base64' and
                        file['content'] == self.repo.get_base64_ca()):
                        logger.info("kubecontrolplane cr[%s/%s]: up to date" % (name, ns))
                        return
        logger.error("kubecontrolplane cr[%s/%s]: out of date" % (name, ns))

    def get_kcts(self, ns): #kubeadmconfigtemplate
        if self.kcts == None:
            self.kcts = self.get_dict(self.GET_ALL_KCT_CMD)['items']
        kcts = []
        for kct in self.kcts:
            if kct['metadata']['namespace'] == ns:
                kcts.append(kct)
        return kcts

    def update_kcts(self, ns):
        kcts = self.get_kcts(ns)
        fqdn = self.repo.get_fqdn()
        path = "/etc/containerd/%s.crt" % fqdn

        for kct in kcts:
            if 'files' in kct['spec']['template']['spec']:
                files = kct['spec']['template']['spec']['files']
            else:
                files = []
            name = kct['metadata']['name']
            if name == 'tkg-vsphere-default-v1.0.0-md-config':
                continue
            ns = kct['metadata']['namespace']
            patch_json = None
            found = False
            need_patch = False
            for file in files:
                if file['path'] == path and file['encoding'] == 'base64':
                    found = True
                    if file['content'] != self.repo.get_base64_ca():
                        file['content'] = self.repo.get_base64_ca()
                        need_patch = True
                    else:
                        logger.info("kubeadmconfigtemplate cr[%s/%s]: up to date" % (name, ns))
                    break
            if not found:
                need_patch = True
                f = {}
                f['path'] = path
                f['encoding'] = 'base64'
                f['content'] = self.repo.get_base64_ca()
                files.append(f)
            if need_patch:
                patch_dict = {}
                patch_dict['spec'] = {}
                patch_dict['spec']['template'] = {}
                patch_dict['spec']['template']['spec'] = {}
                patch_dict['spec']['template']['spec']['files'] = files
                patch_json = json.dumps(patch_dict, separators=(',',':'))
                patch_cmd = "patch kubeadmconfigtemplate -n %s %s --type=merge -p '%s'" % (
                            ns, name, patch_json)
                self.run_cmd(patch_cmd)
                logger.info("update kubeadmconfigtemplate %s airgap cacert file content successfully" % name)

    def verify_kcts(self, ns):
        try:
            kcts = self.get_kcts(ns)
        except Exception as e:
            logger.critical("failed to get kubeadmconfigtemplates in namespace[%s], err: %s" % (ns, repr(e)))
            traceback.print_exc()
            return

        fqdn = self.repo.get_fqdn()
        path = "/etc/containerd/%s.crt" % fqdn

        for kct in kcts:
            name = kct['metadata']['name']
            if name == 'tkg-vsphere-default-v1.0.0-md-config':
                continue
            ns = kct['metadata']['namespace']
            found = False
            if 'files' in kct['spec']['template']['spec']:
                files = kct['spec']['template']['spec']['files']
                for file in files:
                    if (file['path'] == path and file['encoding'] == 'base64' and
                        file['content'] == self.repo.get_base64_ca()):
                        logger.info("kubeadmconfigtemplate cr[%s/%s]: up to date"
                                    % (name, ns))
                        found = True
                        break

            if not found:
                logger.error("kubeadmconfigtemplate cr[%s/%s]: out of date"
                              % (name, ns))

    def update_kapp_secret(self, ns):
        secret_name = "%s-kapp-controller-addon" % ns
        try:
            values_dict = self.get_secret_datavalues(ns, secret_name)
        except Exception as e:
            logger.critical("failed to get secret [%s] in namespace [%s], err: %s"
                            % (secret_name, ns, repr(e)))
            traceback.print_exc()
            raise e
        if 'caCerts' not in values_dict['kappController']['config']:
            new_cacerts = self.merge_certs(None)
        else:
            cacerts = values_dict['kappController']['config']['caCerts']
            new_cacerts = self.merge_certs(cacerts)
        if new_cacerts is None:
            logger.info("caCerts in secret [%s] of namespace [%s] is up to date"
                        ", skip" % (secret_name, ns))
            return
        values_dict['kappController']['config']['caCerts'] = new_cacerts
        values = yaml.dump(values_dict)
        values = self.TKG_ADDON_SECRET_HEADER + values
        logger.debug("kapp secret new values: %s", values)
        base64_values = base64.b64encode(values.encode('utf-8')).decode('utf-8')
        patch_secret_dict = {}
        patch_secret_dict['data'] = {}
        patch_secret_dict['data']['values.yaml'] = base64_values
        patch_secret_json = json.dumps(patch_secret_dict,
                                       separators=(',',':')).replace('"', '\\"')
        patch_cmd = "patch secret %s -n %s --type=merge -p \"%s\"" % (
                    secret_name, ns, patch_secret_json)
        self.run_cmd(patch_cmd)
        logger.info("update secret [%s] in namespace [%s] successfully"
                    % (secret_name, ns))

    def verify_kapp_secret(self, ns):
        secret_name = "%s-kapp-controller-addon" % ns
        try:
            values_dict = self.get_secret_datavalues(ns, secret_name)
        except Exception as e:
            logger.critical("failed to get secret [%s] in namespace [%s], err: %s"
                            % (secret_name, ns, repr(e)))
            traceback.print_exc()
            return
        if 'caCerts' not in values_dict['kappController']['config']:
            logger.error("secret [%s] in namespace [%s]: out of date" % (
                         secret_name, ns))
            return
        cacerts = values_dict['kappController']['config']['caCerts']
        new_cacerts = self.merge_certs(cacerts)
        if new_cacerts is None:
            logger.info("secret [%s] in namespace [%s]: up to date" % (
                        secret_name, ns))
        else:
            logger.error("secret [%s] in namespace [%s]: out of date" % (
                         secret_name, ns))
            
    def cc_update(self, clusterName):
        self.update_kapp_controller_cm()
        self.update_tkg_pkg_secret()
        self.update_tkr_controller_secrets()
        self.update_tkr_controller_cm()
        self.update_nodeconfig("/etc/ssl/certs/private-tdnf-repository-ca.pem")
        GET_CLUSTER_CMD = "cluster -n tkg-system %s" % clusterName
        clusterCR = self.get_dict(GET_CLUSTER_CMD)
        need_patch = False
        variables = clusterCR['spec']['topology']['variables']
        for var in variables:
            if var['name'] == 'trust' and 'additionalTrustedCAs' in var['value']:
                for ca in var['value']['additionalTrustedCAs']:
                    if ca['name'] == 'imageRepository' and ca['data'] != self.repo.get_base64_ca():
                        ca['data'] = self.repo.get_base64_ca()
                        need_patch = True
            elif var['name'] == 'customTDNFRepository' and 'certificate' in var['value']:
                if var['value']['certificate'] != self.repo.get_base64_ca():
                    var['value']['certificate'] = self.repo.get_base64_ca()
                    need_patch = True
        if need_patch:
            patch_dict = {'spec': {'topology': {'variables': variables}}}
            patch_json = json.dumps(patch_dict, separators=(',',':')).replace('"', '\\"')
            patch_cmd = "patch cluster -n tkg-system %s --type=merge -p \"%s\"" % (clusterName, patch_json)
            self.run_cmd(patch_cmd)
            logger.info("update clusterclass mgmt cluster [%s] in namespace tkg-system successfully" % clusterName)
        else:
            logger.info("clusterclass mgmt cluster [%s] in namespace tkg-system up to date" % clusterName)

    def cc_verify(self, clusterName):
        verify_nodeconfig(self.repo, self)
        self.verify_kapp_controller_cm()
        self.verify_tkg_pkg_secret()
        self.verify_tkr_controller_secrets()
        self.verify_tkr_controller_cm()
        GET_CLUSTER_CMD = "cluster -n tkg-system %s" % clusterName
        clusterCR = self.get_dict(GET_CLUSTER_CMD)
        upToDate = True
        for var in clusterCR['spec']['topology']['variables']:
            if var['name'] == 'trust' and 'additionalTrustedCAs' in var['value']:
                for ca in var['value']['additionalTrustedCAs']:
                    if ca['name'] == 'imageRepository' and ca['data'] != self.repo.get_base64_ca():
                        logger.error("cluster [%s] in namespace tkg-system: out of date" % clusterName)
                        upToDate = False
            elif var['name'] == 'customTDNFRepository' and 'certificate' in var['value']:
                if var['value']['certificate'] != self.repo.get_base64_ca():
                    logger.error("cluster [%s] in namespace tkg-system: out of date" % clusterName)
                    upToDate = False
        if upToDate:
            logger.info("cluster [%s] in namespace tkg-system: up to date" % clusterName)
        
        self.verify_machine()

    def update(self):
        fqdn = self.repo.get_fqdn()
        ssl_cert_file = "/etc/ssl/certs/%s.pem" % fqdn
        self.update_kapp_controller_cm()
        self.update_tkg_pkg_secret()
        self.update_tkr_controller_secrets()
        self.update_tkr_controller_cm()
        self.update_kcp("tkg-system")
        self.update_kcts("tkg-system")
        self.update_nodeconfig(ssl_cert_file)

    def verify(self):
        verify_nodeconfig(self.repo, self)
        self.verify_kapp_controller_cm()
        self.verify_tkg_pkg_secret()
        self.verify_tkr_controller_secrets()
        self.verify_tkr_controller_cm()
        self.verify_kcp("tkg-system")
        self.verify_kcts("tkg-system")
        self.verify_machine()
        
    def generate_nodeconfig_file(self, ssl_file, nodeconfig_yaml_file):
        fqdn = self.repo.get_fqdn()
        container_cert_file = "/etc/containerd/%s.crt" % fqdn
        base64_cert_with_newline = add_newline_char_cert(self.repo.get_plain_ca())
        fileInjection = {
             "fileInjection":[
                 {
                     "path": container_cert_file,
                     "source": "base64Content",
                     "fileMode": "0444",
                     "content": base64_cert_with_newline
                 },
                 {
                     "path": ssl_file,
                     "source": "base64Content",
                     "fileMode": "0444",
                     "content": base64_cert_with_newline
                 },
             ]
        }
        fileInjection_str = yaml.dump(fileInjection, default_flow_style=False)
        instance = {
            "apiVersion": "acm.vmware.com/v1alpha1",
            "kind": "NodeConfig",
            "metadata": {
                "name": "update-airgap-certs",
                "namespace": "tca-system",
                },
            "spec": {
                "config": fileInjection_str,
                "jobType": "profile",
                "nodeMatch": [
                    {
                        "kubernetes.io/os": "linux"
                        }
                    ]
                }
        }
        with open(nodeconfig_yaml_file, 'w') as file:
            yaml.dump(instance, file, default_flow_style=False)
        
    def update_nodeconfig(self, ssl_file):
        nodeconfig_yaml_file = "%s/update-airgap-certs.yaml" % TMP_DIR
        self.generate_nodeconfig_file(ssl_file, nodeconfig_yaml_file)
        self.apply_yaml(nodeconfig_yaml_file)

    def save_cluster_kubeconfig(self, path, clusterName, ns):
        kubeconfig_secret_name = "%s-kubeconfig" % clusterName
        kubeconfig_str = self.get_secret_data_value(ns, kubeconfig_secret_name)
        kubeconfig_dict = yaml.safe_load(kubeconfig_str)
        with open(path, 'w') as file:
            yaml.dump(kubeconfig_dict, file, default_flow_style=False)
 
    def get_instance_replicas(self, cmd):
        kcps = self.get_dict(cmd)['items']
        total = 0
        for k in kcps:
            total = total + k['spec']['replicas']
        return total
           
    def get_machines(self, ns):
        cmd = "machine -n %s" % ns
        return self.get_dict(cmd)['items']
    
    def verify_machine(self):
        machines = self.get_machines("tkg-system")
        nodes = self.get_dict("node")['items']
        compare_machine_node(machines, nodes)
        
        kcp_replicas = self.get_instance_replicas("kcp -n %s" % "tkg-system")
        np_replicas = self.get_instance_replicas("md -n %s" % "tkg-system")
        if kcp_replicas + np_replicas != len(nodes):
            logger.error("expected %s cluster node, but now there are %s, please check again later" % (kcp_replicas + np_replicas, len(nodes)))
            sys.exit()
        logger.info("cluster nodes: up to date")
  
def compare_machine_node(machines, nodes):
    for m in machines:
        isFound = False
        for n in nodes:
            if m['metadata']['name'] == n['metadata']['name']:
                isFound = True
                break
        if not isFound:
            logger.error("machine %s might not join cluster. If it doesn't join cluster for a longer time, please ssh login and check if there is an exception in /var/log/cloud-init-output.log." % m['metadata']['name'])
            sys.exit()
    

class WrklClusterClient(ClusterClient):
    def __init__(self, mc_kubeconfig_file, wc_kubeconfig_file, wName, repo):
        super().__init__(mc_kubeconfig_file)
        self.mc_kubeconfig_file = mc_kubeconfig_file
        self.wName = wName
        self.repo = repo
        self.wc_client = ClusterClient(wc_kubeconfig_file)

    def verify(self):
        cmd = "tkc -n %s %s" % (self.wName, self.wName)
        tkcCR = self.get_dict(cmd)
        if tkcCR['status'] is None:
            logger.error("workload cluster [%s] has no status field in TcaKubernetesCluster spec" % self.wName)
            return
        
        if tkcCR['status']['phase'] == "Deleting":
            logger.info("workload cluster [%s] is being deleted, ignore check" % self.wName)
            return
        
        verify_nodeconfig(self.repo, self.wc_client)
        self.verify_tkcp_status()
        self.verify_tknp_status()
        self.verify_machine()
        
    def verify_tkcp_status(self):
        cmd = "tkcp -n %s" % self.wName
        tkcps = self.get_dict(cmd)['items']
        for tkcp in tkcps:
            if tkcp['status'] is None:
                logger.error("TcaKubeControlPlane [%s] has no status field in spec" % tkcp['metadata']['name'])
            else:
                if tkcp["status"]["phase"] == "Provisioned":
                    logger.info("update TcaKubeControlPlane [%s] successfully" % tkcp['metadata']['name'])
                else:
                    logger.error("the phase of TcaKubeControlPlane [%s] is %s" % (tkcp['metadata']['name'], tkcp['status']['phase']))
    
    def verify_tknp_status(self):
        cmd = "tknp -n %s" % self.wName
        tknps = self.get_dict(cmd)['items']
        for tknp in tknps:
            if tknp['status'] is None:
                logger.error("TcaNodePool [%s] has no status field in spec" % tknp['metadata']['name'])
            else:
                if tknp["status"]["phase"] == "Provisioned":
                    logger.info("update TcaNodePool [%s] successfully" % tknp['metadata']['name'])
                else:
                    logger.error("the phase of TcaNodePool [%s] is %s" % (tknp['metadata']['name'], tknp['status']['phase']))
    
    def verify_machine(self):
        cmd = "machine -n %s" % self.wName
        machines = self.get_dict(cmd)['items']
        nodes = self.wc_client.get_dict("node")['items']
        compare_machine_node(machines, nodes)

def verify_nodeconfig(repo, c_client):
    cmd = "nodeconfig -n tca-system"
    nodeConfigs = c_client.get_dict(cmd)['items']
    base64_cert_with_newline = add_newline_char_cert(repo.get_plain_ca())
    for nc in nodeConfigs:
        if str(nc['metadata']['name']).startswith("update-airgap-certs") :
            config_yaml_str = nc['spec']['config']
            config_dict = yaml.safe_load(config_yaml_str)
            for f in config_dict['fileInjection']:
                if f['content'] != base64_cert_with_newline:
                    logger.error("nodeConfig [%s] is not updated yet" % nc['metadata']['name'])
                    sys.exit()
            logger.info("nodeConfig [%s] is updated successuflly" % nc['metadata']['name'])

# the content of the CA cert pem file should have a newline charater at the end
# else it will impact cert rehash function
def add_newline_char_cert(ca_cert_plain_text):
    ca_cert_plain_text = ca_cert_plain_text + "\n"
    ca_bytes = ca_cert_plain_text.encode('ascii')
    return base64.b64encode(ca_bytes).decode('ascii')