#!/usr/bin/env python3
#
# This script is responsible for creating/updating /etc/puppet/cosmos-modules.conf.
#
# If this script exits without creating that file, a default list of modules will be
# selected (by post-tasks.d/010cosmos-modules, the script that invokes this script).
#
# NOTES ABOUT THE IMPLEMENTATION:
#
#   - Avoid any third party modules. We want this script to be re-usable in all ops-repos.
#   - To make merging easier, try to keep all local alterations in the local_* functions.
#   - Format with black and isort. Line width 120.
#   - You probably ONLY want to change things in the local_get_modules_hook() function.
#

import argparse
import csv
import json
import logging
import logging.handlers
import os
import re
import socket
import sys
from pathlib import Path
from typing import Dict, NewType, Optional, cast

from pkg_resources import parse_version

logger = logging.getLogger(__name__)  # will be overwritten by _setup_logging()

# Set up types for data that is passed around in functions in this script.
# Need to use Dict (not dict) here since these aren't stripped by strip-hints, and doesn't work on Ubuntu <= 20.04.
Arguments = NewType("Arguments", argparse.Namespace)
OSInfo = Dict[str, str]
HostInfo = Dict[str, Optional[str]]
Modules = Dict[str, Dict[str, str]]


def parse_args() -> Arguments:
    """
    Parse the command line arguments
    """
    parser = argparse.ArgumentParser(
        description="Setup cosmos-modules.conf",
        add_help=True,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("--debug", dest="debug", action="store_true", default=False, help="Enable debug operation")
    parser.add_argument(
        "--filename", dest="filename", type=str, default="/etc/puppet/cosmos-modules.conf", help="Filename to write to"
    )

    return cast(Arguments, parser.parse_args())


def get_os_info() -> OSInfo:
    """Load info about the current OS (distro, release etc.)"""
    os_info: OSInfo = {}
    if Path("/etc/os-release").exists():
        os_info.update({k.lower(): v for k, v in _parse_bash_vars("/etc/os-release").items()})
    res = local_os_info_hook(os_info)
    logger.debug(f"OS info:\n{json.dumps(res, sort_keys=True, indent=4)}")
    return res


def get_host_info() -> HostInfo:
    """Load info about the current host (hostname, fqdn, domain name etc.)"""
    try:
        fqdn = socket.getfqdn()
        hostname = socket.gethostname()
    except OSError:
        host_info = {}
    else:
        _domainname = fqdn[len(hostname + ".") :]

        host_info: HostInfo = {
            "domainname": _domainname,
            "fqdn": fqdn,
            "hostname": hostname,
        }
    res = local_host_info_hook(host_info)
    logger.debug(f"Host info: {json.dumps(res, sort_keys=True, indent=4)}")
    return res


def _parse_bash_vars(path: str) -> dict[str, str]:
    """
    Parses a bash script and returns a dictionary representing the
    variables declared in that script.

    Source: https://dev.to/htv2012/how-to-parse-bash-variables-b4f

    :param path: The path to the bash script
    :return: Variables as a dictionary
    """
    with open(path) as stream:
        contents = stream.read().strip()

    var_declarations = re.findall(r"^[a-zA-Z0-9_]+=.*$", contents, flags=re.MULTILINE)
    reader = csv.reader(var_declarations, delimiter="=")
    bash_vars = dict(reader)
    return bash_vars


def get_modules(os_info: OSInfo, host_info: HostInfo) -> Modules:
    """Load the list of default modules.

    This is more or less an inventory of all the modules we have. If you don't want
    to use all modules in your OPS repo, you can filter them in the local hook.

    If you want to use a different tag for a module on a specific host/os, you can
    do that in the local hook as well.
    """
    default_modules = """
        # name      repo                                                  upgrade  tag
        apparmor    https://github.com/SUNET/puppet-apparmor.git          yes      sunet-2*
        apt         https://github.com/SUNET/puppetlabs-apt.git           yes      sunet-2*
        augeas      https://github.com/SUNET/puppet-augeas.git            yes      sunet-2*
        bastion     https://github.com/SUNET/puppet-bastion.git           yes      sunet-2*
        concat      https://github.com/SUNET/puppetlabs-concat.git        yes      sunet-2*
        cosmos      https://github.com/SUNET/puppet-cosmos.git            yes      sunet-2*
        dhcp        https://github.com/SUNET/puppetlabs-dhcp.git          yes      sunet_dev-2*
        docker      https://github.com/SUNET/garethr-docker.git           yes      sunet-2*
        hiera-gpg   https://github.com/SUNET/hiera-gpg.git                yes      sunet-2*
        munin       https://github.com/SUNET/ssm-munin.git                yes      sunet-2*
        nagioscfg   https://github.com/SUNET/puppet-nagioscfg.git         yes      sunet-2*
        network     https://github.com/SUNET/attachmentgenie-network.git  yes      sunet-2*
        pound       https://github.com/SUNET/puppet-pound.git             yes      sunet-2*
        pyff        https://github.com/samlbits/puppet-pyff.git           yes      puppet-pyff-*
        python      https://github.com/SUNET/puppet-python.git            yes      sunet-2*
        stdlib      https://github.com/SUNET/puppetlabs-stdlib.git        yes      sunet-2*
        sunet       https://github.com/SUNET/puppet-sunet.git             yes      sunet-2*
        sysctl      https://github.com/SUNET/puppet-sysctl.git            yes      sunet-2*
        ufw         https://github.com/SUNET/puppet-module-ufw.git        yes      sunet-2*
        varnish     https://github.com/samlbits/puppet-varnish.git        yes      puppet-varnish-*
        vcsrepo     https://github.com/SUNET/puppetlabs-vcsrepo.git       yes      sunet-2*
        xinetd      https://github.com/SUNET/puppetlabs-xinetd.git        yes      sunet-2*
    """
    modules: Modules = {}
    for line in default_modules.splitlines():
        try:
            if not line.strip() or line.strip().startswith("#"):
                continue
            _name, _url, _upgrade, _tag = line.split()
            modules[_name] = {
                "repo": _url,
                "upgrade": _upgrade,
                "tag": _tag,
            }
        except ValueError:
            logger.error(f"Failed to parse line: {repr(line)}")
            raise

    # Remove the UFW module on Ubuntu >= 22.04 (nftables is used there instead)
    if os_info.get("name") == "Ubuntu":
        ver = os_info.get("version_id")
        if ver:
            if parse_version(ver) >= parse_version("22.04"):
                logger.debug("Removing UFW module for Ubuntu >= 22.04")
                del modules["ufw"]
            else:
                logger.debug("Keeping UFW module for Ubuntu < 22.04")
        else:
            logger.debug("Unknown Ubuntu module version, keeping UFW module")

    return local_get_modules_hook(os_info, host_info, modules)


def local_os_info_hook(os_info: OSInfo) -> OSInfo:
    """Local hook to modify os_info in an OPS repo."""
    # Start local changes in this repository
    # End local changes
    return os_info


def local_host_info_hook(host_info: HostInfo) -> HostInfo:
    """Local hook to modify host_info in an OPS repo."""
    # Start local changes in this repository

    # Regular expression to tease apart an eduID hostname
    hostname_re = re.compile(
        r"""^
        (\w+)   # function ('idp', 'apps', ...)
        -
        (\w+)   # site ('tug', 'sthb', ...)
        -
        (\d+)   # 1 for staging, 3 for production
    """,
        re.VERBOSE,
    )
    _hostname = host_info.get("hostname")
    if _hostname:
        m = hostname_re.match(_hostname)
        if m:
            _function, _site, _num = m.groups()
            host_info["function"] = _function
            host_info["site"] = _site
            if _num == "1":
                host_info["environment"] = "staging"

    # End local changes
    return host_info


def local_get_modules_hook(os_info: OSInfo, host_info: HostInfo, modules: Modules) -> Modules:
    """Local hook to modify default set of modules in an OPS repo."""
    # Start local changes in this repository

    _eduid_modules = {
        "apparmor",
        "apt",
        "augeas",
        "bastion",
        "concat",
        "docker",
        "munin",
        "stdlib",
        "sunet",
        "ufw",
    }
    # Only keep the modules eduID actually uses
    modules = {k: v for k, v in modules.items() if k in _eduid_modules}
    logger.debug(f"Adding modules: {json.dumps(modules, sort_keys=True, indent=4)}")

    # Use eduID tag for puppet-sunet
    modules["sunet"]["tag"] = "eduid-stable-2*"
    if host_info.get("environment") == "staging":
        modules["sunet"]["tag"] = "eduid_dev-2*"

    # use sunet_dev-2* for some modules in staging
    for dev_module in ["munin"]:
        if host_info.get("environment") == "staging" and dev_module in modules:
            modules[dev_module]["tag"] = "sunet_dev-2*"

    # End local changes
    return modules


def update_cosmos_modules(filename: str, modules: Modules) -> None:
    """Create/update the cosmos-modules.conf file.

    First, we check if the file already have the right content. If so, we do nothing.
    """
    content = "# This file is automatically generated by the setup_cosmos_modules script.\n# Do not edit it manually.\n"
    for k, v in sorted(modules.items()):
        content += f"{k:15} {v['repo']:55} {v['upgrade']:5} {v['tag']}\n"
    _file = Path(filename)
    if _file.exists():
        # Check if the content is already correct, and avoid updating the file if so (so that the timestamp
        # of the file at least indicates when the content was last updated)
        with _file.open("r") as f:
            current = f.read()
        if current == content:
            logger.debug(f"{filename} is up to date")
            return

    # Create/update the file by writing the content to a temporary file and then renaming it
    _tmp_file = _file.with_suffix(".tmp")
    with _tmp_file.open("w") as f:
        f.write(content)
    _tmp_file.rename(_file)
    logger.debug(f"Updated {filename}")


def _setup_logging(my_name: str, args: Arguments):
    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
    logging.basicConfig(level=level, stream=sys.stderr, format="{asctime} | {levelname:7} | {message}", style="{")
    global logger
    logger = logging.getLogger(my_name)
    # If stderr is not a TTY, change the log level of the StreamHandler (stream = sys.stderr above) to ERROR
    if not sys.stderr.isatty() and not args.debug:
        for this_h in logging.getLogger("").handlers:
            this_h.setLevel(logging.ERROR)
    if args.debug:
        logger.setLevel(logging.DEBUG)


def main(my_name: str, args: Arguments) -> bool:
    _setup_logging(my_name, args)

    os_info = get_os_info()
    host_info = get_host_info()
    modules = get_modules(os_info, host_info)

    update_cosmos_modules(args.filename, modules)

    return True


if __name__ == "__main__":
    my_name = os.path.basename(sys.argv[0])
    args = parse_args()
    res = main(my_name, args=args)
    if res:
        sys.exit(0)
    sys.exit(1)