#! /usr/bin/python
# Copyright 2017 Cumulus Networks, Inc.
#
# ARP refresh ctl script

import argparse
import json
import select
import socket
import sys

import stream_lib

'''
neighmgrctl communicates with NeighMgrCliT to configure and display databases
maintained by neighmgrd.
'''

#XXX - need to add manpage

def show_nl_link(sockInfo, args):
    rv, _ = do_show_nl_link(sockInfo, args, retJson=False)
    return rv

def do_show_nl_link(sockInfo, args, retJson):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None
    intfs = json.loads(jout)
    printHdr = True
    for intf in intfs:
        ifName = intf.get("ifName", "-")
        ifIndex = intf.get("ifIndex", 0)
        kind = intf.get("kind", "-")
        isBrPort = intf.get("isBrPort", False)
        isBrPort = "yes" if isBrPort else "no"
        master = intf.get("master", "-")
        if printHdr:
            print "Interface           IfIndex        Kind  BrPort           Master"
            print "---------           -------        ----  ------           ------"
            printHdr = False
        print "%-15s  %10d  %10s  %6s  %15s" % (ifName, ifIndex, kind, isBrPort, master)

    return rv, None

def show_addr(sockInfo, args):
    rv, _ = do_show_addr(sockInfo, args, retJson=False)
    return rv

def do_show_addr(sockInfo, args, retJson):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None
    intfs = json.loads(jout)
    printHdr = True
    for intf in intfs:
        ifName = intf.get("ifName", "-")
        ipv4Networks = intf.get("ipv4Networks", [])
        ipv6Networks = intf.get("ipv6Networks", [])
        for net in ipv4Networks:
            if printHdr:
                print "Interface                                         IP address"
                print "---------                                         ----------"
                printHdr = False
            print "%-15s  %43s" % (ifName, net)

        for net in ipv6Networks:
            if printHdr:
                print "Interface                                         IP address"
                print "---------                                         ----------"
                printHdr = False
            print "%-15s  %43s" % (ifName, net)

    return rv, None

def show_invalid_neigh(sockInfo, args):
    rv, _ = do_show_invalid_neigh(sockInfo, args, retJson=False)
    return rv

def do_show_invalid_neigh(sockInfo, args, retJson):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None

    neighs = json.loads(jout)
    printHdr = True
    for neigh in neighs:
        ip = neigh.get("ip", "-")
        ifName = neigh.get("ifName", "-")
        state = neigh.get("state", "-")
        if printHdr:
            printHdr = False
            print "Interface                                     IP address      F"
            print "---------                                     ----------  -----"
        print "%-15s  %39s  0x%03x" % (ifName, ip, state)
    return rv, None

def show_nl_neigh(sockInfo, args):
    rv, _ = do_show_nl_neigh(sockInfo, args, retJson=False)
    return rv

def do_show_nl_neigh(sockInfo, args, retJson):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None

    neighs = json.loads(jout)
    printHdr = True
    for neigh in neighs:
        ip = neigh.get("ip", "-")
        mac = neigh.get("mac", "-")
        ifName = neigh.get("ifName", "-")
        isExt = neigh.get("isExt", False)
        isRouter = neigh.get("isRouter", False)
        isDampened = neigh.get("isDampened", False)
        flags = ""
        if isRouter:
            flags += "R"
        if isExt:
            flags += "E"
        if isDampened:
            flags += "D"
        if printHdr:
            printHdr = False
            print "Interface                                     IP address                MAC    F"
            print "---------                                     ----------                ---  ---"
        print "%-15s  %39s  %17s  %3s" % (ifName, ip, mac, flags)
    return rv, None

def show_snoop_link(sockInfo, args):
    rv, _ = do_show_snoop_link(sockInfo, args, retJson=False)
    return rv

def do_show_snoop_link(sockInfo, args, retJson):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None

    intfs = json.loads(jout)
    printHdr = True
    for intf in intfs:
        ifName = intf.get("ifName", "-")
        rxArpCnt = intf.get("rxArpCnt", 0)
        rxNdCnt = intf.get("rxNdCnt", 0)
        addNeighV4Cnt = intf.get("addNeighV4Cnt", 0)
        addNeighV6Cnt = intf.get("addNeighV6Cnt", 0)
        if printHdr:
            printHdr = False
            print "Interface                  ARP Rx             ND Rx  v4Nbr Adds  v6Nbr Adds"
            print "---------                  ------             -----  ----------  ----------"

        print "%-15s  %16d  %16d  %10d  %10d" % (ifName, rxArpCnt, rxNdCnt, addNeighV4Cnt, addNeighV6Cnt)
    return rv, None

def show_support(sockInfo, args):
    jAllOut={}

    if not args.json:
        print "Summary"
    args.command = "summary"
    rv, jout = do_show_summary(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["summary"] = json.loads(jout)
        except ValueError:
            pass

    if not args.json:
        print "\nLink cache"
    args.command = "showlink"
    rv, jout = do_show_nl_link(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["links"] = json.loads(jout)
        except ValueError:
            pass

    if not args.json:
        print "\nIntf address cache"
    args.command = "showaddr"
    rv, jout = do_show_addr(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["addrs"] = json.loads(jout)
        except ValueError:
            pass

    if not args.json:
        print "\nNeigh cache"
    args.command = "showneigh"
    rv, jout = do_show_nl_neigh(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["neighbors"] = json.loads(jout)
        except ValueError:
            pass

    if not args.json:
        print "\nNeigh (!NUD_VALID) cache"
    args.command = "showinvalidneigh"
    rv, jout = do_show_invalid_neigh(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["invalidNeighbors"] = json.loads(jout)
        except ValueError:
            pass

    if not args.json:
        print "\nSnooper links"
    args.command = "showsnooplink"
    rv, jout = do_show_snoop_link(sockInfo, args, retJson=True)
    if args.json and jout:
        try:
            jAllOut["snooperLinks"] = json.loads(jout)
        except ValueError:
            pass

    if args.json:
        print json.dumps(jAllOut, sort_keys=True, indent=4)
    
    return 0

def show_summary(sockInfo, args):
    rv, _ = do_show_summary(sockInfo, args)
    return rv

def do_show_summary(sockInfo, args, retJson=False):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv, None

    if args.json:
        if retJson:
            return rv, jout
        print jout
        return rv, None

    summary = json.loads(jout)
    print "IPv4 local neighs: %d" % summary.get("neighIpv4", 0)
    print "IPv4 ext neighs: %d" % summary.get("neighExtIpv4", 0)
    print "IPv4 vrr neighs: %d" % summary.get("neighVrrIpv4", 0)
    print "IPv6 local neighs: %d" % summary.get("neighIpv6", 0)
    print "IPv6 ext neighs: %d" % summary.get("neighExtIpv6", 0)
    print "IPv6 vrr neighs: %d" % summary.get("neighVrrIpv6", 0)
    print "IPv4 neigh adds (l/e): %d/%d" % (summary.get("neighAddIpv4", 0),
                                      summary.get("neighAddExtIpv4", 0))
    print "IPv4 neigh mods (l/e): %d/%d" % (summary.get("neighModIpv4", 0),
                                      summary.get("neighModExtIpv4", 0))
    print "IPv4 neigh dels (l/e): %d/%d" % (summary.get("neighDelIpv4", 0),
                                      summary.get("neighDelExtIpv4", 0))
    print "IPv6 neigh adds (l/e): %d/%d" % (summary.get("neighAddIpv6", 0),
                                      summary.get("neighAddExtIpv6", 0))
    print "IPv6 neigh mods (l/e): %d/%d" % (summary.get("neighModIpv6", 0),
                                      summary.get("neighModExtIpv6", 0))
    print "IPv6 neigh dels (l/e): %d/%d" % (summary.get("neighDelIpv6", 0),
                                      summary.get("neighDelExtIpv6", 0))
    print "IPv4 neigh vrr adds: %d" % summary.get("neighAddVrrIpv4", 0)
    print "IPv4 neigh vrr mods: %d" % summary.get("neighModVrrIpv4", 0)
    print "IPv4 neigh vrr dels: %d" % summary.get("neighDelVrrIpv4", 0)
    print "IPv6 neigh vrr adds: %d" % summary.get("neighAddVrrIpv6", 0)
    print "IPv6 neigh vrr mods: %d" % summary.get("neighModVrrIpv6", 0)
    print "IPv6 neigh vrr dels: %d" % summary.get("neighDelVrrIpv6", 0)
    print "IPv6 neigh deferred refreshes: %d" % summary.get("deferredIpv6Refreshes", 0)
    print "neighWorkQ len (curr/max): %d/%d" %\
             (summary.get("neighWorkQLen", 0),\
              summary.get("neighWorkQMaxLen", 0))
    print "total neighWork: %d" % summary.get("neighWorkTotal", 0)
    print "max neighWork: %d" % summary.get("neighWorkMax", 0)
    print "neigh mods dampened: %d" % summary.get("neighModDampens", 0)
    print "neigh dampen threshold: %d" % summary.get("dampenLimit", 0)
    print "neigh dampen timeout: %d" % summary.get("dampenTimeout", 0)
    print "netlink resyncs: %d" % summary.get("resyncs", 0)
    print "snooper subnet checks: %d" % int(summary.get("subnetChecks", 0))
    print "config reloads: %d" % summary.get("configReloads", 0)
    print "setsrcipv4: %s" % summary.get("srcIPv4", None)
    return rv, None

def log_level_command(sockInfo, args):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv

    info = json.loads(jout)
    cmdStatus = info.get("cmdStatus")
    if cmdStatus:
        # config command
        if args.json:
            print jout
        elif cmdStatus != "success":
            # print only if something failed
            print cmdStatus
    else:
        # show command
        if args.json:
            print jout
        else:
            print "logLevel: %s" % info.get("logLevel", "Unknown")
    return rv

def work_rate_command(sockInfo, args):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv

    info = json.loads(jout)
    cmdStatus = info.get("cmdStatus")
    if cmdStatus:
        # config command
        if args.json:
            print jout
        elif cmdStatus != "success":
            # print only if something failed
            print cmdStatus
    else:
        # show command
        if args.json:
            print jout
        else:
            print "workrate: %d" % info.get("workrate", -1)
    return rv

def generic_set_command(sockInfo, args):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv

    info = json.loads(jout)
    cmdStatus = info.get("cmdStatus")
    if cmdStatus:
        # config command
        if args.json:
            print jout
        elif cmdStatus != "success":
            # print only if something failed
            print cmdStatus
    return rv

def profile_command(sockInfo, args):
    rv, jout = neighmgrd_exec_cmd(sockInfo, args)
    if not jout:
        return rv

    info = json.loads(jout)
    cmdStatus = info.get("cmdStatus")
    if cmdStatus:
        # config command
        if args.json:
            print jout
        elif cmdStatus != "success":
            # print only if something failed
            print cmdStatus
    return rv

nmCtlCmds = {
    #Command            : ( Function, Help)
    "showlink"          : ( show_nl_link, "Dump netlink link cache"),
    "showaddr"          : ( show_addr, "Dump link addresses"),
    "showneigh"         : ( show_nl_neigh, "Dump netlink neigh cache"),
    "showinvalidneigh"  : ( show_invalid_neigh, "Dump neighs in a !VALID state"),
    "showsnooplink"     : ( show_snoop_link, "Dump snooper link cache"),
    "summary"           : ( show_summary, "Dump neighmgr state"),
    "showsupport"       : ( show_support, "Dump states for postmortmem analysis"),
    "loglevel"          : ( log_level_command, "Logger level [info, debug, warning]"),
    "resync"            : ( generic_set_command, "Resync kernel cache"),
    "setdampenlimit"    : ( generic_set_command, "Set neigh dampening threshold"),
    "setdampentimeout"  : ( generic_set_command, "Set dampening duration in seconds"),
    "setsubnetchecks"   : ( generic_set_command, "Enable/Disable subnet checks in the snooper"),
    "reloadconfig"      : ( generic_set_command, "Reload config"),
    "cleardampen"       : ( generic_set_command, "Undampen neigh entries"),
    "clearstat"         : ( generic_set_command, "Clear statistics"),
    "workrate"          : ( work_rate_command, "Specify rate in count-per-second"),
    "profile"           : ( profile_command, "Yappi based profiling"),
    "setsrcipv4"        : ( generic_set_command, "IPv4 addr to be used as SIP while refreshing"),
}

NM_UDS_ADDR = "/var/run/neighmgrd/uds"

def neighmgrd_connect():
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    # connect to neighmgrd
    try:
        sock.connect(NM_UDS_ADDR)
    except:
        sock.close()
        sock = None

    sockInfo = stream_lib.StreamInfo(sock, None) if sock else None
    return sockInfo

def neighmgrd_exec_cmd(sockInfo, args):
    command = args.command
    if args.args:
        command += " " + " ".join(args.args)
    sockInfo.build_write_data(command)
    # send the command to neighmgrd and wait for a response
    rv = -1
    jout = ""
    keep_going = True
    while keep_going:
        inputs = [sockInfo.sock]
        outputs = [sockInfo.sock] if sockInfo.writeBuf else []
        try:
            (readable, writeable, exceptional) = select.select(inputs, outputs, inputs+outputs)
        except:
            jout = json.dumps({"errorMsg" : "select exception"})
            break

        for errSock in exceptional:
            keep_going = False
            jout = json.dumps({"errorMsg" : "connect exception"})

        for readSock in readable:
            if readSock != sockInfo.sock:
                continue
            # get the number of bytes we want to read from this socket
            readlen = sockInfo.get_read_data_len()
            data = readSock.recv(readlen)
            if data:
                cmd = sockInfo.proc_read_data(data)
                if cmd:
                    # received response
                    rv = 0
                    jout = cmd
                    keep_going = False
            else:
                jout = json.dumps({"errorMsg" : "read failed"})
                keep_going = False

        for writeSock in writeable:
            if writeSock != sockInfo.sock:
                continue
            try:
                numSent = writeSock.send(sockInfo.writeBuf)
                sockInfo.remove_written_data(numSent)
            except:
                keep_going = False
                jout = json.dumps({"errorMsg" : "write failed"})
    return rv, jout

if __name__ == "__main__":
    epilogStr = "Usage neighmgrctl [-j] [command]:\n"
    for cmd in sorted(nmCtlCmds.iterkeys()):
        epilogStr += "%-20s  %s\n" % (cmd, nmCtlCmds[cmd][1])

    parser = argparse.ArgumentParser(description="neighmgrd control", epilog=epilogStr, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("command", choices=nmCtlCmds.keys(),
                        nargs="?", metavar="command", help="command to execute")
    parser.add_argument("-j", "--json", action="store_true", default=False,
                        help="json output")
    parser.add_argument("args", nargs=argparse.REMAINDER)

    try:
        args = parser.parse_args()
    except Exception as e:
        parser.error(str(e))
        sys.exit(-1)

    if not args.command:
        print epilogStr
        sys.exit(-1)

    # setup connection neighmgrd
    sockInfo = neighmgrd_connect()
    if not sockInfo:
        errMsg = "Failed to connect to neighmgrd"
        if args.json:
            print json.dumps({"errorMsg" : errMsg})
        else:
            print errMsg
        sys.exit(-1)

    # Execute the command
    rv = 0
    try:
        rv = nmCtlCmds[args.command][0](sockInfo, args)
    except Exception as e:
        rv = -1
        errMsg = "Command failed %r" % e
        if args.json:
            print json.dumps({"errorMsg" : errMsg})
        else:
            print errMsg

    # close connection to neighmgrd
    sockInfo.sock.close()
    del sockInfo
    sys.exit(rv)
