#! /usr/bin/python
#--------------------------------------------------------------------------
#
# Copyright 2013 Cumulus Networks, inc  all rights reserved
#
#--------------------------------------------------------------------------
import sys
import subprocess
import argparse
import os
import shutil
import re

OVS_CTL_VTEP_DEST = '/usr/share/openvswitch/scripts/ovs-ctl-vtep'
DEFAULT_PKI_DIR = '/var/lib/openvswitch/pki'
DEFAULT_MANAGER_PORT = 6640
VTEP_SERVICE = 'openvswitch-vtep'
CMD_OK = 0


def validate(fn):
    def _(*args, **kwargs):
        result = fn(*args, **kwargs)
        doc = fn.__doc__ or fn.__name__
        if result:
            if result[0] != 0:
                sys.stderr.write("Error: %s (%s).\n" % (doc, result[2]))
            else:
                sys.stdout.write("Executed: %s (%s).\n" % (doc, result[1]))
        return result

    return _


def setup_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('switch_name', help='Switch name')
    parser.add_argument('tunnel_ip',
                        help='local VTEP IP address for tunnel termination (data plane)')
    parser.add_argument('management_ip',
                        help='local management interface IP address for OVSDB conection (control plane)')
    parser.add_argument('--controller_ip', default='')
    parser.add_argument('--controller_port', type=int, default=DEFAULT_MANAGER_PORT)
    parser.add_argument('--no_encryption', help='clear text OVSDB connection',action="store_true")
    parser.add_argument('--credentials-path', default='.')
    parser.add_argument('--pki-path', default=DEFAULT_PKI_DIR)
    parser.add_argument('--db_ha', default='', help='Select the HA active or standby OVSDB server')
    parser.add_argument('--db_ha_vip', default='', help='Virtual IP accessible from active and standby OVSDB server in HA')
    parser.add_argument('--db_ha_repl_sv', default='', help='Replication OVSDB server address used by backup to sync')
    parser.add_argument('--bridge', default='', help='Bridge name which will be used to add the configuration')
    return parser


def clean_str(value):
    return value.strip() if value else ""


def command(args):
    seq = args
    if callable(getattr(args, 'split')):
        seq = args.split()
    return local(seq)


def local(args):
    """
    Returns (errorcode, out, error)
    """
    args = [a for a in args if a]
    p = subprocess.Popen(args, stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    (stdout, stderr) = p.communicate()
    return p.returncode, clean_str(stdout), clean_str(stderr)


def multiple_replace(s_value, d):
    """
    s_value - string to reaplce
    d - replacement map
    """
    out = s_value
    if not d:
        d = {}
    for k, v in d.iteritems():
        out = out.replace(k, v)
    return out


# Commands ######
@validate
def create_cert(pki_dir=None, force=True):
    """
    create certificate on a switch, to be used for authentication with controller
    """
    cmd = '/usr/bin/ovs-pki init'
    if pki_dir:
        cmd += " --dir=%s" % pki_dir
    if force:
        cmd += " --force"
    return command(cmd)


@validate
def sign_certificate(cert_name, cert_dir=None, pki_dir=None, force=True):
    """
    sign certificate
    """
    cmd = "/usr/bin/ovs-pki req+sign %s" % cert_name
    if pki_dir:
        cmd += " --dir=%s" % pki_dir
    if force:
        cmd += " --force"
    r = command(cmd)
    if r[0] == CMD_OK and cert_dir:
        for i in ['cert', 'privkey', 'req']:
            move_file = "%s-%s.pem" % (cert_name, i)
            try:
                shutil.copy2(move_file, cert_dir)
            except shutil.Error:
                continue
            os.remove(move_file)
    return r


@validate
def patch_ovs_ctl_vtep(dst, no_encryption, controller_port, private_key,
                       cert, bootstrap_cert, ha_state, vip, vip_port, repl_sv_ip,
                       repl_sv_port, bridge):
    """
    patch ovs-ctl-vtep
    """
    skip_lines = False
    comment = '#'
    if os.path.isfile(dst):
        with open(dst, 'r') as f:
            f_content = f.readlines()
        if f_content:
            with open(dst, 'w') as f:
                for line in f_content:
                    if 'set "$@" --remote=punix:"$DB_SOCK' in line:
                        f.write(line)
                        if no_encryption == True:
                            f.write('        set "$@" --remote=ptcp:'+str(controller_port)+"\n")
                        else:
                            f.write('        set "$@" --private-key='+private_key+"\n")
                            f.write('        set "$@" --certificate='+cert+"\n")
                            f.write('        set "$@" --bootstrap-ca-cert='+bootstrap_cert+"\n")
                        if ha_state == "active":
                            f.write('        set "$@" --remote=ptcp:'+vip_port+"\n")
                            f.write('        set "$@" --remote=ptcp:'+repl_sv_port+"\n")
                            f.write('        set "$@" --db_vip='+vip+"\n")
                            f.write('        set "$@" --ssl-protocols=TLSv1,TLSv1.1,TLSv1.2'+"\n")
                        if ha_state == "standby":
                            f.write('        set "$@" --sync-from=tcp:'+repl_sv_ip+":"+repl_sv_port+"\n")
                            f.write('        set "$@" --remote=ptcp:'+vip_port+"\n")
                            f.write('        set "$@" --db_vip='+vip+"\n")
                            f.write('        set "$@" --ssl-protocols=TLSv1,TLSv1.1,TLSv1.2'+"\n")
                        skip_lines = True
                    else:
                        if skip_lines == True:
                            if '[ "$OVS_USER" != "" ] && set "$@" --user "$OVS_USER"' in line:
                                f.write(line)
                                skip_lines = False
                        else:
                            f.write(line)

    if os.path.isfile(dst):
        with open(dst, 'r') as f:
            f_content = f.readlines()
        if f_content:
            vtepd_edited = False
            with open(dst, 'w') as f:
                for line in f_content:
                    if '# Start ovs-vtepd' in line and vtepd_edited == False:
                        f.write(line)
                        if ha_state:
                            f.write('        set ovs-vtepd tcp:'+vip+":"+vip_port+"\n")
                            f.write('        set "$@" --enable-vlan-aware-mode'+"\n")
                            if bridge:
                                f.write('        set "$@" --bridge='+bridge+"\n")
                            vtepd_edited = True
                            skip_lines = True
                        else:
                            f.write('        set ovs-vtepd unix:"$DB_SOCK"'+"\n")
                            if bridge:
                                f.write('        set "$@" --bridge='+bridge+"\n")
                            vtepd_edited = True
                            skip_lines = True
                    else:
                        if skip_lines == True:
                            if 'set "$@" --enable-vlan-aware-mode' in line:
                                f.write(line)
                            if 'set "$@" -vconsole:emer -vsyslog:err -vfile:' in line:
                                f.write(line)
                                skip_lines = False
                        else:
                            f.write(line)


@validate
def restart(service):
    """
    restart a service
    """
    return command("systemctl restart %s" % service)


@validate
def define_controller(controller_ip, controller_port):
    """
    define NSX controller IP address in OVSDB
    """
    return command("/usr/bin/vtep-ctl set-manager ssl:%s:%d" % (controller_ip, controller_port))


@validate
def define_tunnel(switch_name, tunnel_ip):
    """
    define local tunnel IP address on the switch
    """
    return command("vtep-ctl set Physical_switch %s tunnel_ips=%s" % (switch_name, tunnel_ip))


@validate
def define_management(switch_name, management_ip):
    """
    define management IP address on the switch
    """
    return command("vtep-ctl set Physical_switch %s management_ips=%s" % (switch_name, management_ip))


@validate
def define_switch(switch_name):
    """
    define physical switch
    """
    out = command("vtep-ctl list-ps")
    switches = out[1].split()
    if len(switches) == 0:
        return command("vtep-ctl add-ps %s" % switch_name)
    else:
        return command("vtep-ctl set Physical_switch %s name=%s" % (switches[0].strip(), switch_name))


@validate
def delete_old_controller_cert(cert):
    """
    delete old controller certificate at the same path
    """
    if os.path.exists(cert):
        return command("rm %s" % cert)


def main(args):
    parser = setup_arg_parser()
    ns = parser.parse_args(args)
    private_key=''
    certificate=''
    bootstrap_ca_cert=''
    vip = ''
    vip_port = ''
    repl_sv_ip = ''
    repl_sv_port = ''
    ha_state = ''
    bridge = ''
    if ns.no_encryption == False:
       credentials_path = os.path.abspath(ns.credentials_path)
       private_key = os.path.join(
           credentials_path, "%s-privkey.pem" % ns.switch_name)
       certificate = os.path.join(
           credentials_path, "%s-cert.pem" % ns.switch_name)
       bootstrap_ca_cert = os.path.join(credentials_path, "controller.cacert")

       create_cert(ns.pki_path)
       sign_certificate(ns.switch_name, ns.credentials_path, ns.pki_path)

    if ns.db_ha:
       ha_state = ns.db_ha

    if ns.db_ha_vip:
       vip, vip_port = ns.db_ha_vip.split(":")

    if ns.db_ha_repl_sv:
       repl_sv_ip, repl_sv_port = ns.db_ha_repl_sv.split(":")

    if ns.bridge:
       bridge = ns.bridge

    patch_ovs_ctl_vtep(OVS_CTL_VTEP_DEST, ns.no_encryption, ns.controller_port, private_key,
                       certificate, bootstrap_ca_cert, ha_state, vip, vip_port, repl_sv_ip,
                       repl_sv_port, bridge)
    define_switch(ns.switch_name)
    if ns.controller_ip:
        define_controller(ns.controller_ip, ns.controller_port)
    define_tunnel(ns.switch_name, ns.tunnel_ip)
    define_management(ns.switch_name, ns.management_ip)
    delete_old_controller_cert(bootstrap_ca_cert)
    restart(VTEP_SERVICE)
    return CMD_OK


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))
