#! /usr/bin/python
# Copyright 2012 Cumulus Networks LLC, all rights reserved

#############################################################################
#
# This is the main script that handles eeprom encoding and decoding
#
try:
    import binascii
    import json
    import optparse
    import warnings
    import os
    import subprocess
    import sys
    import imp
    import eeprom_dts
    import glob
except ImportError as e:
    raise ImportError(str(e) + "- required module not found")


platform_detect = '/usr/bin/platform-detect'
platform_root = '/usr/share/cumulus-platform'
cache_root = '/run/cumulus/decode-syseeprom'


def main():
    #
    # determine the platform
    #
    try:
        ph = subprocess.Popen((platform_detect), stdout=subprocess.PIPE,
                              shell=False, stderr=subprocess.STDOUT)
        cmdout = ph.communicate()[0]
        ph.wait()
    except OSError:
        raise OSError("cannot detect platform")

    [platform, model] = cmdout.rstrip('\n').split(',')
    platform_path = '/'.join([platform_root, platform, model])

    #
    # load the target class file and instantiate the object
    #
    m = None
    try:
        m = imp.load_source('eeprom', '/'.join([platform_path, 'bin/eeprom.py']))
    except IOError:
        raise IOError("cannot load module: " + '/'.join([platform_path, 'bin/eeprom.py']))

    #
    # discover the eeproms and the paths
    # dictionary format: [target: class, content path, status path, read-only]
    #
    dict = discover_eeprom(m)
    if dict is None:
        sys.stderr.write("decode-syseeprom: Unable to query EEPROMS\n")
        sys.exit(0)

    (opts, args) = get_cmdline_opts(dict, 'board')
    if opts.target not in dict.keys():
        errstr = "unknown target, should be one of %s" \
                 % (', '.join(C for C in reversed(dict.keys())))
        raise RuntimeError(errstr)

    class_ = getattr(m, dict[opts.target][0])
    t = class_(opts.target, dict[opts.target][1], dict[opts.target][2],
               dict[opts.target][3])

    #
    # execute the command
    #
    run(t, opts, args)


# -------------------------------------------------------------------------------
#
# discover eeprom paths etc.
#
def discover_eeprom(m):
    # We have two methods for discovering eeproms.  The old style is to walk
    # the device tree and sysfs attempting to match the .dts label property
    # with a sysfs directory.  The new style is to have the kernel promote the
    # .dts label property to a sysfs file for each eeprom.  The kernel also
    # creates an eeprom_dev sysfs class to help enumerate all eeproms on a
    # system.
    #
    # We try the new way first, and fall back to the old.
    eeprom = discover_eeprom_class()
    if eeprom is None or 'board' not in eeprom.keys():
        eeprom_mtd = discover_eeprom_mtd()
        if eeprom is not None:
            eeprom.update(eeprom_mtd)

    if m is not None:
        eeprom_sysfs = discover_sysfs_devices(m)
        if eeprom is not None:
            eeprom.update(eeprom_sysfs)

    return eeprom


def discover_sysfs_devices(m):
    # Go look for eeprom_sysfs_map dictionary in platform specific eeprom.py
    # Elements in the dictonary must be in the following format:
    # Format: {'name' : ['class_name', 'sysfs_path'], }
    # E.g.  : {'psu1' : ['psu', '/sys/devices/platform/vendor_driver/'], }

    devices = {}
    if hasattr(m, 'eeprom_sysfs_map'):
        s_dict = getattr(m, 'eeprom_sysfs_map')

        for name, s_list in s_dict.iteritems():
            devices[name] = (s_list[0], s_list[1], '', True)
    return devices


def discover_eeprom_class():
    eeprom_dev = '/sys/class/eeprom_dev'
    cpld_path = ''
    eeprom = {}

    # eeprom_dev doesn't exist, no kernel support or module not loaded
    if not os.path.isdir(eeprom_dev):
        return None

    eeprom_dirs = os.listdir(eeprom_dev)
    for eeprom_dir in eeprom_dirs:

        label_path = os.path.join(eeprom_dev, eeprom_dir, 'label')
        if not os.path.isfile(label_path):
            continue

        label = open(label_path, 'r').read().strip()

        for e_type in ['psu', 'fan', 'board', 'cpu']:
            if label.startswith(e_type):
                eeprom_path = os.path.join(eeprom_dev, eeprom_dir, 'device', 'eeprom')
                eeprom_key = label.split('_')[0]
                eeprom_class = ''.join([x for x in eeprom_key if not x.isdigit()])
                read_only = not os.access(eeprom_path, os.W_OK)
                eeprom[eeprom_key] = (eeprom_class, eeprom_path, cpld_path, read_only)
                break

    return eeprom


def discover_eeprom_mtd():

    mtd_root = '/dev'
    mtd_info = '/proc/mtd'
    cpld_path = ''
    eeprom = {}

    try:
        ph = subprocess.Popen(['/bin/grep', '-i', 'eeprom', mtd_info],
                              stdout=subprocess.PIPE,
                              shell=False, stderr=subprocess.STDOUT)
        cmdout = ph.communicate()[0]
        ph.wait()
    except OSError:
        raise OSError("eeprom not found in mtd partitions")

    lines = cmdout.splitlines()
    for I in lines:
        I = I.rstrip(':\n\r')
        (partition_node, mtd_size, erasesize, label_node)= I.split(' ')
        partition = partition_node.split(":")[0]
        label = label_node[1:-1]

    for e_type in ['board']:
        if label.startswith(e_type):
            eeprom_path = os.path.join(mtd_root, partition)
            eeprom_key = label.split('_')[0]
            eeprom_class = ''.join([x for x in eeprom_key if not x.isdigit()])
            read_only = not os.access(eeprom_path, os.W_OK)
            eeprom[eeprom_key] = (eeprom_class, eeprom_path, cpld_path, read_only)
            break

    return eeprom

# -------------------------------------------------------------------------------
#
# sets global variable "optcfg"
#
def get_cmdline_opts(dict, default_dev):
    target_str = ("select the target eeprom (" +
                  ', '.join(C for C in reversed(dict.keys())) +
                  ") for read or write operation, default is " +
                  default_dev)
    optcfg = optparse.OptionParser(usage="usage: %s [-a][-r][-s [args]][-t <target>][-e][-m]" % sys.argv[0])
    optcfg.add_option("-a", dest="startmac", action="store_true", default=False,
                      help="print the base mac address for switch interfaces")
    optcfg.add_option("-r", dest="macrange", action="store_true", default=False,
                      help="print the number of macs allocated for switch interfaces")
    optcfg.add_option("-s", dest="set", action="store_true", default=False,
                      help="set the eeprom content if the eeprom is writable. args "
                           "can be supplied in command line in a comma separated "
                           "list of the form '<field>=<value>, ...'.  ',' and '='"
                           "are illegal characters in field names and values. Fields "
                           "that are not specified will default to their current "
                           "values.  If args are supplied in command line, they will "
                           "be written without confirmation.  If args is empty, the "
                           "values will be prompted interactively.")
    optcfg.add_option("-j", "--json", action="store_true", default=False,
                      help="Display JSON output")
    optcfg.add_option("-t", dest="target", action="store", type="string",
                      default=default_dev, help=target_str)
    optcfg.add_option("-e", "--serial", dest="serial", action="store_true",
                      default=False, help="print device serial number")
    optcfg.add_option("-m", dest="mgmtmac", action="store_true", default=False,
                      help="print the base mac address for management interfaces")
    optcfg.add_option("--init", dest="init", action="store_true", default=False,
                      help="clear and initialize board eeprom cache")
    return optcfg.parse_args()

# -------------------------------------------------------------------------------
#
# Run
#


def run(target, opts, args):
    status = target.check_status()
    if status != 'ok':
        sys.stderr.write("Device is not ready: " + status + "\n")
        exit(0)

    if not os.path.exists(cache_root):
        try:
            os.makedirs(cache_root)
        except:
            pass
    if opts.init:
        for file in glob.glob(os.path.join(cache_root, '*')):
            os.remove(file)

    #
    # only the eeprom classes that inherit from eeprom_base
    # support caching. Others will work normally
    #
    try:
        target.set_cache_name(os.path.join(cache_root, opts.target))
    except:
        pass

    e = target.read_eeprom()
    if e is None:
        return 0

    try:
        target.update_cache(e)
    except:
        pass

    if opts.init:
        return 0

    if opts.startmac:
        sa = target.switchaddrstr(e)
        if sa is not None:
            print sa
    elif opts.mgmtmac:
        mm = target.mgmtaddrstr(e)
        if mm is not None:
            print mm
    elif opts.macrange:
        sr = target.switchaddrrange(e)
        if sr is not None:
            print sr
    elif opts.serial:
        try:
            serial = target.serial_number_str(e)
        except NotImplemented as e:
            print e
        else:
            print serial or "Undefined."
    elif opts.set:
        # To modify EEPROM you need root access
        if not os.geteuid() == 0:
            raise RuntimeError("must be root to run")

        if target.is_read_only():
            sys.stderr.write("Device eeprom cannot be modified\n")
            exit(0)

        new_e = target.set_eeprom(e, args)
        yn = 'y'
        if len(args) == 0:
            print "\nPlease review eeprom content:\n"
            target.decode_eeprom(new_e)
            yn = raw_input("\nSave eeprom content? <y/n> ")

        if yn in ['y', 'Y']:
            target.write_eeprom(new_e)

    elif opts.json:
        data = target.decode_eeprom_dictionary(e)
        print json.dumps(data, sort_keys=True, indent=4, encoding="iso8859_15")

    else:
        target.decode_eeprom(e)
        (is_valid, valid_crc) = target.is_checksum_valid(e)
        if is_valid is None:
            return 0
        if is_valid:
            print '(checksum valid)'
        else:
            print '(*** checksum invalid)'
            # + ', should be 0x' + binascii.b2a_hex(array('I', [valid_crc])).upper() + ')'
            exit(1)
    return 0


#
# formats warnings
#
def mywarn(message, category, filename, lineno, line=None):
    return '%s:%s : %s : %s\n' % (filename, lineno, category.__name__, message)

# --------------------
#
# execution check
#
if __name__ == "__main__":
    try:
        warnings.simplefilter("always")
        warnings.formatwarning = mywarn
        exit(main())
    except KeyboardInterrupt:
        sys.stderr.write("\nInterrupted\n")
        exit(1)
    except (RuntimeError, OSError, IOError) as errstr:
        sys.stderr.write("%s : ERROR : %s\n" % (sys.argv[0], str(errstr)))
        exit(1)
