#!/usr/bin/python3
# Copyright 2023 VMware, Inc. All Rights Reserved.

'''
functionalities:

update ca certificate for a airgap repo to existing clusters
input: 1. airgap repo fqdn
       2. ca certificate file
       3. vc credential

output: dump updated resources results
'''

import argparse
import logging
import logging.handlers
import os
import json
import subprocess
import sys
import requests
import base64
import configparser
import yaml
import traceback
import pathlib
import psycopg2

logger = logging.getLogger(__name__)
TMP_DIR = "/tmp/update-ca"
CWD = os.path.dirname(os.path.realpath(__file__))
CLI_VERSION = "v3.1.0"

def logger_setup(log_dst="console", log_level='info', log_file='/common/logs/update-cert.log'):
    log_formatter = logging.Formatter('%(module)s[%(levelname)s]: %(message)s')
    handler = None

    log_levels = {
        "debug": logging.DEBUG,
        "info": logging.INFO,
        "warning": logging.WARNING,
        "error": logging.ERROR,
        "critical": logging.CRITICAL,
    }

    if log_dst == "file":
        handler = logging.FileHandler(log_file)
    elif log_dst == "console":
        handler = logging.StreamHandler()

    handler.setFormatter(log_formatter)
    logger.addHandler(handler)

    if log_level in log_levels.keys():
        logger.setLevel(log_levels[log_level])

def str_presenter(dumper, data):
    if len(data.splitlines()) > 1:  # check for multiline string
        return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
    return dumper.represent_scalar('tag:yaml.org,2002:str', data)

def init_yaml():
    yaml.add_representer(str, str_presenter)
    # to use with safe_dump:
    yaml.representer.SafeRepresenter.add_representer(str, str_presenter)

class AirgapRepo:
    verified = False

    def __new__(cls, fqdn, cafile):
        if not hasattr(cls, 'instance'):
            cls.instance = super(AirgapRepo, cls).__new__(cls)
        return cls.instance

    def __init__(self, fqdn, cafile):
        self.fqdn = fqdn
        self.cafile = cafile
        self.ca_plain_text = None
        self.ca_base64 = None
        return self.load_ca()

    def load_ca(self):
        try:
            with open(self.cafile, 'r', encoding='ascii') as file:
                self.ca_plain_text = file.read().strip()
                ca_bytes = self.ca_plain_text.encode('ascii')
                self.ca_base64 = base64.b64encode(ca_bytes).decode('ascii')
        except FileNotFoundError:
            logger.error("File is not found.")
        except PermissionError:
            logger.error("You don't have permission to access this file.")
        return
    def get_base64_ca(self):
        return self.ca_base64
    def get_plain_ca(self):
        return self.ca_plain_text
    def get_fqdn(self):
        return self.fqdn
    def verify(self):
        repo_url = "https://%s" % self.fqdn
        response = requests.get(repo_url, verify=self.cafile)
        if response.ok:
            self.verified = True
            return True, ""
        else:
            return False, response.reason
    def has_verified(self):
        return self.verified
    def is_valid(self):
        try:
            return self.verify()
        except Exception as e:
            return False, repr(e)
    def dump(self):
        logger.info("dumpping the repo: %s" % self.fqdn)
        logger.info("repo cafile: %s" % self.cafile)
        logger.info("repo ca plain text: %s" % self.ca_plain_text)
        logger.info("repo ca base64: %s" % self.ca_base64)
        logger.info("repo is_valid: %s, %s" % self.is_valid())

def create_airgap_repo(fqdn, cafile):
    repo = AirgapRepo(fqdn, cafile)
    if repo.has_verified():
        return repo
    valid, errmsg = repo.is_valid()
    if not valid:
        logger.error("airgap repo: %s is invalid: %s" % (fqdn, errmsg))
        sys.exit()
    else:
        logger.info("airgap repo: %s is valid" % fqdn)
    return repo

def run_command(command):
    p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
    s, stderr = p.communicate()
    rc = p.returncode
    
    if rc:
        try:
            error = stderr.decode('utf-8')
            logger.error("run \"%s\" failed, err: %s" % (command, error))
        except Exception as e:
            logger.error("run \"%s\" failed, err: %s" % (command, e))
        sys.exit()
    return s.decode('utf-8')

def save_file(file_name, file_content):
    f = open(file_name, "w")
    f.write(str(file_content))
    f.close()

def get_script_location():
    cmd = "cat /common/configs/appliance.properties | grep applianceRole"
    location = run_command(cmd)
    location = location.strip()[14:]
    return location

def update_cert_db(args):
    location = get_script_location()
    if location == "Manager":
        repo = create_airgap_repo(args.fqdn, args.cafile)
        base64_ca = repo.get_base64_ca()
        cmd = "kubectl get pods postgres-0 -n tca-mgr -o=custom-columns=IP:.status.podIP | sed -n 2p"
        postgres_ip = run_command(cmd)
        postgres_ip = postgres_ip.strip()
        cmd = "kubectl get secret -n tca-mgr tca-admin-cred-secret -o jsonpath='{.data.password}' | base64 -d"
        tca_password = run_command(cmd)
        
        try:
            conn = psycopg2.connect(
                host = postgres_ip,
                port = 5432,
                user = "tca_admin",
                password = tca_password,
                database = "tca"
            )
            
            cur = conn.cursor()
            
            logger.info("########## Quering %s's id,val in Postgres ##########" % args.fqdn)
            
            context = "select id,val from \"Extension\" where strpos(val::text,'%s')>0;" % args.fqdn
            cur.execute(context)
            data = cur.fetchone()
            id = data[0]
            json_str = data[1]
            logger.debug("the interfaceInfo is %s" % json_str["interfaceInfo"])
            
            json_str["interfaceInfo"]["caCert"] = base64_ca
            logger.info("########## Updating %s's val by id in Postgres ##########" % args.fqdn)
            logger.info("the interfaceInfo is %s" % json_str["interfaceInfo"])
            
            json_str = json.dumps(json_str)
            context = "update \"Extension\" set val='%s' where id = %s;" % (json_str, id)
            cur.execute(context)
            conn.commit()
            
            logger.info("Successfully update cert db")
        except (Exception, psycopg2.Error) as error:
            logger.error("Error while connecting to PostgreSQL %s" % error)
        finally:
            if conn:
                cur.close()
                conn.close()
    else:
        logger.warning("the update_cert_db command only run in in TCA-Manager")

def get_mgmtclusters(k8s_service_ip):
    try:
        url = "http://%s:8888/api/v1/managementclusters" % k8s_service_ip
        response = requests.get(url)
        return response.json()
    except Exception as error:
        logger.error("Failed to get Managementclusters, err: %s" % error)
        sys.exit()

def get_mgmtcluster(mgmtclusters, mc_name):
    for mgmtcluster in mgmtclusters:
        if mgmtcluster['clusterName'] == mc_name:
            return mgmtcluster
    logger.warning("the Management Cluster <%s> don't exist", mc_name)
    sys.exit()

def get_tkg_context(k8s_service_ip, tkg_id):
    try:
        url = "http://%s:8888/api/v1/tkgcontext/%s?plaintext=true" % (k8s_service_ip, tkg_id)
        response = requests.get(url)
        logger.info("Successfully get TkgContext")
        return response.json()
    except Exception as error:
        logger.error("Failed to get TkgContext, err: %s" % error)
        sys.exit()

def update_tkg_context(k8s_service_ip, tkg_id, tkg_context):
    url = "http://%s:8888/api/v1/tkgcontext/%s" % (k8s_service_ip, tkg_id)
    response = requests.put(url, json = tkg_context)
    response.encoding = 'utf-8'
    if response.status_code == 200:
        logger.info("Updated tkgcontext %s with response %s" % (tkg_id, response))
    else:
        logger.error("Failed to update tkgcontext %s, err: %s" % (tkg_id, response.text))
        sys.exit()
    
def update_mgmt_cluster(k8s_service_ip, mc_id, mgmt_cluster):
    url = "http://%s:8888/api/v1/managementcluster/%s" % (k8s_service_ip, mc_id)
    response = requests.put(url, json = mgmt_cluster)
    response.encoding = 'utf-8'
    if response.status_code == 200:
        logger.info("Updated Managementcluster <%s> with response %s" % (mc_id, response))
    else:
        logger.error("Failed to update Managementcluster <%s>, err: %s" % (mc_id, response.text))
        sys.exit()

def get_k8s_service_ip():
    cmd = "kubectl get svc k8s-bootstrapper-service -n tca-cp-cn | awk '{print $3}' | sed -n 2p"
    k8s_service_ip = run_command(cmd)
    k8s_service_ip = k8s_service_ip.strip()
    return k8s_service_ip

def update_mgmtcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        k8s_service_ip = get_k8s_service_ip()
        mgmtclusters = get_mgmtclusters(k8s_service_ip)
        mgmtcluster = get_mgmtcluster(mgmtclusters, args.name)
        
        mc_id = mgmtcluster['id']
        tkg_id = mgmtcluster['tkgID']
        tkg_context = get_tkg_context(k8s_service_ip, tkg_id)
        
        repo = create_airgap_repo(tkg_context['airgap']['fqdn'], args.cafile)
        tkg_context['airgap']['caCert'] = repo.get_base64_ca()
        
        update_tkg_context(k8s_service_ip, tkg_id, tkg_context)
        update_mgmt_cluster(k8s_service_ip, mc_id, mgmtcluster)

        logger.info("Updated management cluster <%s>" % args.name)
    else:
        logger.warning("the update_mgmtcluster command only run in in TCA-ControlPlane")

def get_mgmtcluster_kubeconfig(k8s_service_ip, mc_id):
    try:
        url = "http://%s:8888/api/v1/managementcluster/%s/kubeconfig" % (k8s_service_ip, mc_id)
        response = requests.get(url)
        logger.info("Successfully get Managementcluster Kubeconfig")
        return response.text
    except Exception as error:
        logger.error("Failed to get Managementcluster Kubeconfig, err: %s" % error)
        sys.exit()

def verify_workloadcluster(wc_name):
    cmd = "kubectl --kubeconfig=kubeconfig get ns | awk '{print $1}' | grep -wx %s" % wc_name
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
    p.communicate()
    rc = p.returncode
    if rc:
        logger.warning("workload cluster <%s> don't exist", wc_name)
        sys.exit()
    logger.info("workload cluster <%s> exists", wc_name)

def update_workloadcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        k8s_service_ip = get_k8s_service_ip()
        mgmtclusters = get_mgmtclusters(k8s_service_ip)
        mgmtcluster = get_mgmtcluster(mgmtclusters, args.mc)
        
        mc_id = mgmtcluster['id']
        tkg_id = mgmtcluster['tkgID']
        tkg_context = get_tkg_context(k8s_service_ip, tkg_id)
        mc_kubeconfig = get_mgmtcluster_kubeconfig(k8s_service_ip, mc_id)
        save_file("kubeconfig", mc_kubeconfig)
        
        verify_workloadcluster(args.name)
        
        cmd = 'kubectl --kubeconfig=kubeconfig get tkc -n %s %s -o yaml' % (args.name, args.name)
        tkc_yaml = run_command(cmd)
        data = yaml.safe_load(tkc_yaml)
        
        data['spec']['airgap']['caCert'] = tkg_context['airgap']['caCert']
        data = yaml.safe_dump(data)
        save_file("tkc.yaml", data)
        
        cmd = 'kubectl --kubeconfig=kubeconfig apply -f tkc.yaml'
        run_command(cmd)
                
        logger.info("Updated workload cluster <%s>" % args.name)
    else:
        logger.warning("the update_mgmtcluster command only run in in TCA-ControlPlane")

def get_mgmtcluster_status(k8s_service_ip, mc_id):
    try:
        url = "http://%s:8888/api/v1/managementcluster/%s/status" % (k8s_service_ip, mc_id)
        response = requests.get(url)
        logger.info("Successfully get Managementcluster status")
        return response.json()
    except Exception as error:
        logger.error("Failed to get Managementcluster status, err: %s" % error)
        sys.exit()

def show_state_mgmtcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        k8s_service_ip = get_k8s_service_ip()
        mgmtclusters = get_mgmtclusters(k8s_service_ip)
        mgmtcluster = get_mgmtcluster(mgmtclusters, args.name)
        
        mc_id = mgmtcluster['id']
        mc_status = get_mgmtcluster_status(k8s_service_ip, mc_id)
        logger.info("the management cluster <%s> status is %s" % (args.name, mc_status['status']))        
    else:
        logger.warning("the update_mgmtcluster command only run in in TCA-ControlPlane")
            
def get_tkc_status(wc):
    cmd = "kubectl --kubeconfig=kubeconfig get tkc %s -n %s | awk '{print $6}' | sed -n 2p" % (wc, wc)
    status = run_command(cmd)
    status = status.strip()
    if status:
        logger.info("the TcaKubernetesCluster status is %s" % status)
    else:
        logger.warning("%s can't get the TcaKubernetesCluster status." % wc)
    return status

def get_tkcp_status(wc):
    cmd = "kubectl --kubeconfig=kubeconfig get tkcp -n %s | awk '{print $6}' | sed -n 2p" % wc
    status = run_command(cmd)
    status = status.strip()
    if status:
        logger.info("the TcaKubeControlPlane status is %s" % status)
    else:
        logger.warning("%s can't get the TcaKubeControlPlane status." % wc)
    return status

def get_tknp_status(wc):
    cmd = "kubectl --kubeconfig=kubeconfig get tknp -n %s | awk '{print $6}' | sed -n 2p" % wc
    status = run_command(cmd)
    status = status.strip()
    if status:
        logger.info("the TcaNodePool status is %s" % status)
    else:
        logger.warning("the workload cluster <%s> don't have NodePool." % wc)
    return status

def get_nodeprofilestatus(wc):
    cmd = "kubectl --kubeconfig=kubeconfig get secret %s-kubeconfig -n %s -o jsonpath='{.data.value}' | base64 -d" % (wc, wc)
    wc_kubeconfig = run_command(cmd)
    save_file("wckubeconfig", wc_kubeconfig)

    cmd = "kubectl --kubeconfig=wckubeconfig get nodeprofilestatus -n tca-system | wc -l"
    nodeprofilestatus_Nums = run_command(cmd)
    cmd = "kubectl --kubeconfig=wckubeconfig get nodeprofilestatus -n tca-system | grep Normal | wc -l"
    nodeprofilestatus_Normal_Nums = run_command(cmd)
    
    if int(nodeprofilestatus_Nums) == int(nodeprofilestatus_Normal_Nums) + 1:
        logger.info("the legacy workloadcluster <%s> nodeprofile status is Normal" % wc)
    else:
        logger.warning("the legacy workloadcluster nodeprofile status is not Normal")
        cmd = "kubectl --kubeconfig=wckubeconfig get nodeprofilestatus -n tca-system | grep -v Normal"
        results = run_command(cmd)
        logger.warning("\n%s", results)

def show_state_workloadcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        k8s_service_ip = get_k8s_service_ip()
        mgmtclusters = get_mgmtclusters(k8s_service_ip)
        mgmtcluster = get_mgmtcluster(mgmtclusters, args.mc)
        
        mc_id = mgmtcluster['id']
        mc_kubeconfig = get_mgmtcluster_kubeconfig(k8s_service_ip, mc_id)                
        save_file("kubeconfig", mc_kubeconfig)
        verify_workloadcluster(args.name)
        
        tkc_status = get_tkc_status(args.name)
        tkcp_status = get_tkcp_status(args.name)
        tknp_status = get_tknp_status(args.name)
        
        if tkc_status == "Provisioned" and tkcp_status == "Provisioned" and (tknp_status == "Provisioned" or tknp_status == ""):
            logger.info("the workloadcluster status is Provisioned")
        else:
            logger.warning("the workloadcluster status is unProvisioned")
        
        cmd = "kubectl get tkc --kubeconfig=kubeconfig -n %s | grep -wv standard | wc -l" % args.name
        status = run_command(cmd)
        
        logger.info("the script will check the legacy workloadcluster nodeprofilestatus")
        if int(status) != 1:
            logger.warning("the workloadcluster <%s> isn't the legacy workloadcluster, so can't get its nodeprofilestatus" % args.name)
        else:
            get_nodeprofilestatus(args.name)   
    else:
        logger.warning("the update_mgmtcluster command only run in in TCA-ControlPlane")

def version(args):
    logger.info("the CLI VERSION is %s" % CLI_VERSION)

def parse_args():
    parser = argparse.ArgumentParser(
        description='TCA CaaS update airgap repository trusted root certificate tool')

    parser.add_argument('--loglevel', choices=['debug','info','error','warning','critical'],
                        default='info', help='log level of script')
    parser.add_argument('--logdst', choices=['console', 'file'],
                        default='console', help='log destination, the default log file path is /common/logs/update-cert.log')

    subparser = parser.add_subparsers(help='-h for additional help')
    sp = subparser.add_parser('update-cert-db', help='update airgap certificate in TCA-M database')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--fqdn', required=True, help='FQDN of the airgap repository server')
    ag.add_argument('--cafile', required=True, help='new trusted root certificate file of the airgap repository server')
    sp.set_defaults(func=update_cert_db)

    sp = subparser.add_parser('update-mgmtcluster', help='update airgap trusted root certificate of specified management cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--cafile', required=True, help='new trusted root certificate file of the airgap repository server')
    ag.add_argument('--name', required=True, help='management cluster name')
    sp.set_defaults(func=update_mgmtcluster)

    sp = subparser.add_parser('update-workloadcluster', help='update airgap trusted root certificate of specified workload cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--mc', required=True, help='management cluster name')
    ag.add_argument('--name', required=True, help='workload cluster name')
    sp.set_defaults(func=update_workloadcluster)

    sp = subparser.add_parser('show-state-mgmtcluster', help='show status of renewing airgap certificate relevant to management cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--name', required=True, help='management cluster name')
    sp.set_defaults(func=show_state_mgmtcluster)

    sp = subparser.add_parser('show-state-workloadcluster', help='show status of renewing airgap certificate relevant to workload cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--mc', required=True, help='management cluster name')
    ag.add_argument('--name', required=True, help='workload cluster name')
    sp.set_defaults(func=show_state_workloadcluster)

    sp = subparser.add_parser('version', help='print the version of CLI')
    sp.set_defaults(func=version)
    
    return parser.parse_args()

def main(argv=None):
    pathlib.Path(TMP_DIR).mkdir(parents=True, exist_ok=True)
    args = parse_args()
    logger_setup(log_dst=args.logdst, log_level=args.loglevel, log_file="/common/logs/update-cert.log")
    init_yaml()
    args.func(args)

if __name__== "__main__":
    sys.exit(main())
