#!/usr/bin/python
import ipaddr
import argparse
import re
import os
import sys
import string
import struct
import socket
import dpkt
import cumulus.porttab
import cumulus.platforms
from cumulus.ecmpcalc_util import modport2interface, asic2linux, asictrunkmbr2linux, linux2modport
class MissingFrameInformation(RuntimeError):
    pass
class ParseError(RuntimeError):
    pass
class UnsupportedPlatformError(RuntimeError):
    pass
_chip_id_map = {'56538' : 'Firebolt3',
                '56540' : 'Apollo2',
                '56634' : 'Triumph2',
                '56845' : 'Trident',
                '56846' : 'Trident+',
                '56850' : 'Trident2',
                '56854' : 'Trident2', 
                '56960' : 'Tomahawk',
                '56340' : 'Helix4',
                '56860' : 'Trident2+',
                '56864' : 'Trident2+'}
def chipname(bcm):
    unit_str = bcm.run('show unit')
    unit_re = re.compile('Unit \d+ chip BCM(?P<chip_id>\d+)');
    m = unit_re.match(unit_str);
    if not m:
        raise RuntimeError('could not determine the platform from unit string: %s\n' % unit_str)
    chip_id = m.group('chip_id')
    try:
        return _chip_id_map[chip_id]
    except KeyError:
        return 'Unknown(%s)' % chip_id
def is_chip_supported(bcm):
    ec_chips_supported = ('Trident', 'Trident+', 'Trident2', 'Tomahawk', 'Trident2+')
    chip = chipname(bcm)
    if chip in ec_chips_supported:
        return (1, chip)
    return (0, chip)
def is_chip_pkt_trace_supported(chip):
    _pkt_trace_supported = ('Tomahawk', 'Spectrum')
    if chip in _pkt_trace_supported:
        return True
    return False
class LinuxInterface(str):
    @classmethod
    def from_string(self, string):
        if not string.startswith('swp'):
            raise ParseError('%s is not a valid Linux interface name'
                             ', must start with "swp"' % string)
        return self(string)
class FrameField:
    name = 'Unknown Frame Field'
    def __init__(self, num):
        self.value = num
    @classmethod
    def from_int(self, num):
        return self(num)
    @classmethod
    def from_string(self, string):
        try:
            value = int(string, 0)
        except ValueError:
            raise ParseError('%s is not an int for frame field %s' %
                             (string, self.name))
        return self(value)
    def __str__(self):
        return self.name + ':' + str(self.value)
    __repr__ = __str__
    def asInt(self):
        return self.value
class ProtocolID(FrameField):
    name = 'Protocol ID'
    protocolMap = { 'icmp' : 1,
                    'tcp'  : 6,
                    'udp'  : 17,
                    'icmp6': 58,
                  }
    @property
    def isTCPorUDP(self):
        return self.asInt() in (6, 17)
    @classmethod
    def from_string(self, string):
        if string in self.protocolMap.keys():
            return self(self.protocolMap[string])
        return self.from_int(int(string))
class L4Port(FrameField):
    name = 'L4 Port'
class VLANID(FrameField):
    name = 'VLAN ID'
class HardwarePort(FrameField):
    name = 'Hardware Port'
class HardwareModule(FrameField):
    name = 'Hardware Module'
class HashSeed(FrameField):
    name = 'Hash Seed'
class Frame:
    field_desc = { 'src'      : 'Source IP Address',
                   'dst'      : 'Destination IP Address',
                   'vid'      : 'VLAN ID',
                   'sport'    : 'Source Layer 4 Port',
                   'dport'    : 'Destination Layer 4 Port',
                   'protocol' : 'Layer 3 Protocol',
                   'sportid'  : 'Source Hardware Port ID',
                   'smodid'   : 'Source Hardware Module ID',
                   'dportid'  : 'Destination Hardware Port ID',
                   'dmodid'   : 'Destination Hardware Module ID',
                   }
    def __init__(self, src, dst, vid, sport, dport, protocol, sportid, smodid,
                 dportid, dmodid):
        self.src = src 
        self.dst = dst 
        self.vid = vid 
        self.sport = sport 
        self.dport = dport 
        self.protocol = protocol 
        self.sportid = sportid 
        self.smodid = smodid 
        self.dportid = dportid 
        self.dmodid = dmodid 
    @property
    def srcH(self):
        if self.src is None:
            return None
        return self.src.high()
    @property
    def srcL(self):
        if self.src is None:
            return None
        return self.src.low()
    @property
    def dstH(self):
        if self.dst is None:
            return None
        return self.dst.high()
    @property
    def dstL(self):
        if self.dst is None:
            return None
        return self.dst.low()
    def __str__(self):
        validFields = []
        for fieldName in ('src', 'dst', 'vid', 'sport', 'dport', 'protocol',
                          'sportid', 'smodid', 'dportid', 'dmodid'):
            fieldValue = getattr(self, fieldName)
            if fieldValue is not None:
                validFields.append(fieldName + ':(' + str(fieldValue) + ')')
        return 'IP Frame: ' + ' '.join(validFields)
    __repr__ = __str__
    def _buildEnetHdr(self, smac, dmac, vlan, version):
        data = dmac.replace(":", "") + smac.replace(":", "")
        if vlan:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_8021Q
            data += "%04x" % vlan
        if version is 4:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_IP
        else:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_IP6
        return data
    def buildVisibilityPkt(self, smac, dmac, vid):
        if self.protocol.value == dpkt.ip.IP_PROTO_UDP:
            ipl = dpkt.udp.UDP(sport=self.sport.value, dport=self.dport.value)
        elif self.protocol.value == dpkt.ip.IP_PROTO_TCP:
            ipl = dpkt.tcp.TCP(sport=self.sport.value, dport=self.dport.value)
        else:
            return None
        dstAddr = ipaddr.IPAddress(self.dst.addrstr)
        srcAddr = ipaddr.IPAddress(self.src.addrstr)
        if dstAddr.version is 4:
            ip = dpkt.ip.IP(
                    src=socket.inet_pton(socket.AF_INET, str(srcAddr)),
                    dst=socket.inet_pton(socket.AF_INET, str(dstAddr)),
                    p=self.protocol.value, data=ipl)
            ip.len += len(ipl)
        elif dstAddr.version is 6:
            ip = dpkt.ip6.IP6(
                    src=socket.inet_pton(socket.AF_INET6, str(srcAddr)),
                    dst=socket.inet_pton(socket.AF_INET6, str(dstAddr)),
                    nxt=self.protocol.value, data=ipl, plen=len(ipl),
                    hlim=255)
        else:
            return None
        ehdr = self._buildEnetHdr(smac, dmac, vid, dstAddr.version)
        pkt = ehdr + str(ip).encode("hex")
        return pkt
    def visibilityPktMissingFields(self):
        missing = set()
        for fieldName in ('src', 'dst', 'sport', 'dport', 'protocol'):
            fieldValue = getattr(self, fieldName)
            if fieldValue is None:
                missing.add(self.field_desc[fieldName])
        return missing
class Address(FrameField):
    def __init__(self, string):
        self.addrstr = string
        if self.validate() is not True:
           raise ParseError('"%s" is not a %s' % (self.addrstr, self.name))
    @classmethod
    def from_string(self, string):
        return self(string)
    def __str__(self):
        return self.name + ':' + self.addrstr
    __repr__ = __str__
    def _checkchars(self, valid):
        for char in self.addrstr:
           if char not in valid:
               raise ParseError('invalid character "%s" in %s: "%s"' %
                                (char, self.name, self.addrstr))
        return True
    def validate(self):
        raise NotImplementedError
    def asInt(self):
        raise NotImplementedError
class MACAddress(Address):
    name = 'MAC Address'
    def validate(self):
        split = self.addrstr.lower().split(':')
        if len(split) != 6:
            return False
        return self._checkchars(string.hexdigits + ':')
class IPAddress(Address):
    name = 'IP Address'
    def __init__(self, string):
        Address.__init__(self, string)
        self.addr = ipaddr.IPAddress(string)
    def validate(self):
        try:
            addr = ipaddr.IPAddress(self.addrstr)
        except ValueError, e:
            raise ParseError(e)
        return True
    def _fold6(self, addrint):
        return (((addrint >> 96) & 0xffffffff) ^
                ((addrint >> 64) & 0xffffffff) ^
                ((addrint >> 32) & 0xffffffff) ^
                ((addrint >> 0) & 0xffffffff))
    def asInt(self):
        if self.addr.version is 4:
            return int(self.addr)
        elif self.addr.version is 6:
            return self._fold6(int(self.addr))
        else:
            raise RuntimeError('%s is not a IPv4 or IPv6 address' % str(self.addr))
    def high(self):
        return IPAddressHigh(self.addrstr)
    def low(self):
        return IPAddressLow(self.addrstr)
class IPAddressHigh(IPAddress):
    name = 'IP Address (High 16-bits)'
    def asInt(self):
        val = IPAddress.asInt(self)
        return (val & 0xffff0000) >> 16
class IPAddressLow(IPAddress):
    name = 'IP Address (Low 16-bits)'
    def asInt(self):
        val = IPAddress.asInt(self)
        return val & 0xffff
class RTAG7HashBin:
    bins = ('srcH',
            'srcL',
            'dstH',
            'dstL',
            'vid',
            'sport',
            'dport',
            'protocol',
            'sportid',
            'smodid',
            'dportid',
            'dmodid')
    descriptions = { 'srcH'     : 'Source IP Address',
                     'srcL'     : 'Source IP Address',
                     'dstH'     : 'Destination IP Address',
                     'dstL'     : 'Destination IP Address',
                     'vid'      : 'VLAN ID',
                     'sport'    : 'Source Layer 3 Port',
                     'dport'    : 'Destination Layer 3 Port',
                     'protocol' : 'Layer 3 Protocol',
                     'sportid'  : 'Source Hardware Port ID',
                     'smodid'   : 'Source Hardware Module ID',
                     'dportid'  : 'Destination Hardware Port ID',
                     'dmodid'   : 'Destination Hardware Module ID',
                   }
    width = 2
    def __init__(self, name, description=None):
        self.name = name
        if description is not None:
            self.description = description
    @classmethod
    def from_string(self, string):
        if string not in self.bins:
            raise ParseError('"%s" is not a valid hash bin (%s)' %
                             (string, ','.join(self.bins)))
        return self(string, self.descriptions[string])
    def __str__(self):
        return 'RTAG7HashBin: %s' % (self.name)
    __repr__ = __str__
class RTAG7CRC:
    def __init__(self, initial=0, poly=0x1021):
       self.initial = initial
       self.crc = initial
       self.poly = poly
    def update(self, data):
       for byte in reversed(data):
          for bit in reversed(range(8)):
             msBit = self.crc & 0x8000
             self.crc = (self.crc << 1) | ((ord(byte) >> (7 - bit)) & 0x01)
             if msBit:
                self.crc ^= self.poly
          self.crc &= 0xffff
    @property
    def value(self):
       final = self.crc
       for bit in reversed(range(16)):
          msBit = final & 0x8000
          final = (final << 1)
          if msBit:
             final ^= self.poly
       return (final ^ 0x0000) & 0xffff
    def reset(self):
        self.crc = self.initial
class RTAG7Hash:
    bmapOffsets = {
        'srcH'     : 11,
        'srcL'     : 10,
        'dstH'     : 9,
        'dstL'     : 8,
        'vid'      : 7,
        'sport'    : 6,
        'dport'    : 5,
        'protocol' : 4,
        'sportid'  : 3,
        'smodid'   : 2,
        'dportid'  : 1,
        'dmodid'   : 0,
    }
    hashFunctions = {
        'crc16-ccitt' : RTAG7CRC(poly=0x1021),
        'crc16-bisync' : RTAG7CRC(poly=0x8005),
    }
    hashInputWidth = 224 / 8
    def __init__(self, func, seed, bins):
        self.func = func
        self.seed = seed
        self.bins = bins
        if func not in self.hashFunctions.keys():
            raise RuntimeError('no such hash function: %s' % func)
        self.crc = self.hashFunctions[func]
        self._calcbmap()
    def _calcbmap(self):
        bmap = 0
        for hbin in self.bins:
            bmap |= (1 << self.bmapOffsets[hbin.name])
        return bmap
    def _assembleHashInput(self, frame):
        hashInput = ['\x00'] * self.hashInputWidth
        hashInput[:4] = struct.pack('!I', self.seed.asInt())
        for hbin in self.bins:
            offset = 4 + (11 - self.bmapOffsets[hbin.name]) * hbin.width
            hashInput[offset:offset+hbin.width] =                struct.pack('!H', getattr(frame, hbin.name).asInt())
        return hashInput
    def hashable(self, frame):
        missing = set()
        for hbin in self.bins:
            if getattr(frame, hbin.name) is None:
                missing.add(hbin.description)
        if len(missing):
            raise MissingFrameInformation, missing
        return True
    def calculate(self, frame):
        self.hashable(frame)
        hashInput = self._assembleHashInput(frame)
        self.crc.reset()
        self.crc.update(hashInput)
        return self.crc.value
class RTAG7Config:
    chips_supported = ('Trident', 'Trident+', 'Trident2', 'Trident2+')
    bmapOffsets = {
        'srcH'     : 11,
        'srcL'     : 10,
        'dstH'     : 9,
        'dstL'     : 8,
        'vid'      : 7,
        'sport'    : 6,
        'dport'    : 5,
        'protocol' : 4,
        'sportid'  : 3,
        'smodid'   : 2,
        'dportid'  : 1,
        'dmodid'   : 0,
    }
    hashFunctions = { 3 : 'crc16-bisync',
                      9 : 'crc16-ccitt',
                    }
    def __init__(self, bcmshell):
        self.bcm = bcmshell
        self.chip = chipname(self.bcm)
        self._check_chip_id()
    def _check_chip_id(self):
        if self.chip not in self.chips_supported:
            raise NotImplementedError('unsupported platform: %s' % self.chip)
    def load(self):
        for register in ('hash_control',
                         'rtag7_hash_seed_a',
                         'rtag7_ipv4_tcp_udp_hash_field_bmap_2',
                         'rtag7_ipv4_tcp_udp_hash_field_bmap_1',
                         'rtag7_hash_field_bmap_1',
                         'rtag7_hash_field_bmap_1',
                         'rtag7_hash_control_3'):
            setattr(self, register, self.bcm.getreg(register, fields=True))
        if self.chip in ('Trident', 'Trident+'):
            for register in ( 'rtag7_hash_ecmp(0)', 'rtag7_hash_ecmp(1)'):
                setattr(self, register, self.bcm.getreg(register, fields=True))
    def _registersUnderstood(self):
        if self.hash_control['ECMP_HASH_USE_RTAG7'] != 1:
            raise NotImplementedError('unsupported hardware configuration, '
                                      'hash_control.ECMP_HASH_USE_RTAG7 != 1')
        if (self.rtag7_hash_control_3['HASH_A0_FUNCTION_SELECT']
            not in self.hashFunctions.keys()):
            raise NotImplementedError('unsupported hardware configuration, '
                                      'rtag7_hash_control_3.HASH_A0_FUNCTION_SELECT '
                                      'not in: ' % self.hashFunctions.keys())
        if self.chip in ('Trident', 'Trident+') and                    (sum(getattr(self, 'rtag7_hash_ecmp(0)').values() +
                    getattr(self, 'rtag7_hash_ecmp(1)').values()) != 0):
            raise NotImplementedError('unsupported hardware configuration, '
                                      'rtag7_hash_ecmp({0,1}) must be zero')
    def getHashSeed(self, frame):
        self._registersUnderstood()
        return self.rtag7_hash_seed_a['HASH_SEED_A']
    def getHashFunction(self, frame):
        self._registersUnderstood()
        return self.hashFunctions[self.rtag7_hash_control_3['HASH_A0_FUNCTION_SELECT']]
    def getHashFields(self, frame):
        self._registersUnderstood()
        if frame.protocol is not None and frame.protocol.isTCPorUDP:
            if frame.sport is None or frame.dport is None:
                parser.error('--sport and --dport required for TCP and UDP frames')
            if frame.sport.value == frame.dport.value:
                bmap = self.rtag7_ipv4_tcp_udp_hash_field_bmap_1[
                       'IPV4_TCP_UDP_SRC_EQ_DST_FIELD_BITMAP_A']
            else:
                bmap = self.rtag7_ipv4_tcp_udp_hash_field_bmap_2[
                       'IPV4_TCP_UDP_FIELD_BITMAP_A']
        else:
            bmap = self.rtag7_hash_field_bmap_1['IPV4_FIELD_BITMAP_A']
        hashFields = []
        for field, offset in self.bmapOffsets.items():
           if bmap & (1 << offset):
               hashFields.append(RTAG7HashBin.from_string(field))
        return hashFields
class EgressObject:
    def __init__(self, ipv4net, egressId):
        self.ipv4net = ipv4net
        self.egressId = egressId
class EgressTable:
    def __init__(self, bcmshell, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose
    def findEgressObjectByDst(self, dst):
        if dst.version is 4:
            table = self.bcm.run('l3 defip show')
        elif dst.version is 6:
            table = self.bcm.run('l3 ip6route show')
        else:
            raise RuntimeError('Error: invalid IP version: %s' % dst)
        intf = 0
        net = 0
        for line in table.split('\n'):
            try:
                (foo, foo, netaddr, foo, intfstr, foo) = line.split(None, 5)
                new_net = ipaddr.IPNetwork(netaddr)
                new_intf = int(intfstr)
            except ValueError:
                continue
            if dst in new_net and (intf == 0 or new_net > net):
                net = new_net
                intf = new_intf
        if intf == 0:
            raise RuntimeError('Error: no egress object for destination: %s' % dst)
        return EgressObject(net, intf)
class ECMPConfig:
    egressObjectOffset = 200000
    egressObjectMax = 204095
    chips_supported = ('Trident', 'Trident+', 'Trident2', 'Trident2+')
    def __init__(self, bcmshell, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose
        self.chip = chipname(self.bcm)
        self._check_chip_id()
    def _check_chip_id(self):
        if self.chip not in self.chips_supported:
            raise NotImplementedError('unsupported platform: %s' % self.chip)
    @classmethod
    def objectIsECMP(self, egressId):
        if egressId >= self.egressObjectOffset and egressId <= self.egressObjectMax:
            return True
        else:
            return False
    def getECMPGroup(self, egressId):
        if not self.objectIsECMP(egressId):
            raise RuntimeError('egress object ID expected to be between %d and %d' %
                               (self.egressObjectOffset, self.egressObjectMax))
        groupOffset = egressId - self.egressObjectOffset
        group = self.bcm.gettable('L3_ECMP_GROUP',
                                  fields=True,
                                  start=groupOffset,
                                  entries=1)[0]
        if self.verbose:
            sys.stderr.write('ecmp group offset %d\n' % groupOffset)
        return group
    def pageLookup(self, egressId, pageNum):
        group = self.getECMPGroup(egressId)
        if self.chip is 'Trident':
            (base, count) = (group['BASE_PTR_%d' % pageNum],
                         group['COUNT_%d' % pageNum] + 1) 
        elif self.chip in ('Trident+', 'Trident2', 'Trident2+'):
            (base, count) = (group['BASE_PTR'],
                         group['COUNT'] + 1) 
        else:
            raise UnsupportedPlatformError('ecmp page is not available on %s' % self.chip)
        if self.verbose:
            sys.stderr.write('base,count: %d,%d\n' % (base, count))
        return (base,count)
    def getEcmpMask(self):
        mask = {0: 0x3ff,
                1: 0x7ff,
                2: 0xfff,
                3: 0x1fff,
                4: 0x3fff,
                5: 0x7fff,
                6: 0xffff,
                7: 0}
        if self.chip is 'Trident':
            ecmpMask = 0x3ff
        elif self.chip in ('Trident+', 'Trident2', 'Trident2+'):
            hash_control =  self.bcm.getreg('hash_control', fields=True)
            ecmpMask = mask[hash_control['ECMP_HASH_FIELD_UPPER_BITS_COUNT']]
        else:
            raise UnsupportedPlatformError('ecmp mask is not available on %s' % self.chip)
        return ecmpMask
    def getECMP(self, egressId, frameHash):
        ecmpPage = (frameHash & 0x0c00) >> 10
        (ecmpBase, ecmpCount) = self.pageLookup(egressId, ecmpPage)
        ecmpMask = self.getEcmpMask()
        ecmpOffset = (frameHash & ecmpMask) % ecmpCount
        ecmpPtr = ecmpBase + ecmpOffset
        ecmp = self.bcm.gettable('L3_ECMP',
                                 fields=True,
                                 start=ecmpPtr,
                                 entries=1)[0]
        if self.verbose:
            sys.stderr.write('ecmpMask 0x%x L3_ECMP+%d\n' % (ecmpMask, ecmpPtr))
        return ecmp
    def getNextHop(self, egressId, frameHash):
        ecmp = self.getECMP(egressId, frameHash)
        nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                 fields=True,
                                 start=ecmp['NEXT_HOP_INDEX'],
                                 entries=1)[0]
        if self.verbose:
            sys.stderr.write('ING_L3_NEXT_HOP+%d\n' % ecmp['NEXT_HOP_INDEX'])
        return nextHop
    def getPort(self, egressId, frameHash):
        nextHop = self.getNextHop(egressId, frameHash)
        if nextHop['T']:
            trunk = self.bcm.gettable('TRUNK_GROUP',
                                 fields=True,
                                 start=nextHop['TGID'],
                                 entries=1)[0]
            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                 fields=True,
                                 start=trunk['BASE_PTR'],
                                 entries=1)[0]
            modId = trunkMbr['MODULE_ID']
            port = trunkMbr['PORT_NUM']
            isTrunkMbr = True
        else:
            modId = nextHop['MODULE_ID']
            port = nextHop['PORT_NUM']
            isTrunkMbr = False
        return isTrunkMbr, modport2interface(modId, port, self.chip)
    def getRhPort(self, rh_cmd):
        NHId = 0
        NHObjectOffset = 100000
        table = self.bcm.run(rh_cmd)
        for line in table.splitlines():
            str = line.strip().split(' ')
            if 1 == len(str):
                continue
            if (str[1] in 'destination'):
                NHId = int(str[0])
        if self.verbose:
            sys.stderr.write('rh_cmd %s\n' % rh_cmd)
            sys.stderr.write('NextHopId %d NhIdOffset %d\n' % (NHId, NHObjectOffset))
        if NHId < NHObjectOffset:
            parser.error('Invalid parameter: unable to determine destination port')
        NHIdOffset = (NHId - NHObjectOffset)
        nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                 fields=True,
                                 start=NHIdOffset,
                                 entries=1)[0]
        if self.verbose:
            sys.stderr.write('NextHopId %d Port_num %d\n' %
                        (NHId, nextHop['PORT_NUM']))
        if nextHop['T']:
            trunk = self.bcm.gettable('TRUNK_GROUP',
                                 fields=True,
                                 start=nextHop['TGID'],
                                 entries=1)[0]
            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                 fields=True,
                                 start=trunk['BASE_PTR'],
                                 entries=1)[0]
            modId = trunkMbr['MODULE_ID']
            port = trunkMbr['PORT_NUM']
            isTrunkMbr = True
        else:
            modId = nextHop['MODULE_ID']
            port = nextHop['PORT_NUM']
            isTrunkMbr = False
        return isTrunkMbr, modport2interface(modId, port, self.chip)
    def getLinuxIntf(self, isTrunkMbr, port):
        if self.verbose:
            sys.stderr.write('bcm interface %s\n' % port)
        return asictrunkmbr2linux(port) if isTrunkMbr else asic2linux(port)
    def getOffset(self, mcount, frameHash):
        return (frameHash & 0x3ff) % mcount
class VisibilityPktTrace:
    def __init__(self, bcmshell, frame, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose
        self.pt = cumulus.porttab.porttab()
        self.logical_port = frame.sportid.value
        self.dport_name = self.pt.logical2dport(int(self.logical_port))
        self.knownL3UcPkt = False
        self.egressId = -1
        self._runTrace(frame)
    def _getPortVid(self):
        portEnt = self.bcm.gettable('PORT',
                                 fields=True,
                                 start=self.logical_port,
                                 entries=1)[0]
        pvid = portEnt['PORT_VID']
        if self.verbose:
            sys.stderr.write('logical_port %s has pvid %d\n' % (self.logical_port, pvid))
        return pvid
    def _getPortMac(self):
        port_str = self.bcm.run('port %s' % self.dport_name)
        pat = re.compile('Stad\((?P<pmac>\S+)\)')
        obj = pat.search(port_str)
        pmac = obj.group('pmac')
        if self.verbose:
            sys.stderr.write('port %s has pmac %s\n' % (self.dport_name, pmac))
        return pmac
    def _runTrace(self, frame):
        pmac = self._getPortMac()
        pvid = self._getPortVid()
        pkt = frame.buildVisibilityPkt(smac=pmac, dmac=pmac, vid=pvid)
        if not pkt:
            raise RuntimeError('Error: visibility pkt trace generation failed')
        if self.verbose:
            sys.stderr.write('Trace pkt:\n  %s\n' % pkt)
        trace_str = self.bcm.run('tx 1 VisibilitySourcePort=%s DATA=%s' % (self.dport_name, pkt))
        pat = re.compile('KnownL3UcPkt')
        obj = pat.search(trace_str)
        if not obj:
            return
        self.knownL3UcPkt = True
        if self.verbose:
            sys.stderr.write('Trace packet was L3 switched\n')
        pat = re.compile('ecmp_1_egress (?P<egressId>\d+)')
        obj = pat.search(trace_str)
        if obj:
            self.egressId = int(obj.group('egressId'))
            if self.verbose:
                sys.stderr.write('Trace packet was switched via egress %d\n' % self.egressId)
            sdk_offset = 100000
            if self.egressId < sdk_offset:
                raise RuntimeError('Error: egress id %d format is unexpected' % self.egressId)
            self.egressId = self.egressId - sdk_offset
    def getL3PktRes(self):
        return self.knownL3UcPkt, self.egressId != -1
    def getOutLinuxIntf(self):
        nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                 fields=True,
                                 start=self.egressId,
                                 entries=1)[0]
        if nextHop['T']:
            trunk = self.bcm.gettable('TRUNK_GROUP',
                                 fields=True,
                                 start=nextHop['TGID'],
                                 entries=1)[0]
            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                 fields=True,
                                 start=trunk['BASE_PTR'],
                                 entries=1)[0]
            port = trunkMbr['PORT_NUM']
            isTrunkMbr = True
        else:
            port = nextHop['PORT_NUM']
            isTrunkMbr = False
        port = self.pt.logical2dport(port)
        return asictrunkmbr2linux(port) if isTrunkMbr else asic2linux(port)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Determine the ECMP egress interface for a frame')
    parser.add_argument('-v', '--verbose',
                        required=False,
                        action='store_true',
                        help='Verbose output')
    parser.add_argument('-p', '--protocol',
                        required=False,
                        type=ProtocolID.from_string,
                        help='IP protocol ("tcp", "udp", or integer)')
    parser.add_argument('-s', '--src',
                        required=False,
                        type=IPAddress.from_string,
                        help='Source IPv4 or IPv6 address')
    parser.add_argument('--sport',
                        required=False,
                        type=L4Port.from_string,
                        help='Source L4 port')
    parser.add_argument('-d', '--dst',
                        required=False,
                        type=IPAddress.from_string,
                        help='Destination IPv4 or IPv6 address')
    parser.add_argument('--dport',
                        required=False,
                        type=L4Port.from_string,
                        help='Destination L4 port')
    parser.add_argument('--vid',
                        required=False,
                        type=VLANID.from_string,
                        help='VLAN ID')
    parser.add_argument('-i', '--in-interface',
                        required=False,
                        type=LinuxInterface.from_string,
                        help='Input interface')
    parser.add_argument('--sportid',
                        required=False,
                        type=HardwarePort.from_string,
                        help='Hardware source port ID')
    parser.add_argument('--smodid',
                        required=False,
                        type=HardwareModule.from_string,
                        help='Hardware source module ID')
    parser.add_argument('-o', '--out-interface',
                        required=False,
                        type=LinuxInterface.from_string,
                        help='Output interface')
    parser.add_argument('--dportid',
                        required=False,
                        type=HardwarePort.from_string,
                        help='Hardware destination port ID')
    parser.add_argument('--dmodid',
                        required=False,
                        type=HardwareModule.from_string,
                        help='Hardware destination module ID')
    parser.add_argument('--hardware',
                        required=False,
                        default=True,
                        action='store_true',
                        help='Force hardware query, ecmpcalc will read the '
                             'hash configuration from hardware')
    parser.add_argument('--nohardware',
                        required=False,
                        dest='hardware',
                        action='store_false',
                        help='Force no hardware query')
    parser.add_argument('-hs', '--hashseed',
                        required=False,
                        type=HashSeed.from_string,
                        help='RTAG7 hash seed')
    parser.add_argument('-hf', '--hashfields',
                        required=False,
                        nargs='+',
                        type=RTAG7HashBin.from_string,
                        help='list of RTAG7 hash fields, one or more of (%s)' %
                             ', '.join(RTAG7HashBin.bins))
    parser.add_argument('--hashfunction',
                        required=False,
                        choices=('crc16-ccitt', 'crc16-bisync'),
                        help='RTAG7 hash function')
    parser.add_argument('-e', '--egress',
                        required=False,
                        type=int,
                        help='multipath egress object as from "/usr/lib/cumulus/bcmcmd l3 defip show"')
    parser.add_argument('-c', '--mcount',
                        required=False,
                        type=int,
                        help='ECMP group member count')
    try:
        args = parser.parse_args()
    except ParseError, e:
        parser.error(str(e))
    if (os.geteuid() != 0):
        sys.stderr.write('root privileges are needed to run ecmpcalc\n')
        sys.exit(-1)
    platform_object = cumulus.platforms.probe()
    if platform_object.switch.chip.sw_base == 'mlx':
        mlx = True
        try:
            import cumulus.mlx_ecmpcalc
            asicShellAvailable = True
        except ImportError:
            asicShellAvailable = False
    else:
        mlx = False
        try:
            import bcmshell
            asicShellAvailable = True
        except ImportError:
            asicShellAvailable = False
    if args.hardware:
        if not asicShellAvailable:
            sys.stderr.write('ecmpcalc: asicTool not available, unable to query hardware\n')
            sys.exit(-1)
    if mlx:
        chip = "Spectrum"
        if not args.hardware:
            sys.stderr.write('ecmpcalc: only hardware mode supported on mellanox platforms\n')
            sys.exit(-1)
        if not args.in_interface:
            sys.stderr.write('ecmpcalc: --in-interface is required on mellanox platforms\n')
            sys.exit(-1)
    else:
        if args.hardware:
            try:
                bcmshell.bcmshell().run('echo foo')
            except IOError as e:
                sys.stderr.write(('%s\n' % e))
                sys.exit(-1)
            (rv, chip) = is_chip_supported(bcmshell.bcmshell())
            if rv == 0:
                sys.stderr.write('ecmpcalc is not supported on this chip (%s)\n' % chip)
                sys.exit(-1)
    if args.hardware:
        sys.stderr.write('ecmpcalc: will query hardware\n')
    else:
        sys.stderr.write('ecmpcalc: will NOT query hardware\n')
    if args.in_interface and (args.sportid is not None or
                              args.smodid is not None):
        parser.error('require only one of --in-interface OR (--sportid/--smodid)')
    if args.out_interface and (args.dportid is not None or
                               args.dmodid is not None):
        parser.error('require only one of --out-interface OR (--dportid/--dmodid)')
    if args.hardware is False and (args.in_interface is not None or
                                   args.out_interface is not None):
        parser.error('unable to query hardware for interface translation')
    if args.in_interface is not None:
        (mod, port) = linux2modport(args.in_interface, chip)
        args.smodid = HardwareModule.from_string(str(mod))
        args.sportid = HardwareModule.from_string(str(port))
    if args.out_interface is not None:
        (mod, port) = linux2modport(args.out_interface, chip)
        args.dmodid = HardwareModule.from_string(str(mod))
        args.dportid = HardwareModule.from_string(str(port))
    frame = Frame(args.src, args.dst, args.vid, args.sport, args.dport,
                  args.protocol, args.sportid, args.smodid, args.dportid,
                  args.dmodid)
    if args.verbose:
        sys.stderr.write('frame: %s\n' % frame)
    if args.hardware and is_chip_pkt_trace_supported(chip):
        if not args.sportid:
            parser.error('require --in-interface OR (--sportid/--smodid)')
        missing = frame.visibilityPktMissingFields()
        if missing:
            parser.error('Frame information is incomplete.\n'
                     'Please specify additional options for: %s' % ', '.join(missing))
        if args.verbose:
            sys.stderr.write('using packet trace...\n')
        if mlx:
            linux_intf = cumulus.mlx_ecmpcalc.getEcmpMbrLinuxIntf(args.in_interface, frame, args.verbose)
        else:
            pktTrace = VisibilityPktTrace(bcmshell.bcmshell(), frame,
                                  verbose=args.verbose)
            isL3, isEcmp = pktTrace.getL3PktRes()
            if not isL3:
                raise RuntimeError('Error: traffic to %s will not be L3 switched' % args.dst.addrstr)
            if not isEcmp:
                raise RuntimeError('Error: traffic to %s will not ECMP' % args.dst.addrstr)
            linux_intf = pktTrace.getOutLinuxIntf()
        sys.stdout.write('%s\n' % linux_intf)
        sys.exit(0)
    if mlx:
        raise UnsupportedPlatformError('ecmpcalc is not available on %s' % chip)
    if args.hardware:
        rtag7config = RTAG7Config(bcmshell.bcmshell())
        rtag7config.load()
        if args.hashseed is None:
            setattr(args, 'hashseed', HashSeed.from_int(rtag7config.getHashSeed(frame)))
            if args.verbose:
                sys.stderr.write('hardware hashseed    : 0x%08x\n' % args.hashseed.value)
        if args.hashfunction is None:
            setattr(args, 'hashfunction', rtag7config.getHashFunction(frame))
            if args.verbose:
                sys.stderr.write('hardware hashfunction: %s\n' % args.hashfunction)
        if args.hashfields is None:
            setattr(args, 'hashfields', rtag7config.getHashFields(frame))
            if args.verbose:
                sys.stderr.write('hardware hashfields  : %s\n' % args.hashfields)
    if args.hashseed is None:
        parser.error('--hashseed required')
    if args.hashfunction is None:
        parser.error('--hashfunction required')
    if args.hashfields is None:
        args.hashfields = []
    if args.protocol is not None and (args.protocol.value > 255 or
                                     args.protocol.value < 0):
        parser.error('Invalid protocol value %d\n' % (args.protocol.value))
    rtag7hash = RTAG7Hash(args.hashfunction, args.hashseed, args.hashfields)
    try:
        frameHash = rtag7hash.calculate(frame)
    except MissingFrameInformation, missing:
        parser.error('unable to calculate hash.  Frame information is incomplete.\n'
                     'Please specify additional options for: %s' % ', '.join(*missing))
    if args.egress and not asicShellAvailable:
        parser.error('unable to load multipath egress object from hardware')
    if args.dst is None and args.egress is None and args.mcount is None:
        parser.error('need a destination IP OR multipath egress object OR a '
                     'ECMP group member count')
    rhEnChipset = False
    if chip in ('Trident2', 'Trident2+') :
        rh_enable_path = '/cumulus/switchd/config/traffic/resilient_hash_enable'
        if os.path.exists(rh_enable_path):
            f = open(rh_enable_path)
            rh_enable = f.read()
            if 'TRUE' in rh_enable:
                rhEnChipset = True
            f.close()
    if (rhEnChipset and (args.src is None or args.dst is None or
            args.sport is None or args.dport is None or args.protocol is None
            or args.hardware is None or args.in_interface is None)):
        parser.error('please specify following options for'
                ' resilient hashing: -s, -d, --sport, --dport, -p, -i, '
                '--hardware')
    if args.egress is not None and not rhEnChipset:
        ecmp = ECMPConfig(bcmshell.bcmshell(), verbose=args.verbose)
        isTrunkMbr, port = ecmp.getPort(args.egress, frameHash)
        linux_intf = ecmp.getLinuxIntf(isTrunkMbr, port)
        sys.stdout.write('%s\n' % linux_intf)
    elif args.mcount is not None and not rhEnChipset:
        ecmp = ECMPConfig(None, verbose=args.verbose)
        ecmpOffset = ecmp.getOffset(args.mcount, frameHash)
        sys.stdout.write('%d\n' % ecmpOffset)
    elif args.dst is not None:
        egress = EgressTable(bcmshell.bcmshell(timeout=60), verbose=args.verbose)
        dstAddr = ipaddr.IPAddress(args.dst.addrstr)
        egressObj = egress.findEgressObjectByDst(dstAddr)
        ecmp = ECMPConfig(bcmshell.bcmshell(), verbose=args.verbose)
        if ecmp.objectIsECMP(egressObj.egressId) is False:
            raise RuntimeError('Error: traffic to %s will not egress ECMP' % args.dst.addrstr)
        if rhEnChipset:
            if dstAddr.version is 4:
                ethertype = 0x0800
                rh_cmd = ('hd dest get Group=ECMP GID=%s Port=%s SrcIp=%s '
                        'DestIp=%s L4SrcPort=%s L4DstPort=%s Ethertype=%s '
                        'Protocol=%s' % (egressObj.egressId, args.sportid.value,
                            args.src.addrstr, args.dst.addrstr, args.sport.value,
                            args.dport.value, ethertype, args.protocol.value))
            elif dstAddr.version is 6:
                ethertype = 0x86dd
                srcNet = ipaddr.IPNetwork(args.src.addrstr)
                dstNet = ipaddr.IPNetwork(args.dst.addrstr)
                rh_cmd = ('hd dest get Group=ECMP GID=%s Port=%s SIp6=%s '
                        'DIp6=%s L4SrcPort=%s L4DstPort=%s Ethertype=%s '
                        'Protocol=%s' % (egressObj.egressId, args.sportid.value,
                            srcNet.ip.exploded, dstNet.ip.exploded, args.sport.value,
                            args.dport.value, ethertype, args.protocol.value))
            isTrunkMbr, port = ecmp.getRhPort(rh_cmd)
        else:
            isTrunkMbr, port = ecmp.getPort(egressObj.egressId, frameHash)
        linux_intf = ecmp.getLinuxIntf(isTrunkMbr, port)
        sys.stdout.write('%s\n' % linux_intf)

