#!/usr/bin/env python3
# -*- coding: utf-8 -*- #
# ==============================================================#
# File      :   validate
# Desc      :   validate pigsty config file
# Ctime     :   2023-04-21
# Mtime     :   2023-04-21
# Path      :   bin/validate
# Deps      :   yaml
# License   :   Apache-2.0 @ https://pigsty.io/docs/about/license/
# Copyright :   2018-2026  Ruohang Feng / Vonng (rh@vonng.com)
# ==============================================================#
__author__ = 'Vonng (rh@vonng.com)'

from argparse import ArgumentParser
import os, sys, re, json, yaml


def usage():
    print("""bin/validate [../pigsty.yml]""")


#######################
# utils
#######################
def is_valid_ipv4(s):
    return True if re.compile(r'^(\d{1,3}\.){3}\d{1,3}$').match(s) else False


def is_valid_ipv4_cidr(s):
    return True if re.compile(r'^(\d{1,3}\.){3}\d{1,3}/\d{1,2}$').match(s) else False


#######################
# parse
#######################
def parse_conf(path):
    try:
        conf = json.loads(json.dumps(yaml.safe_load(open(path, 'r'))))
    except Exception as e:
        raise Exception("fail to parse config %s" % path)
    if not conf:
        raise Exception("empty config file")
    if 'all' not in conf:
        raise Exception("invalid config: `all` not found")
    if not isinstance(conf['all'], dict):
        raise Exception("invalid config: `all` is not dict")
    if 'children' not in conf['all']:
        raise Exception("invalid config: `all.children` not found")
    if not isinstance(conf['all']['children'], dict):
        raise Exception("invalid config: `all.children` is not dict")
    if 'vars' not in conf['all']:
        print("[WARN] global vars `all.vars` not found")
        conf['all']['vars'] = {}
    if not isinstance(conf['all']['vars'], dict):
        raise Exception("invalid config: `all.vars` is not dict")
    return conf


def validate_global_vars(conf):
    success = True
    global_vars = conf['all']['vars']

    # validate admin_ip in global vars
    if 'admin_ip' not in global_vars:
        print("[ERRO] 'all.vars.admin_ip' not found")
        success = False
    elif not is_valid_ipv4(global_vars['admin_ip']):
        print("[ERRO] 'all.vars.admin_ip' is not a valid ipv4 address")
        success = False

    # warn about version & region
    if 'version' not in global_vars:
        print("[WARN] 'all.vars.version' not found, is this a pigsty config file?")
    # warn about missing version
    if 'region' not in global_vars:
        print("[WARN] 'all.vars.region' not found, is this pigsty file properly configured?")
    elif global_vars['region'] not in ('default', 'china', 'europe'):
        print("[WARN] 'all.vars.region' %s is not a valid value" % global_vars['region'])

    # print all.vars check result
    if success:
        print("[ OK ] all.vars")
    else:
        print("[FAIL] all.vars")

    return success


MODULES = {"infra": {}, "pgsql": {}, "redis": {}, "minio": {}, "etcd": {}}


def validate_children(conf):
    """
    fail if duplicate group names found
    warn if special 'infra' and 'etcd' group not found
    """
    # one host can only belong to one group within same module
    success = True

    for gname, gconf in conf['all']['children'].items():
        if not gconf:
            print("[ERRO] 'all.children.%s' is empty" % gname)
            success = False
            continue
        elif 'hosts' not in gconf:
            print("[ERRO] 'all.children.%s.hosts' not found" % gname)
            success = False
        elif not isinstance(gconf['hosts'], dict):
            print("[ERRO] 'all.children.%s.hosts' is not a dict" % gname)
            success = False
        elif len(gconf['hosts']) < 1:
            print("[ERRO] 'all.children.%s.hosts' is empty" % gname)
            success = False
        if gconf and gconf.get('hosts'):
            ghosts = list(set(gconf['hosts'].keys()))
            for h in ghosts:
                if not is_valid_ipv4(h):
                    print("[ERRO] 'all.children.%s.hosts.%s' is not a valid ipv4 address" % (gname, h))
                    success = False
        else:
            ghosts = []

        if 'vars' in gconf and not isinstance(gconf['vars'], dict):
            print("[ERRO] 'all.children.%s.vars' is not a dict" % gname)
            success = False
        if 'vars' not in gconf:
            # print("[WARN] 'all.children.%s.vars' is empty" % gname)
            gconf['vars'] = {}

        if gname == 'infra':
            MODULES["infra"][gname] = ghosts
        if 'pg_cluster' in gconf['vars']:
            MODULES["pgsql"][gname] = ghosts
        if 'etcd_cluster' in gconf['vars']:
            MODULES["etcd"][gname] = ghosts
        if 'minio_cluster' in gconf['vars']:
            MODULES["minio"][gname] = ghosts
        if 'redis_cluster' in gconf['vars']:
            MODULES["redis"][gname] = ghosts

    for mname, module in MODULES.items():
        reverse_dict = {}
        for gname, ghosts in module.items():
            for hname in ghosts:
                if hname not in reverse_dict:
                    reverse_dict[hname] = []
                reverse_dict[hname].append(gname)
        for host, groups in reverse_dict.items():
            if len(groups) > 1:
                print("[ERRO] host %s belongs to multiple groups: %s " % (host, groups))
                success = False

    # print all.vars check result
    if success:
        print("[ OK ] all.children")
        return success
    else:
        print("[FAIL] all.children")
        raise Exception("invalid config structure")


def validate_infra_group(conf):
    success = True
    if 'infra' not in conf["all"]["children"]:
        print("[WARN] `all.children.infra` not found")
    else:
        gconf = conf["all"]["children"]["infra"]
        for hname, hvars in gconf['hosts'].items():
            if 'infra_seq' not in hvars:
                print("[ERRO] `all.children.infra.hosts.%s.infra_seq` not found" % (hname))
                success = False
            elif not isinstance(hvars['infra_seq'], int):
                print("[ERRO] `all.children.infra.hosts.%s.infra_seq` is not an integer" % (hname))
                success = False
            elif hvars['infra_seq'] < 0:
                print("[ERRO] `all.children.infra.hosts.%s.infra_seq` should be an integer >= 0" % (hname))
                success = False
    if success:
        print("[ OK ] all.children.infra")
    else:
        print("[FAIL] all.children.infra")
    return success


def validate_etcd_group(conf):
    success = True
    if 'etcd' not in conf["all"]["children"]:
        if len(MODULES["pgsql"]) > 0:
            print("[WARN] `all.children.etcd` not found, which is required for pgsql")
    else:
        gconf = conf["all"]["children"]["etcd"]
        for hname, hvars in gconf['hosts'].items():
            if 'etcd_seq' not in hvars:
                print("[ERRO] `all.children.etcd.hosts.%s.etcd_seq` not found" % (hname))
                success = False
            elif not isinstance(hvars['etcd_seq'], int):
                print("[ERRO] `all.children.etcd.hosts.%s.etcd_seq` is not an integer" % (hname))
                success = False
            elif hvars['etcd_seq'] < 0:
                print("[ERRO] `all.children.etcd.hosts.%s.etcd_seq` should be an integer >= 0" % (hname))
                success = False
    if success:
        print("[ OK ] all.children.etcd")
    else:
        print("[FAIL] all.children.etcd")
    return success


def validate_minio_group(conf):
    success = True
    if 'minio' not in conf["all"]["children"]:
        if conf["all"]["vars"].get("pgbackrest_method") == "minio":
            print("[WARN] `all.children.minio` not found, which is used by pgBackRest repo")
    else:
        gconf = conf["all"]["children"]["minio"]
        for hname, hvars in gconf['hosts'].items():
            if 'minio_seq' not in hvars:
                print("[ERRO] `all.children.minio.hosts.%s.minio_seq` not found" % (hname))
                success = False
            elif not isinstance(hvars['minio_seq'], int):
                print("[ERRO] `all.children.minio.hosts.%s.minio_seq` is not an integer" % (hname))
                success = False
            elif hvars['minio_seq'] < 0:
                print("[ERRO] `all.children.minio.hosts.%s.minio_seq` should be an integer >= 0" % (hname))
                success = False
    if success:
        print("[ OK ] all.children.minio")
    else:
        print("[FAIL] all.children.minio")
    return success


def validate_pgsql_group(conf, cls):
    success = True
    gconf = conf["all"]["children"][cls]
    gvars = gconf.get('vars', {})
    pg_cluster = gvars["pg_cluster"]
    if pg_cluster != cls:
        print("[WARN] pg_cluster name %s is different from group name %s" % (pg_cluster, cls))

    # check pg_cluster name matching [a-zA-Z0-9-]+
    if not re.match(r'^[a-zA-Z0-9-]+$', pg_cluster):
        print("[ERROR] pg_cluster name %s is not valid for ^[a-zA-Z0-9-]+$" % pg_cluster)
        success = False

    # validate hosts
    for hname, hvars in gconf['hosts'].items():

        if 'pg_seq' not in hvars:
            print("[ERRO] `all.children.%s.hosts.%s.pg_seq` not found" % (cls, hname))
            success = False
        elif not isinstance(hvars['pg_seq'], int):
            print("[ERRO] `all.children.%s.hosts.%s.pg_seq` is not an integer" % (cls, hname))
            success = False
        elif hvars['pg_seq'] < 0:
            print("[ERRO] `all.children.%s.hosts.%s.pg_seq` should be an integer >= 0" % (cls, hname))
            success = False

        if 'pg_role' not in hvars:
            print("[ERRO] `all.children.%s.hosts.%s.pg_role` not found" % (cls, hname))
            success = False
        elif not isinstance(hvars['pg_role'], str) or \
                hvars['pg_role'] not in ['primary', 'replica', 'standby', 'offline', 'delayed']:
            print("[ERRO] `all.children.%s.hosts.%s.pg_role` '%s' should be priamry, replica,..." % (
                cls, hname, hvars['pg_role']))
            success = False

        if 'pg_upstream' in hvars:
            if not isinstance(hvars['pg_upstream'], str):
                print("[ERRO] `all.children.%s.hosts.%s.pg_upstream` should be a string" % (cls, hname))
                success = False
            elif not is_valid_ipv4(hvars['pg_upstream']):
                print("[ERRO] `all.children.%s.hosts.%s.pg_upstream` %s should be a valid IPv4" % (
                    cls, hname, hvars['pg_upstream']))
                success = False

    # validate pgsql cluster objects
    success = validate_pgsql_user(gvars) and success
    success = validate_pgsql_db(gvars) and success
    success = validate_pgsql_svc(gvars) and success
    success = validate_pgsql_hba(gvars) and success
    success = validate_pgsql_vip(gvars) and success

    if success:
        print("[ OK ] all.children.%s" % cls)
    else:
        print("[FAIL] all.children.%s" % cls)
    return success


def validate_pgsql_groups(conf):
    success = True
    pg_groups = []  # get groups marked with pg_cluster on cluster vars
    for gname, gconf in conf['all']['children'].items():
        if 'pg_cluster' in gconf.get('vars', {}): pg_groups.append(gname)
    for cls in pg_groups:
        success = validate_pgsql_group(conf, cls) and success
    return success


def validate_type(v, t):
    if t == 'string' or t == 'str':
        return isinstance(v, str)
    elif t == 'integer' or t == 'int':
        return isinstance(v, int)
    elif t == 'float':
        return isinstance(v, float)
    elif t == 'bool':
        return isinstance(v, bool)
    elif t == 'list':
        return isinstance(v, list)
    elif t == 'dict':
        return isinstance(v, dict)
    elif t == 'date':
        return True if isinstance(v, str) and re.match(r'^\d{4}-\d{2}-\d{2}$', v) else False
    elif t == 'str[]' or t == 'string[]':
        return isinstance(v, list) and all(isinstance(i, str) for i in v)
    elif t == 'dict[]':
        return isinstance(v, list) and all(isinstance(i, dict) for i in v)
    elif t.startswith('enum:'):
        return True if v in t[5:].split(',') else False
    else:
        raise Exception("Unknown type '%s'" % t)


USER_FIELDS = {
    'name': 'str', 'password': 'str', 'login': 'bool', 'superuser': 'bool', 'createdb': 'bool', 'createrole': 'bool',
    'inherit': 'bool', 'replication': 'bool', 'bypassrls': 'bool', 'pgbouncer': 'bool', 'connlimit': 'int',
    'expire_in': 'int', 'expire_at': 'date', 'comment': 'str', 'roles': 'str[]', 'parameters': 'dict',
    'pool_mode': 'enum:transaction,session,statement', 'pool_connlimit': 'int',
}


def validate_pgsql_vip(gvars):
    success = True
    cls = gvars["pg_cluster"]
    if "pg_vip_enabled" in gvars and bool(gvars["pg_vip_enabled"]):
        if "pg_vip_address" not in gvars:
            print("[ERRO] `all.children.%s.vars.pg_vip_address` not given" % cls)
            success = False
    elif "pg_vip_address" in gvars and not is_valid_ipv4_cidr(gvars["pg_vip_address"]):
        print("[ERRO] `all.children.%s.vars.pg_vip_address` '%s' should be a valid IPv4 CIDR" % (
            cls, gvars["pg_vip_address"]))
        success = False
    return success


def validate_pgsql_user(gvars):
    success = True
    cls = gvars["pg_cluster"]
    if "pg_users" in gvars:
        if isinstance(gvars["pg_users"], list):
            for i, user in enumerate(gvars["pg_users"]):
                if not isinstance(user, dict):
                    print("[ERRO] `pgsql.%s.vars.pg_users[%s]` should be a dict" % (cls, i))
                    success = False
                    continue
                if 'name' not in user:
                    print("[ERRO] `pgsql.%s.vars.pg_users[%s]` should have a 'name'" % (cls, i))
                    success = False
                if 'password' not in user:
                    print("[WARN] `pgsql.%s.vars.pg_users[%s]` does not have a 'password'" % (cls, i))

                for k, v in user.items():
                    if k not in USER_FIELDS:
                        print("[WARN] `pgsql.%s.vars.pg_users[%s]` has an unknown key '%s'" % (cls, i, k))
                    elif not validate_type(v, USER_FIELDS[k]):
                        print("[ERROR] `pgsql.%s.vars.pg_users[%s].%s` should be a %s" % (cls, i, k, USER_FIELDS[k]))
                        success = False

        else:
            print("[ERRO] `all.children.%s.vars.pg_users` should be a list" % cls)
            success = False
    return success


DB_FIELDS = {
    'name': 'str', 'baseline': 'str', 'pgbouncer': 'bool', 'schemas': 'str[]', 'extensions': 'list', 'comment': 'str',
    'owner': 'str', 'template': 'str', 'encoding': 'str', 'locale': 'str', 'lc_collate': 'str', 'lc_ctype': 'str',
    'tablespace': 'str', 'allowconn': 'bool', 'revokeconn': 'bool', 'register_datasource': 'bool', 'connlimit': 'int',
    'pool_auth_user': 'str', 'pool_mode': 'enum:transaction,session,statement',
    'pool_size': 'int', 'pool_size_reserve': 'int', 'pool_size_min': 'int', 'pool_max_db_conn': 'int'
}


def validate_pgsql_db(gvars):
    success = True
    cls = gvars["pg_cluster"]
    if "pg_databases" in gvars:
        if isinstance(gvars["pg_databases"], list):
            for i, db in enumerate(gvars["pg_databases"]):
                if not isinstance(db, dict):
                    print("[ERRO] `pgsql.%s.vars.pg_databases[%s]` should be a dict" % (cls, i))
                    success = False
                elif 'name' not in db:
                    print("[ERRO] `pgsql.%s.vars.pg_databases[%s]` should have a 'name'" % (cls, i))
                    success = False
                for k, v in db.items():
                    if k not in DB_FIELDS:
                        print("[WARN] `pgsql.%s.vars.pg_databases[%s]` has an unknown key '%s'" % (cls, i, k))
                    elif not validate_type(v, DB_FIELDS[k]):
                        print("[ERRO] `pgsql.%s.vars.pg_databases[%s].%s` should be a %s" % (cls, i, k, DB_FIELDS[k]))

        else:
            print("[ERRO] `pgsql.%s.vars.pg_databases` should be a list" % cls)
            success = False
    return success


SVC_FIELDS = {
    'name': 'str', 'port': 'int', 'ip': 'str', 'selector': 'str', 'backup': 'str',
    'dest': 'enum:default,postgres,pgbouncer', 'check': 'str', 'maxconn': 'int', 'options': 'str',
    'balance': 'enum:roundrobin,leastconn',
}


def validate_pgsql_svc(gvars):
    success = True
    cls = gvars["pg_cluster"]
    if "pg_services" in gvars:
        distinct_ports = set()
        if isinstance(gvars["pg_services"], list):
            for i, svc in enumerate(gvars["pg_services"]):
                if not isinstance(svc, dict):
                    print("[ERRO] `all.children.%s.vars.pg_services[%s]` should be a dict" % (cls, i))
                    success = False
                    continue
                if 'name' not in svc:
                    print("[ERRO] `all.children.%s.vars.pg_services[%s]` should have a 'name'" % (cls, i))
                    success = False
                if 'port' not in svc:
                    print("[ERRO] `all.children.%s.vars.pg_services[%s]` should have a 'port'" % (cls, i))
                    success = False
                elif not isinstance(svc['port'], int):
                    print("[ERRO] `all.children.%s.vars.pg_services[%s].port` should be an integer" % (cls, i))
                    success = False
                elif svc['port'] < 0 or svc['port'] > 65535:
                    print(
                        "[ERRO] `all.children.%s.vars.pg_services[%s].port` should be an integer between 0 and 65535" % (
                            cls, i))
                    success = False
                else:
                    if svc['port'] in distinct_ports:
                        print("[ERRO] `all.children.%s.vars.pg_services[%s].port` %s is duplicated" % (
                            cls, i, svc['port']))
                        success = False
                    elif svc['port'] in (5432, 6432, 5433, 5434, 5436, 5438, 9100, 9101, 9630, 9631, 9080):
                        print("[WARN] `all.children.%s.vars.pg_services[%s].port` %s may be reserved" % (
                            cls, i, svc['port']))
                    else:
                        distinct_ports.add(svc['port'])

                for k, v in svc.items():
                    if k not in SVC_FIELDS:
                        print("[WARN] `pgsql.%s.vars.pg_services[%s]` has an unknown key '%s'" % (cls, i, k))
                    elif not validate_type(v, SVC_FIELDS[k]):
                        print("[ERRO] `pgsql.%s.vars.pg_services[%s].%s` should be a %s" % (cls, i, k, SVC_FIELDS[k]))

        else:
            print("[ERRO] `all.children.%s.vars.pg_services` should be a list" % cls)
            success = False
    return success


HBA_ADDR_ALIAS = ['world', 'intra', 'infra', 'admin', 'local', 'localhost', 'cluster']
HBA_AUTH_ALIAS = ['deny', 'trust', 'pwd', 'sha', 'scram-sha-256', 'md5',
                  'ssl', 'ssl-md5', 'ssl-sha', 'os', 'ident', 'peer', 'cert']


def validate_hba(hba, cls, var_name, i):
    success = True
    vars = (cls, var_name, i)
    if not isinstance(hba, dict):
        print("[ERRO] `pgsql.%s.vars.%s[%s]` should be a dict" % vars)
        return False
    if 'title' not in hba:
        print("[WARN] `pgsql.%s.vars.%s[%s]` missing 'title' field" % vars)
    for k, v in hba.items():
        if k not in ('title', 'rules', 'addr', 'auth', 'user', 'db'):
            print("[WARN] `pgsql.%s.vars.%s[%s]` has an unknown key '%s'" % (cls, var_name, i, k))

    if 'rules' in hba:
        if not validate_type(hba['rules'], 'str[]'):
            print("[ERRO] `pgsql.%s.vars.%s[%s].rules` should be a list of hba rule string" % vars)
            success = False
    else:
        if 'addr' not in hba:
            print("[ERRO] `pgsql.%s.vars.%s[%s]` missing 'addr' field" % vars)
            success = False
        elif hba['addr'] not in HBA_ADDR_ALIAS and not is_valid_ipv4_cidr(hba['addr']):
            print("[ERRO] `pgsql.%s.vars.%s[%s].addr` has invalid value: %s" % (cls, var_name, i, hba['addr']))

        if 'auth' not in hba:
            print("[ERRO] `pgsql.%s.vars.%s[%s]` missing 'auth' field" % vars)
            success = False
        elif hba['auth'] not in HBA_AUTH_ALIAS:
            print("[ERRO] `pgsql.%s.vars.%s[%s].auth` has invalid value: %s" % (cls, var_name, i, hba['auth']))

        if 'user' not in hba:
            print("[ERRO] `pgsql.%s.vars.%s[%s]` missing 'user' field" % vars)
            success = False
        if 'db' not in hba:
            print("[ERRO] `pgsql.%s.vars.%s[%s]` missing 'db' field" % vars)
            success = False
    return success


def validate_pgsql_hba(gvars):
    success = True
    cls = gvars["pg_cluster"]

    for var_name in ("pg_hba_rules", "pgb_hba_rules", "pg_default_hba_rules", "pgb_default_hba_rules"):
        if var_name in gvars:
            if isinstance(gvars[var_name], list):
                for i, hba in enumerate(gvars[var_name]):
                    validate_hba(hba, cls, var_name, i)
            else:
                print("[ERRO] `all.children.%s.vars.%s` should be a list" % (cls, var_name))
                success = False
    return success


def validate_config(path):
    c = parse_conf(path)
    success = True
    success = validate_global_vars(c) and success
    success = validate_children(c) and success
    success = validate_infra_group(c) and success
    success = validate_etcd_group(c) and success
    success = validate_minio_group(c) and success
    success = validate_pgsql_groups(c) and success
    return success


if __name__ == '__main__':
    bindir = os.path.dirname(os.path.abspath(__file__))
    path = os.path.abspath(os.path.join(bindir, '../pigsty.yml'))
    if len(sys.argv) > 1:
        path = sys.argv[1]
    if validate_config(path):
        sys.exit(0)
    else:
        sys.exit(1)
