#!/usr/bin/python3
# Copyright 2025 VMware by Broadcom, Inc. All rights reserved.

'''
functionalities:

update ca certificate for a airgap repo to the existing cluster
input: 1. airgap repo fqdn
       2. ca certificate file
       3. cluster name

show cluster update status
output: dump the related resources results
'''

import argparse
import os
import json
import subprocess
import sys
import base64
import yaml
import traceback
import pathlib
import psycopg2
from cluster import client

CLI_VERSION = "v3.3.0"

TMP_DIR = "/tmp/update-ca"
KUBE_CONFIG_FILE_PATTEN = "/tmp/update-ca/%s-kubeconfig.yaml"
CWD = os.path.dirname(os.path.realpath(__file__))

def str_presenter(dumper, data):
    if len(data.splitlines()) > 1:  # check for multiline string
        return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
    return dumper.represent_scalar('tag:yaml.org,2002:str', data)

def init_yaml():
    yaml.add_representer(str, str_presenter)
    # to use with safe_dump:
    yaml.representer.SafeRepresenter.add_representer(str, str_presenter)

def create_airgap_repo(fqdn, cafile):
    repo = client.AirgapRepo(fqdn, cafile)
    if repo.has_verified():
        return repo
    valid, errmsg = repo.is_valid()
    if not valid:
        client.logger.error("airgap repo: %s is invalid: %s" % (fqdn, errmsg))
        sys.exit()
    else:
        client.logger.info("airgap repo: %s is valid" % fqdn)
    return repo

def run_command(command):
    p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
    s, stderr = p.communicate()
    rc = p.returncode
    
    if rc:
        try:
            error = stderr.decode('utf-8')
            client.logger.error("run \"%s\" failed, err: %s" % (command, error))
        except Exception as e:
            client.logger.error("run \"%s\" failed, err: %s" % (command, e))
        sys.exit()
    return s.decode('utf-8')

def save_file(file_name, file_content):
    f = open(file_name, "w")
    f.write(str(file_content))
    f.close()

def get_cluster_kubeconfig_file(name):
    return KUBE_CONFIG_FILE_PATTEN % name

def save_base64_airgpap_cert_file(airgap_cert_file, base64_cert):
    try:
        decoded_bytes = base64.b64decode(base64_cert)
        with open(airgap_cert_file, 'wb') as file:
            file.write(decoded_bytes)
            file.close()
    except Exception as e:
        client.logger.critical("failed to save Airgap CA cert into the file [%s], err: %s" 
                               % (airgap_cert_file, repr(e)))
        traceback.print_exc()
        raise e

def get_script_location():
    cmd = "cat /common/configs/appliance.properties | grep applianceRole"
    location = run_command(cmd)
    location = location.strip()[14:]
    return location

def update_cert_db(args):
    location = get_script_location()
    if location == "Manager":
        repo = create_airgap_repo(args.fqdn, args.cafile)
        base64_ca = repo.get_base64_ca()
        cmd = "kubectl get pods postgres-0 -n tca-mgr -o=custom-columns=IP:.status.podIP | sed -n 2p"
        postgres_ip = run_command(cmd)
        postgres_ip = postgres_ip.strip()
        cmd = "kubectl get secret -n tca-mgr tca-admin-cred-secret -o jsonpath='{.data.password}' | base64 -d"
        tca_password = run_command(cmd)
        
        try:
            conn = psycopg2.connect(
                host = postgres_ip,
                port = 5432,
                user = "tca_admin",
                password = tca_password,
                database = "tca"
            )
            
            cur = conn.cursor()
            
            client.logger.info("########## Quering %s's id,val in Postgres ##########" % args.fqdn)
            
            context = "select id,val from \"Extension\" where strpos(val::text,'%s')>0;" % args.fqdn
            cur.execute(context)
            data = cur.fetchone()
            id = data[0]
            json_str = data[1]
            client.logger.debug("the interfaceInfo is %s" % json_str["interfaceInfo"])
            
            json_str["interfaceInfo"]["caCert"] = base64_ca
            client.logger.info("########## Updating %s's val by id in Postgres ##########" % args.fqdn)
            client.logger.info("the interfaceInfo is %s" % json_str["interfaceInfo"])
            
            json_str = json.dumps(json_str)
            context = "update \"Extension\" set val='%s' where id = %s;" % (json_str, id)
            cur.execute(context)
            conn.commit()
            
            client.logger.info("Successfully update cert db")
        except (Exception, psycopg2.Error) as error:
            client.logger.error("Error while connecting to PostgreSQL %s" % error)
        finally:
            if conn:
                cur.close()
                conn.close()
    else:
        client.logger.warning("the update_cert_db command only run in in TCA-Manager")

def get_k8s_service_ip():
    cmd = "kubectl get svc k8s-bootstrapper-service -n tca-cp-cn | awk '{print $3}' | sed -n 2p"
    k8s_service_ip = run_command(cmd)
    k8s_service_ip = k8s_service_ip.strip()
    return k8s_service_ip

def update_mgmtcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        # The code `k8s_service_ip` is defining a variable in Python. This variable is likely intended
        # to store the IP address of a Kubernetes service.
        kbs = client.Kbs(get_k8s_service_ip())
        
        # 1. update Airgap CA cert stored in tkg context
        mc = kbs.get_mgmt_cluster(args.name)
        tkg_context = kbs.get_tkg_context(mc['tkgID'])
        repo = create_airgap_repo(tkg_context['airgap']['fqdn'], args.cafile)
        tkg_context['airgap']['caCert'] = repo.get_base64_ca()
        kbs.update_tkg_context(tkg_context)
        
        # 2. update Airgap CA cert stored in the management cluster
        mc_kubeconfig_str = kbs.get_mgmtcluster_kubeconfig(mc['id'])
        mc_kubeconfig_file = get_cluster_kubeconfig_file(args.name)
        save_file(mc_kubeconfig_file, mc_kubeconfig_str)
        kbs.update_mgmt_cluster(mc, repo, mc_kubeconfig_file)
        
        client.logger.info("Updated management cluster <%s>" % args.name)
    else:
        client.logger.warning("the update_mgmtcluster command only run in in TCA-ControlPlane")

def show_state_mgmtcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        # create airgap repo instance
        kbs = client.Kbs(get_k8s_service_ip())
        mc = kbs.get_mgmt_cluster(args.name)
        tkg_context = kbs.get_tkg_context(mc['tkgID'])
        airgap_cert_file = "%s/%s-cert.yaml" % (TMP_DIR, args.name)
        save_base64_airgpap_cert_file(airgap_cert_file, tkg_context['airgap']['caCert'])
        repo = create_airgap_repo(tkg_context['airgap']['fqdn'], airgap_cert_file)
        
        # verify
        mc_kubeconfig_str = kbs.get_mgmtcluster_kubeconfig(mc['id'])
        mc_kubeconfig_file = get_cluster_kubeconfig_file(args.name)
        save_file(mc_kubeconfig_file, mc_kubeconfig_str)
        mc_client = client.MgmtClusterClient(mc_kubeconfig_file, repo)
        if mc_client.IsClusterClass(args.name):
            mc_client.cc_verify(args.name)
        else:
            mc_client.verify()        
    else:
        client.logger.warning("the show_state_mgmtcluster command only run in in TCA-ControlPlane")

def verify_workloadcluster(mc_kubeconfig, wc_name):
    cmd = "kubectl --kubeconfig=%s get ns | awk '{print $1}' | grep -wx %s" % (mc_kubeconfig, wc_name)
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
    p.communicate()
    rc = p.returncode
    if rc:
        client.logger.warning("workload cluster <%s> don't exist", wc_name)
        sys.exit()
    client.logger.info("workload cluster <%s> exists", wc_name)

def update_workloadcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        # query mc kubeconfig and save it
        kbs = client.Kbs(get_k8s_service_ip())
        mc = kbs.get_mgmt_cluster(args.mc)
        mc_kubeconfig_str = kbs.get_mgmtcluster_kubeconfig(mc['id'])
        mc_kubeconfig_file = get_cluster_kubeconfig_file(args.name)
        save_file(mc_kubeconfig_file, mc_kubeconfig_str)
        
        # 2. validate if workload cluster exist
        verify_workloadcluster(mc_kubeconfig_file, args.name)
        
        # 3. get tkc instance of workload cluster
        cmd = 'kubectl --kubeconfig=%s get tkc -n %s %s -o yaml' % (mc_kubeconfig_file, args.name, args.name)
        tkc_yaml = run_command(cmd)
        data = yaml.safe_load(tkc_yaml)
        
        # 4. update airgap CA cert
        tkg_context = kbs.get_tkg_context(mc['tkgID'])
        data['spec']['airgap']['caCert'] = tkg_context['airgap']['caCert']
        data = yaml.safe_dump(data)
        wc_spec_file = "%s/%s-tkc.yaml" % (TMP_DIR, args.name)
        save_file(wc_spec_file, data)
        
        # 5. apply new tkc spec
        cmd = 'kubectl --kubeconfig=%s apply -f %s' % (mc_kubeconfig_file, wc_spec_file)
        run_command(cmd)
                
        client.logger.info("Updated workload cluster <%s>" % args.name)
    else:
        client.logger.warning("the update_workloadcluster command only run in in TCA-ControlPlane")

def show_state_workloadcluster(args):
    location = get_script_location()
    if location == "ControlPlane":
        kbs = client.Kbs(get_k8s_service_ip())
        # query mc kubeconfig and save it
        mc = kbs.get_mgmt_cluster(args.mc)
        mc_kubeconfig_str = kbs.get_mgmtcluster_kubeconfig(mc['id'])
        mc_kubeconfig_file = get_cluster_kubeconfig_file(args.mc)
        save_file(mc_kubeconfig_file, mc_kubeconfig_str)
        
        # 2. validate if workload cluster exist
        verify_workloadcluster(mc_kubeconfig_file, args.name)
        
        # 3. create airgap repo instance
        tkg_context = kbs.get_tkg_context(mc['tkgID'])
        airgap_cert_file = "%s/%s-cert.yaml" % (TMP_DIR, args.name)
        save_base64_airgpap_cert_file(airgap_cert_file, tkg_context['airgap']['caCert'])
        repo = create_airgap_repo(tkg_context['airgap']['fqdn'], airgap_cert_file)
        
        # 4. get wc kubeconfig and save it
        mc_client = client.MgmtClusterClient(mc_kubeconfig_file, repo)
        wc_kubeconfig_file = get_cluster_kubeconfig_file(args.name)
        mc_client.save_cluster_kubeconfig(wc_kubeconfig_file, args.name, args.name)
        
        # 5. verify
        wc_client = client.WrklClusterClient(mc_kubeconfig_file, wc_kubeconfig_file, args.name, repo)
        wc_client.verify()  
    else:
        client.logger.warning("the show_state_workloadcluster command only run in in TCA-ControlPlane")

def version(args):
    client.logger.info("the CLI VERSION is %s" % CLI_VERSION)

def parse_args():
    parser = argparse.ArgumentParser(
        description='TCA CaaS update airgap repository trusted root certificate tool')

    parser.add_argument('--loglevel', choices=['debug','info','error','warning','critical'],
                        default='info', help='log level of script')
    parser.add_argument('--logdst', choices=['console', 'file'],
                        default='console', help='log destination, the default log file path is /common/logs/update-cert.log')

    subparser = parser.add_subparsers(help='-h for additional help')
    sp = subparser.add_parser('update-cert-db', help='update airgap certificate in TCA-M database')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--fqdn', required=True, help='FQDN of the airgap repository server')
    ag.add_argument('--cafile', required=True, help='new trusted root certificate file of the airgap repository server')
    sp.set_defaults(func=update_cert_db)

    sp = subparser.add_parser('update-mgmtcluster', help='update airgap trusted root certificate of specified management cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--cafile', required=True, help='new trusted root certificate file of the airgap repository server')
    ag.add_argument('--name', required=True, help='management cluster name')
    sp.set_defaults(func=update_mgmtcluster)

    sp = subparser.add_parser('update-workloadcluster', help='update airgap trusted root certificate of specified workload cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--mc', required=True, help='management cluster name')
    ag.add_argument('--name', required=True, help='workload cluster name')
    sp.set_defaults(func=update_workloadcluster)

    sp = subparser.add_parser('show-state-mgmtcluster', help='show status of renewing airgap certificate relevant to management cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--name', required=True, help='management cluster name')
    sp.set_defaults(func=show_state_mgmtcluster)

    sp = subparser.add_parser('show-state-workloadcluster', help='show status of renewing airgap certificate relevant to workload cluster')
    ag = sp.add_argument_group('required arguments')
    ag.add_argument('--mc', required=True, help='management cluster name')
    ag.add_argument('--name', required=True, help='workload cluster name')
    sp.set_defaults(func=show_state_workloadcluster)

    sp = subparser.add_parser('version', help='print the version of CLI')
    sp.set_defaults(func=version)
    
    return parser.parse_args()

def main(argv=None):
    pathlib.Path(TMP_DIR).mkdir(parents=True, exist_ok=True)
    args = parse_args()
    client.logger_setup(log_dst=args.logdst, log_level=args.loglevel, log_file="/common/logs/update-cert.log")
    init_yaml()
    args.func(args)

if __name__== "__main__":
    sys.exit(main())
