#!/usr/bin/env python
#
# Copyright 2018 SUNET. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are
# permitted provided that the following conditions are met:
#
#    1. Redistributions of source code must retain the above copyright notice, this list of
#       conditions and the following disclaimer.
#
#    2. Redistributions in binary form must reproduce the above copyright notice, this list
#       of conditions and the following disclaimer in the documentation and/or other materials
#       provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY SUNET ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SUNET OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# The views and conclusions contained in the software and documentation are those of the
# authors and should not be interpreted as representing official policies, either expressed
# or implied, of SUNET.
#
# Author: Fredrik Thulin <fredrik@thulin.net>
#

"""
Toolbox for everything that needs to read the instance YAML config file.

The idea is to not have a lot of different scripts reading that config file,
to make it easier to update the YAML config file format going forward.

Main actions:

  - Print things in the YAML config file. Some things are handled specially,
    like 'print_ips' that extracts frontend IP addresses and make them convenient
    to work with from shell scripts.

  - Print haproxy config. Creates a Jinja2 context from the YAML file and generates a
    haproxy config from a template.

  - Print exabgp announce messages. Called whenever a status change is detected
    for the configured backends.
"""

import os
import sys
import glob
import time
import yaml
import pprint
import logging
import argparse

import logging.handlers

logger = None

_defaults = {'syslog': False,
             'debug': False,
             'config_dir': '/opt/frontend/config',
             'backends_dir': '/opt/frontend/api/backends',
             'max_age': 600,
             }


def parse_args(defaults):
    parser = argparse.ArgumentParser(description = 'SUNET frontend config toolbox',
                                     add_help = True,
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter,
    )

    # Positional arguments
    parser.add_argument('actions',
                        metavar='ACTION',
                        nargs='+',
                        help='Actions to perform',
    )

    # Optional arguments
    parser.add_argument('--config_dir',
                        dest = 'config_dir',
                        metavar = 'DIR', type = str,
                        default = defaults['config_dir'],
                        help = 'Base directory for configuration data',
    )
    parser.add_argument('--backends_dir',
                        dest = 'backends_dir',
                        metavar = 'DIR', type = str,
                        default = defaults['backends_dir'],
                        help = 'Base directory for API backend registration data',
    )
    parser.add_argument('-c', '--config_fn',
                        dest = 'config_fn',
                        metavar = 'FILENAME', type = str,
                        default = None,
                        help = 'YAML config file to use',
    )
    parser.add_argument('--haproxy_template',
                        dest = 'haproxy_template',
                        metavar = 'FILENAME', type = str,
                        default = None,
                        help = 'haproxy.j2 file to use, relative to --config-dir',
    )
    parser.add_argument('--instance',
                        dest = 'instance',
                        metavar='NAME', type = str,
                        default = None,
                        help='Instance',
    )
    parser.add_argument('--max_age',
                        dest = 'max_age',
                        metavar = 'SECONDS', type = int,
                        default = defaults['max_age'],
                        help = 'Max backend file age to allow',
    )
    parser.add_argument('--fqdn',
                        dest = 'fqdn',
                        metavar='NAME', type = str,
                        default = None,
                        help='Host FQDN (for print_exabgp_announce)',
    )
    parser.add_argument('--status',
                        dest = 'status',
                        metavar='UP/DOWN', type = str,
                        default = None,
                        help='Current backend status (for print_exabgp_announce)',
    )

    parser.add_argument('--debug',
                        dest = 'debug',
                        action = 'store_true', default = defaults['debug'],
                        help = 'Enable debug operation',
    )
    parser.add_argument('--syslog',
                        dest = 'syslog',
                        action = 'store_true', default = defaults['syslog'],
                        help = 'Enable syslog output',
    )
    args = parser.parse_args()
    return args


def main(myname = 'frontend-config', args = None, logger_in = None, defaults = _defaults):
    if not args:
        args = parse_args(defaults)

    global logger
    # initialize various components
    if logger_in:
        logger = logger_in
    else:
        # This is the root log level
        level = logging.INFO
        if args.debug:
            level = logging.DEBUG
        logging.basicConfig(level = level, stream = sys.stderr,
                            format='%(asctime)s: %(name)s: %(threadName)s %(levelname)s %(message)s')
        logger = logging.getLogger(myname)
        # If stderr is not a TTY, change the log level of the StreamHandler (stream = sys.stderr above) to WARNING
        if not sys.stderr.isatty() and not args.debug:
            for this_h in logging.getLogger('').handlers:
                this_h.setLevel(logging.WARNING)
    if args.syslog:
        syslog_h = logging.handlers.SysLogHandler()
        formatter = logging.Formatter('%(name)s: %(levelname)s %(message)s')
        syslog_h.setFormatter(formatter)
        logger.addHandler(syslog_h)

    config_fn = args.config_fn
    if not config_fn:
        if not args.instance:
            logger.error('Neither config file nor --instance supplied')
            return False
        config_fn = os.path.join(args.config_dir, args.instance, 'config.yml')

    logger.debug('Loading configuration from {!r}'.format(config_fn))

    with open(config_fn, 'r') as fd:
        config = yaml.safe_load(fd)

    logger.debug('Config:\n{!s}'.format(pprint.pformat(config)))

    for this in args.actions:
        logger.debug('Processing action {!r}'.format(this))
        if this.startswith('print_'):
            if this == 'print_ips':
                if not print_ips(config, args, logger):
                    return False
            elif this == 'print_haproxy_config':
                if not print_haproxy_config(config, args, logger):
                    return False
            elif this == 'print_exabgp_announce':
                if not print_exabgp_announce(config, args, logger):
                    return False
            else:
                # Print generic value
                if not print_generic(this[6:], config, args, logger):
                    return False
        else:
            sys.stderr.write('Unknown action {!r}\n'.format(this))
            return False

    return True


def print_generic(what, config, args, logger):
    if what not in config:
        return False
    print(config[what])
    return True


def print_ips(config, args, logger):
    if not 'frontends' in config:
        return False
    res = []
    for fe in config['frontends'].values():
        res += fe['ips']
    print('\n'.join(sorted(res)))
    return True


def print_exabgp_announce(config, args, logger):
    if not args.fqdn:
        logger.error('Action print_exabgp_announce requires --fqdn')
        return False
    if not args.status:
        logger.error('Action print_exabgp_announce requires --status')
        return False
    frontends = _pref_sort_frontends(args.fqdn, sorted(config['frontends'].keys()))

    bgp_pref = 1000
    is_up = True if args.status.startswith('UP') else False
    if not is_up:
        # fallback route preference
        bgp_pref = 100
    bgp = []
    count=1
    frontends.reverse()
    for fe in frontends:
        for ip in sorted(config['frontends'][fe]['ips']):
            if ':' in ip:
                ip_slash = ip + '/128'
            else:
                ip_slash = ip + '/32'

            logger.debug('Site {}, address {}/{} -> announce {} with BGP preference {}'.format(
                config['site_name'], count, len(frontends), ip, bgp_pref))
            bgp.append('announce route {} next-hop self local-preference {}'.format(ip_slash, bgp_pref))
            count += 1
        if is_up:
            bgp_pref += 100
        else:
            bgp_pref += 10

    print('\n'.join(sorted(bgp)))
    return bool(bgp)


def _pref_sort_frontends(fqdn, frontends):
    """
    Three load balancers named A, B and C should all announce
    service IP 1, 2 and 3 with different BGP MED (lower MED wins).

    A should announce 1 with the lowest MED, and B should announce
    2 with the lowest MED.

    Re-sort the input list of IPs to start with the one specified by the index.
    """
    index = frontends.index(fqdn)
    return frontends[index:] + frontends[:index]


def print_haproxy_config(config, args, logger):
    if not args.haproxy_template:
        logger.error('Action print_haproxy_config requires --haproxy_template')
        return False

    if not args.instance:
        logger.error('Action print_haproxy_config requires --instance')
        return False

    context = config

    import pprint
    import jinja2
    from jinja2 import Environment, FileSystemLoader

    bind_ips = []
    for fe in config['frontends'].values():
        bind_ips += fe['ips']
    context['bind_ips'] = sorted(bind_ips)

    context['backends'] = _load_backends(config, args, logger)

    # remove things not meant for the haproxy template
    for this in ['allow_ports', 'frontends', 'frontend_template']:
        if this in context:
            del context[this]

    logger.debug('Rendering haproxy template {} with context:\n{!s}'.format(args.haproxy_template, pprint.pformat(context)))

    env = Environment(loader = FileSystemLoader(args.config_dir))
    template = env.get_template(args.haproxy_template)
    print(template.render(**context))
    return True

def _load_backends(config, args, logger):
    """
    Load data written to disk by the API used by backends to register themselves.
    """
    res = []
    for be in sorted(config['backends'].keys()):
        # be is 'default', 'testing' or similar
        if be.startswith('_'):
            continue

        servers = []
        be_name = '{}__{}'.format(config['site_name'], be)
        for host, params in config['backends'][be].items():
            be_ips = params.pop('ips', [])
            for addr in be_ips:
                data = _load_backend_data(config['site_name'], addr, args, logger)
                logger.debug('Backend data for site {} addr {}: {}'.format(config['site_name'], addr, data))
                if data is None:
                    logger.debug('Not adding backend {} - probably not registered'.format(addr))
                    continue
                params.update(data)
                servers += [params]

        res += [{'name': be_name,
                 'servers': servers,
        }]
    return res


def _load_backend_data(site_name, addr, args, logger):
    """
    Load data about a registered backend from a file.

    Return data matching what the haproxy.j2 template wants for 'servers':

    {'name': 'server1.example.org',
     'ip': '192.0.2.1',
     'port': '443',
     'address_family': 'v4'
    }
    """
    match = os.path.join(args.backends_dir, site_name, '*_{}.conf'.format(addr))
    fns = glob.glob(match)
    if not fns:
        logger.debug('Found no file(s) matching {!r}'.format(match))
        return None
    for this in fns:
        # Check file age before using this file
        if _get_fileage(this) > args.max_age:
            logger.info('Ignoring file {} that is older than max_age ({})'.format(this, args.max_age))
            continue

        logger.debug('Found backend file {}'.format(this))
        params = _load_backend_file(this)
        if params.pop('action') == 'register':
            params['address_family'] = 'v4' if params.get('ip') else None
            if ':' in params.get('ip', ''):
                params['address_family'] = 'v6'
            return params


def _get_fileage(filename):
    st = os.stat(filename)
    mtime = st.st_mtime
    real_now = int(time.time())
    return int(real_now - mtime)


def _load_backend_file(filename):
    """
    Load a backend registration file, created by the sunet-frontend-api.

    Example file:

      ACTION=register
      BACKEND=www.dev.eduid.se
      SERVER=www-fre-1.eduid.se
      REMOTE_IP=130.242.130.200
      PORT=443

    :returns: dict with normalized keys and their values
    :rtype: dict
    """
    res = {}
    fd = open(filename)
    for line in fd.readlines():
        while line.endswith('\n'):
            line = line[:-1]
        (head, sep, tail) = line.partition('=')
        if head and tail:
            key = head.lower()
            if key == 'remote_ip':
                key = 'ip'
            elif key == 'backend':
                key = 'name'
            res[key] = tail
    return res


if __name__ == '__main__':
    try:
        progname = os.path.basename(sys.argv[0])
        res = main(progname)
        if res is True:
            sys.exit(0)
        if res is False:
            sys.exit(1)
        sys.exit(int(res))
    except KeyboardInterrupt:
        sys.exit(0)