#!/usr/bin/python3
# ==============================================================#
# File      :   ssh
# Desc      :   setup vagrant ssh config
# Ctime     :   2021-04-20
# Mtime     :   2025-09-11
# Path      :   vagrant/ssh
# License   :   Apache-2.0 @ https://pigsty.io/docs/about/license/
# Copyright :   2018-2026  Ruohang Feng / Vonng (rh@vonng.com)
# ==============================================================#
# this script will generate vagrant ssh config to ~/.ssh
# which allows you to ssh to VMs via both name and IP addr

import os, re, sys, subprocess


# get the node name and ip mapping from Vagrantfile
def parse_vagrant_spec(vagrant_filepath):
    specs = []
    result = {}
    with open(vagrant_filepath, 'r') as f:
        raw_lines = f.readlines()

    activate = False
    for line in raw_lines:
        if line.startswith('Specs'):
            activate = True
            continue
        elif line.startswith(']'):
            activate = False
            break
        if activate and '"name"' in line and '"ip"' in line:
            name = re.findall(r'"name"\s*=>\s*"([^"]+)"', line)[0]
            ip = re.findall(r'"ip"\s*=>\s*"([^"]+)"', line)[0]
            result[name] = ip
            specs.append((ip, name))
    return result, specs


# run vagrant ssh-config and gather result
def get_vagrant_ssh_config(vagrant_dir):
    print("cd %s && vagrant ssh-config" % vagrant_dir)
    print("this may take several seconds....")
    os.chdir(vagrant_dir)
    command = 'vagrant ssh-config'
    config = subprocess.check_output(command, shell=True, universal_newlines=True)
    return config


def generate_ssh_config(config, mapping):
    extra_config = config
    for name, ip in mapping.items():
        src = 'Host %s' % name
        dst = 'Host %s' % ip
        extra_config = extra_config.replace(src, dst)
    final_config = '# ssh access via nodename\n\n\n' + config + '\n\n\n' + '# SSH Access via IP address\n\n' + extra_config
    return final_config


def write_ssh_config(config, name="pigsty"):
    config_name = '%s_config' % name
    include_cmd = 'Include ~/.ssh/%s_config' % name
    main_config_path = os.path.join(os.environ['HOME'], '.ssh', 'config')
    extra_config_path = os.path.join(os.environ['HOME'], '.ssh', config_name)

    # write extra config
    print("write extra ssh config [%s] to %s" % (name, extra_config_path))
    with open(extra_config_path, 'w') as f:
        f.write(config)
    os.chmod(extra_config_path, 0o600)

    with open(main_config_path, 'r') as f:
        ssh_config = f.read()
    if include_cmd in ssh_config:
        print("include cmd already exists in ~/.ssh/config")
        return

    print("write include %s command to ~/.ssh/config" % name)
    with open(main_config_path, 'a') as f:
        f.write('\n\n# %s\n%s\n' % (config_name, include_cmd))
    os.chmod(main_config_path, 0o600)


def get_config_name():
    config_name = 'pigsty'
    if len(sys.argv) > 1:
        config_name = sys.argv[1]
    if not re.match('^[a-z0-9]+$', config_name):
        print("config name must match regexp: [a-z0-9]+")
        sys.exit(1)
    return config_name


def main():
    # generate config name, pigsty by default
    config_name = get_config_name()

    # check vagrantfile exists
    vagrant_dir = os.path.dirname(os.path.realpath(__file__))
    os.chdir(vagrant_dir)
    vagrant_file = os.path.join(vagrant_dir, 'Vagrantfile')
    if not os.path.exists(vagrant_file):
        print("Vagrantfile not found in %s" % vagrant_file)
        sys.exit(1)

    # parse vagrantfile and generate ssh config
    mapping, specs = parse_vagrant_spec(vagrant_file)
    print("\nVagrant nodes:\n")
    for ip, name in specs:
        print("%-16s %s" % (ip, name))
    print("\n\n")
    config = get_vagrant_ssh_config(vagrant_dir)
    config = generate_ssh_config(config, mapping)
    print(config)

    # write ssh config to ~/.ssh/
    write_ssh_config(config, config_name)


# main
if __name__ == '__main__':
    main()
