#!/usr/bin/python

#-------------------------------------------------------------------------------
#
# Copyright 2016 Cumulus Networks, inc  all rights reserved
#
#-------------------------------------------------------------------------------

"""
This script has three functions: IPv4 arp refresh, IPv6 neighbor refresh, and
VRR MAC address refresh. This three functions run as separate threads. The ARP
and neighbor refreshes are important to keep those entries in the "reachable"
state, and not have them churn between "reachable" and "stale" or other states
which can cause lots of Netlink messages and significant work by the receipients
of those netlink messages, like switchd.

IPv4 ARP Refresh - The goal is to keep IPv4 neighbors which are still alive from
transitioning out of the reachable state. So this thread periodically sends
arp requests to each of the IPv4 neighbors.

IPv6 Neighbor Refresh - Similiar to the IPv4 ARP Refresh, this thread
periodically sends neighbor solicitation requests to each of the IPv6 neighbors.

VRR MAC address refresh - Although servers send data to the VRR MAC address,
the switch never sends data using the VRR MAC address in the source address
field of a packet. Therefore, any switches between the server and this switch
would flood packets to the VRR MAC address. This thread sends gratuitous ARPs
with the VRR MAC as the source address, to keep the MACs in the forwarding
database of those switches.
"""

#-------------------------------------------------------------------------------
#
#   Import the necessary modules
#
#-------------------------------------------------------------------------------

try:
    import argparse
    import fcntl
    import ipaddr
    import logging
    import logging.handlers
    import nlmanager.nlmanager
    import nlmanager.nlpacket
    import signal
    import socket
    import struct
    import sys
    import threading
    import traceback
    import time
except ImportError, e:
    raise ImportError (str(e) + "- required module not found")

#-------------------------------------------------------------------------------
#
#   Define constants
#
#-------------------------------------------------------------------------------

_ETH_BROADCAST      = "\xFF\xFF\xFF\xFF\xFF\xFF"
_ETH_NULL           = "\x00\x00\x00\x00\x00\x00"

_ARP_ETHTYPE        = 0x0806

_ARP_HW_TYPE_ETH    = 0x0001
_ARP_PROT_TYPE_IP   = 0x0800

_ARP_HW_ETH_LEN     = 6
_ARP_PROT_IP_LEN    = 4

_ARP_REQUEST        = 1
_ARP_REPLY          = 2

_IPV6_SN_MCAST_IP   = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"

_IPV4_BASE_REACHABLE_MS = "/proc/sys/net/ipv4/neigh/default/base_reachable_time_ms"
_IPV6_BASE_REACHABLE_MS = "/proc/sys/net/ipv6/neigh/default/base_reachable_time_ms"

#-------------------------------------------------------------------------------
#
#   Global Variables
#
#-------------------------------------------------------------------------------

# A dictionary with interface name as key and a 4-item tuple as data: a six-byte
# packed MAC address of interface, a 4-byte packed IP address of interface, a
# 16-byte IPv6 address of interface, and raw socket bound to interface. Example:
#
#   {
#       "swp1"  : ( "\x70\x72\xcf\xe9\xf0\x76", "\x0b\x00\x04\0x06", ip6, raw_sock1 ),
#       "bond7" : ( "\x70\x72\xcf\xe9\xf0\x77", "\x0b\x00\x04\0x07", ip6, raw_sock2 ),
#       "br100" : ( "\x70\x72\xcf\xe9\xf0\x78", "\x0b\x00\x04\0x08", ip6, raw_sock3 ),
#   }
#
intfInfoLock = threading.Lock()
intfInfo     = {}

# A cached dictionary of IPv6 addresses on interfaces. The key is the interface
# name and the value is a 16-byte packed string of an IPv6 address for the
# interface.
#
ip6addrsLock = threading.Lock()
ip6addrs = {}

# The handle used for logging messages
logger = logging.getLogger("arp_refresh")

# When set, exit all threads and get out.
stopEvent = threading.Event()

# The netlink manager
nlmLock = threading.Lock()
nlm = None

#-------------------------------------------------------------------------------
#
#   Functions
#
#-------------------------------------------------------------------------------

def ParseCmdLine():
    """
    Parse the command line options using the standard argparse library.
    """
    # The default values for the arguments
    _DEFAULT_COUNT      = 1
    _DEFAULT_LOG_LEVEL  = "WARNING"
    _DEFAULT_LOG_FILE   = None


    parser = argparse.ArgumentParser(description="Send ARP request to a list of hosts")
    parser.add_argument("--count", "-c", type=int, default=_DEFAULT_COUNT,
                        help="The number of packets to send to each host")
    parser.add_argument("--loglevel", "-l", default=_DEFAULT_LOG_LEVEL,
                        help="The logging level for output")
    parser.add_argument("--logfile", "-f", default=_DEFAULT_LOG_FILE,
                        help="The file to log messages to. If not given syslog is used")
    args = parser.parse_args()
    return(args)

def InitLogging(cmdargs):
    """
    Set up the logging according to the command line parameters provided.
    """
    logLevel = getattr(logging, cmdargs.loglevel.upper(), None)
    if not isinstance(logLevel, int):
        raise ValueError("Invalid log level: %s" % cmdargs.loglevel)
        exit(-1)
    logger.setLevel(logLevel)

    lh = logging.handlers.SysLogHandler(address="/dev/log")
    if cmdargs.logfile:
        lh = logging.FileHandler(cmdargs.logfile)
    logger.addHandler(lh)

def SignalHandler(signum, frame):
    """
    Handles any signals registered for and received while code is running.
    """
    if signum == signal.SIGTERM:
        stopEvent.set()

def GetIpv4BaseReachableMs():
    try:
        with open(_IPV4_BASE_REACHABLE_MS) as f:
            return int(f.readline().strip())
    except (IOError, ValueError, TypeError):
        return 4*60*60*1000

def GetIpv6BaseReachableMs():
    try:
        with open(_IPV6_BASE_REACHABLE_MS) as f:
            return int(f.readline().strip())
    except (IOError, ValueError, TypeError):
        return 4*60*60*1000

def GetIpAddress(intf):
    """
    Retrieves the IP address assigned to an interface using IOCTL interface.
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    ifreq = struct.pack('16sH14s', intf, socket.AF_INET, '\x00'*14)
    try:
        res = fcntl.ioctl(s.fileno(), 0x8915, ifreq)
    except:
        logger.info("Could not get IP Address for interface %s" % (intf,))
        return None
    return struct.unpack('16sH2x4s8x', res)[2]

def GetIp6Address(intf):
    """
    Retrieves the IPv6 address assigned to an interface using /proc/net. Sample:
    fe800000000000007272cffffee9f076 58 40 20 80  br0.110
    1. IPv6 address displayed in 32 hexadecimal chars without colons as separator
    2. Netlink device number (interface index) in hexadecimal
    3. Prefix length in hexadecimal
    4. Scope value
    5. Interface flags
    6. Device name
    """
    global ip6addrs

    ip6addrsLock.acquire()
    if not ip6addrs:
        try:
            with open("/proc/net/if_inet6") as f:
                for inet6line in f.readlines():
                    inet6line = inet6line.strip().split()
                    if len(inet6line) == 6 and inet6line[5] not in ip6addrs:
                        ip6addrs[inet6line[5]] = inet6line[0].decode('hex')
        except IOError:
            pass

    ip6addr = ip6addrs.get(intf)
    ip6addrsLock.release()
    return ip6addr

def GetMacAddress(intf):
    """
    Retrieves the MAC address assigned to an interface using IOCTL interface.
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    ifreq = struct.pack('256s', intf[:15])
    try:
        res = fcntl.ioctl(s.fileno(), 0x8927,  ifreq)
    except:
        logger.info("Could not get MAC Address for interface %s" % (intf,))
        return None
    return res[18:24]

def GetRawSocket(intf):
    """
    Creates a raw socket on the specified interface for sending ARP packets.
    """
    s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW)
    try:
        s.bind((intf, _ARP_ETHTYPE))
    except:
        logger.info("Could not bind a raw socket to interface %s" % (intf,))
        return None
    return s

def GetInfoForIntf(intf):
    """
    Given the name of an interface, get the MAC, IP, and open a raw socket
    bound to the inteface. The data is stored in the global intfInfo dict.
    """
    global intfInfo
    ip   = GetIpAddress(intf)
    ip6  = GetIp6Address(intf)
    mac  = GetMacAddress(intf)
    if mac is None:
        return False
    sock = GetRawSocket(intf)
    if sock is None:
        return False
    intfInfoLock.acquire()
    intfInfo[intf] = (mac, ip, ip6, sock)
    intfInfoLock.release()
    return True

def toHexStr(inStr):
    """
    Converts a packed string to a string of colon-separated hex bytes.
    """
    if not inStr:
        return None

    return ":".join("{:02x}".format(ord(c)) for c in inStr)

def checksum(pkt):
    """
    Calculates the traditional one's compliment checksum used in many IETF
    protocols.
    """
    if len(pkt) % 2 == 1:
        pkt += "\0"
    s = (sum([ord(x) for x in pkt[0::2]]) << 8) + sum([ord(x) for x in pkt[1::2]])
    s = (s >> 16) + (s & 0xffff)
    s += s >> 16
    s = ~s
    return s & 0xffff

def NsSend(sock, smac, sip, dip):
    """
    Sends a Neighbor solicitation packet
    """
    #           Type     Code     Checksum     Reserved
    icmp6     = '\x87' + '\x00' + '\x00\x00' + '\x00\x00\x00\x00' + dip
    #           Type     Length
    icmp6opt  = '\x01' + '\x01' + smac
    icmp6len  = struct.pack('!H', len(icmp6) + len(icmp6opt))

    pseudo    = sip + _IPV6_SN_MCAST_IP + dip[13:16] + '\x00\x00' + icmp6len + \
                '\x00\x00\x00' + '\x3a'
    icmp6csum = struct.pack('!H', checksum(pseudo + icmp6 + icmp6opt))
    #           Type     Code                 Reserved
    icmp6     = '\x87' + '\x00' + icmp6csum + '\x00\x00\x00\x00' + dip

    #     Ver, TC, Flow                  Next Hdr  Hop
    ip6 = '\x60\x00\x00\x00' + icmp6len + '\x3a' + '\xff' + sip + \
           _IPV6_SN_MCAST_IP + dip[13:16]
    #     DMAC                                 EthType
    eth = '\x33\x33\xff' + dip[13:16] + smac + '\x86\xdd'

    packet = eth + ip6 + icmp6 + icmp6opt

    sent=False
    logger.debug("About to send a packet (%s)" % (toHexStr(packet),))
    try:
        sock.send(packet)
        sent=True
    except socket.error as e:
        logger.info("Unable to send packet: %s (%s)" % (toHexStr(packet), str(e)))
    return sent

def ArpSend(op, sock, smac, sip, dmac, dip):
    """
    Sends an ARP packet.
    """
    sent=1
    tmac = _ETH_NULL if dmac == _ETH_BROADCAST else dmac
    packet = dmac + smac + \
             struct.pack('!HHHBBH', _ARP_ETHTYPE, _ARP_HW_TYPE_ETH,
                         _ARP_PROT_TYPE_IP, _ARP_HW_ETH_LEN,
                         _ARP_PROT_IP_LEN, op) + \
             smac + sip + tmac + dip
    logger.debug("About to send a packet (%s)" % (toHexStr(packet),))
    try:
        sock.send(packet)
    except socket.error as e:
        logger.info("Unable to send packet: %s (%s)" % (toHexStr(packet), str(e)))
        sent=0
    return sent

def request_link_dump(debug):
    """
    A frontend to the Netlink Manager's request_dump function which makes it
    theadsafe and extracts the relevant data.
    """
    nlmLock.acquire()
    retLinks = {}
    for link in nlm.request_dump(nlmanager.nlpacket.RTM_GETLINK, socket.AF_UNSPEC, debug):
        retLinks[link.ifindex] = (
            link.get_attribute_value(nlmanager.nlpacket.Link.IFLA_IFNAME),
            link.get_attribute_value(nlmanager.nlpacket.Link.IFLA_LINK),
            link.get_attribute_value(nlmanager.nlpacket.Link.IFLA_MASTER))
    nlmLock.release()
    return retLinks

def request_neigh_dump(family, debug):
    """
    A frontend to the Netlink Manager's request_dump function which makes it
    theadsafe and extracts the relevant data.
    """
    nlmLock.acquire()
    retNeighs = []
    neighs = nlm.request_dump(nlmanager.nlpacket.RTM_GETNEIGH, family, debug)
    for neigh in neighs:
        if neigh.state in [neigh.NUD_STALE, neigh.NUD_PROBE,
                           neigh.NUD_DELAY, neigh.NUD_REACHABLE]:
            dst = neigh.get_attribute_value(nlmanager.nlpacket.Neighbor.NDA_DST)
            if dst:
                retNeighs.append((neigh.ifindex, dst))
    nlmLock.release()
    return retNeighs


#-------------------------------------------------------------------------------
#
#   Threads
#
#-------------------------------------------------------------------------------

def IPv4ArpRefresh():
    """
    This function refreshes the IPv4 Neighbor table entries by sending out an
    ARP request packet to each neighbor
    """
    logger.debug("Beginning execution of the thread IPv4ArpRefresh")
    while not stopEvent.is_set():

        # Get the current list of interfaces
        ifByIndex = request_link_dump(False)

        # Get the current list of IPv4 neighbors in the refreshable states that are not mgmt interfaces
        neighs = [ n for n in request_neigh_dump(socket.AF_INET, False) \
                    if not ifByIndex.get(n[0], ("eth", None, None))[0].startswith("eth") ]
        if not len(neighs):
            pktDelay = GetIpv4BaseReachableMs() / (4.0 * 1000.0)
            logger.debug("IPv4: Waiting %f seconds" % (pktDelay,))
            stopEvent.wait(pktDelay)
        else:
            for (ifindex, dip) in neighs:
                (ifname, iflower, ifmaster) = ifByIndex.get(ifindex, ("", None, None))

                # For VRR interfaces, send ARP through the lower interface
                (lifname, liflower, lifmaster) = ifByIndex.get(iflower, ("", None, None))
                (base, sep, idx) = ifname.partition("-v")
                if base == lifname.replace(".", "-") and idx.isdigit():
                    ifname = lifname

                # If we haven't retrieved information for this interface, get it now
                try:
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                except KeyError:
                    intfInfoLock.release()
                    if not GetInfoForIntf(ifname):
                        continue
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                    logger.debug("Interface data has been populated. intf=(%s), smac=(%s), sip=(%s), sip6=(%s)" % (ifname, toHexStr(smac), toHexStr(sip), toHexStr(sip6)))

                # Send the ARP
                if sip:
                    logger.info("Sending ARP to refresh %s on interface %s" % (str(dip), ifname))
                    ArpSend(_ARP_REQUEST, sock, smac, sip, _ETH_BROADCAST, dip.packed)
                else:
                    logger.info("Interface %s does not have an IPv4 addresses. Skipping..." % (ifname,))

                # Figure out how long to wait until the next packet. Delay is 1/4 of
                # base_reachable divided by the number of neighbors
                pktDelay = GetIpv4BaseReachableMs() / (4.0 * 1000.0 * len(neighs))
                logger.debug("IPv4: Waiting %f seconds" % (pktDelay,))
                if stopEvent.wait(pktDelay):
                    break

    logger.debug("Finished execution of the thread IPv4ArpRefresh")

def IPv6NeighborSolicit():
    """
    This function refreshes the IPv6 Neighbor table entries by sending out a
    Neighbor Solicitation packet to each neighbor
    """
    logger.debug("Beginning execution of the thread IPv6NeighborSolicit")
    while not stopEvent.is_set():

        # Get the current list of interfaces
        ifByIndex = request_link_dump(False)

        # Get the current list of IPv6 neighbors in the refreshable states that are not mgmt interfaces
        neighs = [ n for n in request_neigh_dump(socket.AF_INET6, False) \
                    if not ifByIndex.get(n[0], ("eth", None, None))[0].startswith("eth") ]
        if not len(neighs):
            pktDelay = GetIpv6BaseReachableMs() / (4.0 * 1000.0)
            logger.debug("IPv6: Waiting %f seconds" % (pktDelay,))
            stopEvent.wait(pktDelay)
        else:
            for (ifindex, dip) in neighs:
                (ifname, iflower, ifmaster) = ifByIndex.get(ifindex, ("", None, None))

                # For VRR interfaces, send neighbor solicitation through the lower interface
                (lifname, liflower, lifmaster) = ifByIndex.get(iflower, ("", None, None))
                (base, sep, idx) = ifname.partition("-v")
                if base == lifname.replace(".", "-") and idx.isdigit():
                    ifname = lifname

                # If we haven't retrieved information for this interface, get it now
                try:
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                except KeyError:
                    intfInfoLock.release()
                    if not GetInfoForIntf(ifname):
                        continue
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                    logger.debug("Interface data has been populated. intf=(%s), smac=(%s), sip=(%s), sip6=(%s)" % (ifname, toHexStr(smac), toHexStr(sip), toHexStr(sip6)))

                # Send the neighbor solicitation
                if sip6:
                    logger.info("Sending Neighbor Solicitation to refresh %s on interface %s" % (str(dip), ifname))
                    NsSend(sock, smac, sip6, dip.packed)
                else:
                    logger.info("Interface %s does not have an IPv6 addresses. Skipping..." % (ifname,))

                # Figure out how long to wait until the next packet. Delay is 1/4 of
                # base_reachable divided by the number of neighbors
                pktDelay = GetIpv6BaseReachableMs() / (4.0 * 1000.0 * len(neighs))
                logger.debug("IPv6: Waiting %f seconds" % (pktDelay,))
                if stopEvent.wait(pktDelay):
                    break

    logger.debug("Finished execution of the thread IPv6NeighborSolicit")

def VrrMacRefresh():
    """
    This function refreshes the VRR MAC addresses by sending out a gratituous
    ARP packet for each MAC address
    """
    logger.debug("Beginning execution of the thread VrrMacRefresh")
    while not stopEvent.is_set():

        # Get the current list of interfaces
        ifByIndex = request_link_dump(False)

        # Create a list of VRR interfaces
        vrrIf = []
        for ifindex in ifByIndex:
            (ifname, iflower, ifmaster) = ifByIndex.get(ifindex, ("", None, None))

            # Is this a VRR interface?
            (lifname, liflower, lifmaster) = ifByIndex.get(iflower, ("", None, None))
            (base, sep, idx) = ifname.partition("-v")
            if base == lifname.replace(".", "-") and idx.isdigit():
                vrrIf.append(ifname)

        if not len(vrrIf):
            logger.debug("VRR: Waiting %f seconds" % (300.0/2.0))
            stopEvent.wait(300.0/2.0)
        else:
            for ifname in vrrIf:
                # If we haven't retrieved information for this interface, get it now
                try:
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                except KeyError:
                    intfInfoLock.release()
                    if not GetInfoForIntf(ifname):
                        continue
                    intfInfoLock.acquire()
                    (smac, sip, sip6, sock) = intfInfo[ifname]
                    intfInfoLock.release()
                    logger.debug("Interface data has been populated. intf=(%s), smac=(%s), sip=(%s), sip6=(%s)" % (ifname, toHexStr(smac), toHexStr(sip), toHexStr(sip6)))

                # Send the grat ARP or neighbor solicitation
                if sip:
                    logger.info("Sending gratitous ARP to refresh %s on interface %s" % (toHexStr(smac), ifname))
                    ArpSend(_ARP_REQUEST, sock, smac, sip, _ETH_BROADCAST, sip)
                elif sip6:
                    logger.info("Sending gratitous Neighbor Solicitation to refresh %s on interface %s" % (toHexStr(smac), ifname))
                    NsSend(sock, smac, sip6, sip6)

                # Wait to send the next one. We wait 50% of the default bridge aging time
                pktDelay = 300.0/(2.0*len(vrrIf))
                logger.debug("VRR: Waiting %f seconds" % (pktDelay,))
                if stopEvent.wait(pktDelay):
                    break

    logger.debug("Finished execution of the thread VrrMacRefresh")


def ClearCachedInfo():
    """
    This function periodically clears the cached information so that changed to
    the system state (interface ips, etc.) can be picked up.
    """
    global intfInfo
    global ip6addrs
    logger.debug("Beginning execution of the thread ClearCachedInfo")
    while not stopEvent.wait(60):

        logger.debug("Clearing cached information")
        intfInfoLock.acquire()
        intfInfo = {}
        intfInfoLock.release()

        ip6addrsLock.acquire()
        ip6addrs = {}
        ip6addrsLock.release()

    logger.debug("Finished execution of the thread ClearCachedInfo")


#-------------------------------------------------------------------------------
#
#   Threading objects
#
#-------------------------------------------------------------------------------

def IPv4ArpRefreshT():
    try:
        IPv4ArpRefresh()
    except Exception as e:
        logger.exception(e)
        stopEvent.set()

def IPv6NeighborSolicitT():
    try:
        IPv6NeighborSolicit()
    except Exception as e:
        logger.exception(e)
        stopEvent.set()

def VrrMacRefreshT():
    try:
        VrrMacRefresh()
    except Exception as e:
        logger.exception(e)
        stopEvent.set()

def ClearCachedInfoT():
    try:
        ClearCachedInfo()
    except Exception as e:
        logger.exception(e)
        stopEvent.set()

Threads = {
    "IPv4ArpRefresh"       : [IPv4ArpRefreshT,      None],
    "IPv6NeighborSolicit"  : [IPv6NeighborSolicitT, None],
    "VrrMacRefresh"        : [VrrMacRefreshT,       None],
    "ClearCachedInfo"      : [ClearCachedInfo,      None],
}


#-------------------------------------------------------------------------------
#
#   The Main Program
#
#-------------------------------------------------------------------------------

def main():
    # Parse the command line parameters and initialize logging.
    cmdargs = ParseCmdLine()
    InitLogging(cmdargs)
    logger.info("Beginning execution of arp_refresh")

    global nlm
    nlm = nlmanager.nlmanager.NetlinkManager()

    # Create the threads
    global Threads
    for threadName in Threads:
        Threads[threadName][1] = threading.Thread(None, Threads[threadName][0], threadName)
        Threads[threadName][1].daemon = True

    # Start the threads
    for threadName in Threads:
        Threads[threadName][1].start()

    # Wait around until it's time to leave
    try:
        status = 0
        while not stopEvent.wait(60*60*24):
            pass
    except RuntimeError, e:
        logger.error(str(e))
        stopEvent.set()
        status = 2
    except KeyboardInterrupt:
        stopEvent.set()
    except Exception as e:
        logger.exception(e)
        stopEvent.set()
        status = 3

    # Wait for all threads to stop
    for threadName in Threads:
        if Threads[threadName][1]:
            Threads[threadName][1].join()

    # Close any open sockets
    for intf in intfInfo:
        intfInfo[intf][3].close()

    # Get out of here
    logger.info("Execution of arp_refresh is complete")
    sys.exit(status)


if __name__ == '__main__':
    signal.signal(signal.SIGTERM, SignalHandler)
    main()

