#!/usr/bin/env python3
# -*- coding: utf-8 -*- #
#==============================================================#
# File      :   inventory_load
# Desc      :   load pigsty config into cmdb structure
# Ctime     :   2022-05-05
# Mtime     :   2026-01-11
# Path      :   bin/inventory_cmdb
# Deps      :   argparse, os, json, yaml, platform
# License   :   Apache-2.0 @ https://pigsty.io/docs/about/license/
# Copyright :   2018-2026  Ruohang Feng / Vonng (rh@vonng.com)
#==============================================================#
__author__ = 'Vonng (rh@vonng.com)'

from argparse import ArgumentParser
import os, json, yaml, platform


def usage():
    print("""
    bin/load_conf [ -p | --path = ${PIGSTY_HOME}/pigsty.yml ]
                  [ -d | --data = ${METADB_URL} ] 
    Load config into cmdb pigsty schema     
    """)

SQL_CLEAN_UP = '''TRUNCATE pigsty.group, pigsty.host, pigsty.group_var, pigsty.host_var, pigsty.global_var;'''
SQL_INSERT_GLOBAL_VARS = '''INSERT INTO pigsty.global_var (key, value) VALUES (%s, %s) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, mtime = now();'''
SQL_INSERT_GROUP_NAMES = '''INSERT INTO pigsty.group (cls) VALUES (%s) ON CONFLICT (cls) DO UPDATE SET mtime = now();'''
SQL_INSERT_GROUP_VARS = '''INSERT INTO pigsty.group_var (cls, key, value) VALUES (%s, %s, %s) ON CONFLICT (cls, key) DO UPDATE SET value = EXCLUDED.value, mtime = now();'''
SQL_INSERT_HOST_NAMES = '''INSERT INTO pigsty.host (cls, ip) VALUES (%s ,%s) ON CONFLICT (cls, ip) DO UPDATE SET mtime = now();'''
SQL_INSERT_HOST_VARS = '''INSERT INTO pigsty.host_var (cls, ip, key, value) VALUES (%s, %s, %s, %s) ON CONFLICT (cls, ip, key) DO UPDATE SET value = EXCLUDED.value, mtime = now();'''

###########################
# parse arguments
###########################
DEFAULT_PGURL = ''
DEFAULT_CONFIG_PATH = ''
PIGSTY_HOME = ''

if 'METADB_URL' in os.environ:
    DEFAULT_PGURL = os.environ['METADB_URL']
if 'PIGSTY_HOME' in os.environ:
    PIGSTY_HOME = os.environ['PIGSTY_HOME']
elif 'HOME' in os.environ:
    PIGSTY_HOME = os.path.join(os.environ['HOME'], 'pigsty')
if PIGSTY_HOME != '':
    DEFAULT_CONFIG_PATH = os.path.join(PIGSTY_HOME, 'pigsty.yml')

parser = ArgumentParser(description="load config arguments")
parser.add_argument('-n', "--name", default='pgsql', help="config profile name, pgsql by default")
parser.add_argument('-p', "--path", default=DEFAULT_CONFIG_PATH,
                    help="config path, ${PIGSTY_HOME}/pigsty.yml by default")
parser.add_argument('-d', "--data", default=DEFAULT_PGURL, help="postgres cmdb pgurl, ${METADB_URL} by default")


###########################
# parse config
###########################
def parse_conf(path):
    return json.loads(json.dumps(yaml.safe_load(open(path, 'r'))))


def use_dynamic_inventory(ansible_cfg):
    if not os.path.exists(ansible_cfg):
        raise "%s not exists" % ansible_cfg
    cmd = """sed -ie 's/inventory.*/inventory = inventory.sh/g' %s""" % ansible_cfg
    if os.system(cmd) != 0:
        raise "fail to edit %s" % ansible_cfg
    os.remove(ansible_cfg + 'e')  # remove sed trash


def use_static_inventory(ansible_cfg):
    if not os.path.exists(ansible_cfg):
        raise "%s not exists" % ansible_cfg
    cmd = """sed -ie 's/inventory.*/inventory = pigsty.yml/g' %s""" % ansible_cfg
    if os.system(cmd) != 0:
        raise "fail to edit %s" % ansible_cfg
    os.remove(ansible_cfg + 'e')  # remove sed trash


def write_inventory_sh(path):
    with open(path) as dst:
        dst.write("#!/bin/bash\npsql service=meta -AXtwc 'SELECT text FROM pigsty.inventory;'")
    os.chmod(path, 0o755)


############################
# upsert config and activate
############################


def load_conf(conf, pgurl):
    if 'all' not in conf: return
    import psycopg2
    from psycopg2.extras import Json
    groups, global_vars = {}, []
    if 'children' in conf['all']: groups = conf['all']['children']
    if 'vars' in conf['all']: global_vars = [(k, Json(v)) for k, v in conf['all']['vars'].items()]
    group_names = [(i,) for i in groups.keys()]
    conn = psycopg2.connect(pgurl)
    try:
        with conn.cursor() as cur:
            cur.execute(SQL_CLEAN_UP)
            print("[INFO] truncate pigsty schema")
            cur.executemany(SQL_INSERT_GLOBAL_VARS, global_vars)
            print("[INFO] load all.vars")
            cur.executemany(SQL_INSERT_GROUP_NAMES, group_names)
            for cls, group in groups.items():
                print("[INFO] load group %s" % cls)
                group_vars = []
                if 'vars' in group: group_vars = [(cls, k, Json(v)) for k, v in group['vars'].items()]
                cur.executemany(SQL_INSERT_GROUP_VARS, group_vars)  # insert group var
                cur.executemany(SQL_INSERT_HOST_NAMES, [(cls, ip) for ip in group['hosts'].keys()])  # insert host names
                for ip, vars in group['hosts'].items():
                    host_vars = [(cls, ip, k, Json(v)) for k, v in vars.items()]
                    cur.executemany(SQL_INSERT_HOST_VARS, host_vars)  # insert group var
            conn.commit()
    except:
        print("[ERRO] fail to load inventory")
        conn.rollback()
        raise


############################
# OS-specific default vars
############################
# mapping: xxx_default -> xxx in pigsty.default_var
OS_DEFAULT_PARAMS = {
    'repo_packages_default': 'repo_packages',
    'repo_extra_packages_default': 'repo_extra_packages',
    'node_packages_default': 'node_packages',
    'infra_packages_default': 'infra_packages',
    'repo_upstream_default': 'repo_upstream',
}

SQL_UPDATE_DEFAULT_VAR = '''UPDATE pigsty.default_var SET value = %s WHERE key = %s;'''
SQL_TRUNCATE_PACKAGE_MAP = '''TRUNCATE pigsty.package_map;'''
SQL_INSERT_PACKAGE_MAP = '''INSERT INTO pigsty.package_map (key, value) VALUES (%s, %s);'''


def detect_os():
    """Detect OS vendor, version and architecture, return (prefix, fallback, arch) or (None, None, None)"""
    arch = platform.machine()
    if arch == 'arm64':
        arch = 'aarch64'  # macOS uses arm64

    os_release = '/etc/os-release'
    if not os.path.exists(os_release):
        return (None, None, arch)

    os_info = {}
    try:
        with open(os_release, 'r') as f:
            for line in f:
                line = line.strip()
                if '=' in line:
                    parts = line.split('=', 1)
                    if len(parts) == 2:
                        k, v = parts
                        os_info[k] = v.strip('"').strip("'")
    except Exception:
        return (None, None, arch)

    vendor = os_info.get('ID', '').lower()
    version = os_info.get('VERSION_ID', '')
    if '.' in version:
        version = version.split('.')[0]

    # determine OS prefix and fallback
    if vendor in ('rhel', 'centos', 'rocky', 'alma', 'almalinux', 'fedora', 'ol', 'anolis', 'openeuler'):
        prefix = 'el' + version
        fallback = 'el9'
    elif vendor == 'debian':
        prefix = 'd' + version
        fallback = 'd12'
    elif vendor == 'ubuntu':
        prefix = 'u' + version
        fallback = 'u22'
    else:
        return (None, None, arch)

    return (prefix, fallback, arch)


def load_os_defaults(pgurl):
    """Load OS-specific default parameters into pigsty.default_var"""
    if not pgurl or not PIGSTY_HOME:
        return

    prefix, fallback, arch = detect_os()
    if prefix is None:
        return  # silently skip on non-Linux or unknown OS

    vars_dir = os.path.join(PIGSTY_HOME, 'roles', 'node_id', 'vars')
    vars_file = os.path.join(vars_dir, '%s.%s.yml' % (prefix, arch))

    # fallback if exact match not found
    if not os.path.exists(vars_file):
        vars_file = os.path.join(vars_dir, '%s.%s.yml' % (fallback, arch))
        if not os.path.exists(vars_file):
            print("[INFO] skip os-specific defaults: %s.%s.yml not found" % (prefix, arch))
            return

    # parse vars file
    os_vars = None
    try:
        with open(vars_file, 'r') as f:
            os_vars = yaml.safe_load(f)
    except Exception as e:
        print("[WARN] fail to parse %s: %s" % (vars_file, e))
        return
    if not os_vars:
        return

    # collect default_var updates
    updates = []
    for src_key, dst_key in OS_DEFAULT_PARAMS.items():
        if src_key in os_vars:
            updates.append((json.dumps(os_vars[src_key]), dst_key))

    # collect package_map entries
    pkg_map = []
    if 'package_map' in os_vars and isinstance(os_vars['package_map'], dict):
        for k, v in os_vars['package_map'].items():
            pkg_map.append((k, str(v)))

    if not updates and not pkg_map:
        return

    # apply updates to database
    try:
        import psycopg2
        conn = psycopg2.connect(pgurl)
        cur = conn.cursor()
        try:
            if updates:
                cur.executemany(SQL_UPDATE_DEFAULT_VAR, updates)
            if pkg_map:
                cur.execute(SQL_TRUNCATE_PACKAGE_MAP)
                cur.executemany(SQL_INSERT_PACKAGE_MAP, pkg_map)
            conn.commit()
            print("[INFO] load os-specific defaults from %s" % os.path.basename(vars_file))
        finally:
            cur.close()
            conn.close()
    except Exception as e:
        print("[WARN] fail to update os defaults: %s" % e)


if __name__ == '__main__':
    args = parser.parse_args()
    print("[INFO] load pigsty config from %s into %s" % (args.path, args.data))
    conf = parse_conf(args.path)
    load_conf(conf, args.data)
    load_os_defaults(args.data)
