#! /usr/bin/python
#--------------------------------------------------------------------------
#
# Copyright 2013 Cumulus Networks, inc  all rights reserved
#
#--------------------------------------------------------------------------
import sys
import subprocess
import argparse
import os
import shutil
import re

OVS_CTL_VTEP_DEST = '/usr/share/openvswitch/scripts/ovs-ctl-vtep'
DEFAULT_PKI_DIR = '/var/lib/openvswitch/pki'
DEFAULT_MANAGER_PORT = 6632
VTEP_SERVICE = 'openvswitch-vtep'
CMD_OK = 0


def validate(fn):
    def _(*args, **kwargs):
        result = fn(*args, **kwargs)
        doc = fn.__doc__ or fn.__name__
        if result:
            if result[0] != 0:
                sys.stderr.write("Error: %s (%s).\n" % (doc, result[2]))
            else:
                sys.stdout.write("Executed: %s (%s).\n" % (doc, result[1]))
        return result

    return _


def setup_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('switch_name', help='Switch name')
    parser.add_argument('tunnel_ip',
                        help='local VTEP IP address for tunnel termination (data plane)')
    parser.add_argument('management_ip',
                        help='local management interface IP address for OVSDB conection (control plane)')
    parser.add_argument('--controller_ip', default='')
    parser.add_argument('--no_encryption', help='clear text OVSDB connection',action="store_true")
    parser.add_argument('--credentials-path', default='.')
    parser.add_argument('--pki-path', default=DEFAULT_PKI_DIR)
    return parser


def clean_str(value):
    return value.strip() if value else ""


def command(args):
    seq = args
    if callable(getattr(args, 'split')):
        seq = args.split()
    return local(seq)


def local(args):
    """
    Returns (errorcode, out, error)
    """
    args = [a for a in args if a]
    p = subprocess.Popen(args, stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    (stdout, stderr) = p.communicate()
    return p.returncode, clean_str(stdout), clean_str(stderr)


def multiple_replace(s_value, d):
    """
    s_value - string to reaplce
    d - replacement map
    """
    out = s_value
    if not d:
        d = {}
    for k, v in d.iteritems():
        out = out.replace(k, v)
    return out


# Commands ######
@validate
def create_cert(pki_dir=None, force=True):
    """
    create certificate on a switch, to be used for authentication with controller
    """
    cmd = '/usr/bin/ovs-pki init'
    if pki_dir:
        cmd += " --dir=%s" % pki_dir
    if force:
        cmd += " --force"
    return command(cmd)


@validate
def sign_certificate(cert_name, cert_dir=None, pki_dir=None, force=True):
    """
    sign certificate
    """
    cmd = "/usr/bin/ovs-pki req+sign %s" % cert_name
    if pki_dir:
        cmd += " --dir=%s" % pki_dir
    if force:
        cmd += " --force"
    r = command(cmd)
    if r[0] == CMD_OK and cert_dir:
        for i in ['cert', 'privkey', 'req']:
            move_file = "%s-%s.pem" % (cert_name, i)
            try:
                shutil.copy2(move_file, cert_dir)
            except shutil.Error:
                continue
            os.remove(move_file)
    return r


@validate
def patch_ovs_ctl_vtep(dst, no_encryption, private_key, cert, bootstrap_cert):
    """
    patch ovs-ctl-vtep
    """
    skip_lines = False
    if os.path.isfile(dst):
        with open(dst, 'r') as f:
            f_content = f.readlines()
        if f_content:
            with open(dst, 'w') as f:
                for line in f_content:
                    if 'set "$@" --remote=db:Global,managers' in line:
                        f.write(line)
                        if no_encryption == True:
                            f.write('	set "$@" --remote=ptcp:'+str(DEFAULT_MANAGER_PORT)+"\n")
                        else:
                            f.write('	set "$@" --private-key='+private_key+"\n")
                            f.write('	set "$@" --certificate='+cert+"\n")
                            f.write('	set "$@" --bootstrap-ca-cert='+bootstrap_cert+"\n")
                        skip_lines = True
                    else:
                        if skip_lines == True:
                            if 'start_daemon "$OVSDB_SERVER_PRIORITY" "$@" || return 1' in line:
                                f.write(line)
                                skip_lines = False
                        else:
                            f.write(line)


@validate
def restart(service):
    """
    restart a service
    """
    return command("service %s restart" % service)


@validate
def define_controller(controller_ip):
    """
    define NSX controller IP address in OVSDB
    """
    return command("/usr/bin/vtep-ctl set-manager ssl:%s:%d" % (controller_ip, DEFAULT_MANAGER_PORT))


@validate
def define_tunnel(switch_name, tunnel_ip):
    """
    define local tunnel IP address on the switch
    """
    return command("vtep-ctl set Physical_switch %s tunnel_ips=%s" % (switch_name, tunnel_ip))


@validate
def define_management(switch_name, management_ip):
    """
    define management IP address on the switch
    """
    return command("vtep-ctl set Physical_switch %s management_ips=%s" % (switch_name, management_ip))


@validate
def define_switch(switch_name):
    """
    define physical switch
    """
    out = command("vtep-ctl list-ps")
    switches = out[1].split()
    if len(switches) == 0:
        return command("vtep-ctl add-ps %s" % switch_name)
    else:
        return command("vtep-ctl set Physical_switch %s name=%s" % (switches[0].strip(), switch_name))


@validate
def delete_old_controller_cert(cert):
    """
    delete old controller certificate at the same path
    """
    if os.path.exists(cert):
        return command("rm %s" % cert)


def main(args):
    parser = setup_arg_parser()
    ns = parser.parse_args(args)
    private_key=''
    certificate=''
    bootstrap_ca_cert=''
    if ns.no_encryption == False:
       credentials_path = os.path.abspath(ns.credentials_path)
       private_key = os.path.join(
           credentials_path, "%s-privkey.pem" % ns.switch_name)
       certificate = os.path.join(
           credentials_path, "%s-cert.pem" % ns.switch_name)
       bootstrap_ca_cert = os.path.join(credentials_path, "controller.cacert")

       create_cert(ns.pki_path)
       sign_certificate(ns.switch_name, ns.credentials_path, ns.pki_path)
    patch_ovs_ctl_vtep(OVS_CTL_VTEP_DEST, ns.no_encryption, private_key, certificate, bootstrap_ca_cert)
    define_switch(ns.switch_name)
    if ns.controller_ip:
        define_controller(ns.controller_ip)
    define_tunnel(ns.switch_name, ns.tunnel_ip)
    define_management(ns.switch_name, ns.management_ip)
    delete_old_controller_cert(bootstrap_ca_cert)
    restart(VTEP_SERVICE)
    return CMD_OK


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))
