#!/usr/bin/python

#-------------------------------------------------------------------------------
#
# Copyright 2016 Cumulus Networks, inc  all rights reserved
#
#-------------------------------------------------------------------------------

#
#   Import the necessary modules
#
try:
    import argparse
    import logging
    import logging.handlers
    import os
    import subprocess
    import sys
    import xml.etree.ElementTree as ET
except ImportError, e:
    raise ImportError (str(e) + "- required module not found")

#-------------------------------------------------------------------------------
#
#   Global Variables
#
#-------------------------------------------------------------------------------

# The handle used for logging messages
logger = logging.getLogger("mlxfwupgrade")


#-------------------------------------------------------------------------------
#
#   Functions
#
#-------------------------------------------------------------------------------

def ParseCmdLine():
    """
    Parse the command line options using the standard argparse library.
    """
    # The default values for the arguments
    _DEFAULT_LOG_LEVEL  = "WARNING"
    _DEFAULT_LOG_FILE   = None
    _DEFAULT_FW_PATH    = "/usr/share/mellanox/firmware"


    parser = argparse.ArgumentParser(description="Verify that SDK version matches firmware version")
    parser.add_argument("--loglevel", "-l", default=_DEFAULT_LOG_LEVEL,
                        help="The logging level for output")
    parser.add_argument("--logfile", "-f", default=_DEFAULT_LOG_FILE,
                        help="The file to log messages to. If not given syslog is used")
    parser.add_argument("--fwpath", default=_DEFAULT_FW_PATH,
                        help="The directory where the firmware files exist")

    args = parser.parse_args()
    return(args)

def InitLogging(cmdargs):
    """
    Set up the logging according to the command line parameters provided.
    """
    logLevel = getattr(logging, cmdargs.loglevel.upper(), None)
    if not isinstance(logLevel, int):
        raise ValueError("Invalid log level: %s" % cmdargs.loglevel)
        exit(-1)
    logger.setLevel(logLevel)

    lh = logging.handlers.SysLogHandler(address="/dev/log")
    lf = logging.Formatter(fmt="%(name)s[%(process)d]: %(message)s")
    if cmdargs.logfile:
        lh = logging.FileHandler(cmdargs.logfile)
        lf = logging.Formatter(fmt="%(asctime)s %(name)s[%(process)d]: %(message)s")
    lh.setFormatter(lf)
    logger.addHandler(lh)

def GetFwInfo(fwDir):
    """
    Get the information about the firmware in the system. The information is
    returned as an xml.etree.ElementTree.Element for the root. Example format
    of the data:
    <Devices>
        <Device pciName="/dev/mst/mt52100_pci_cr0" type="Spectrum" psid="MT_2860111033" partNumber="MSN2410-CxxxO_Ax">
          <Versions>
            <FW current="13.0350.0418" available="N/A"/>
          </Versions>
          <GUIDs port1="7cfe900300e8db00" port2=" "/>
          <Status>No matching image found</Status>
          <Description>Spectrum based 25GbE/100GbE 1U Open Ethernet switch with ONIE; 48 SFP28 ports; 8 QSFP28 ports; x86 dual core; RoHS6</Description>
        </Device>
    </Devices>
    """
    try:
        cmd = ["/usr/bin/mlxfwmanager", "-D", fwDir, "--query-format", "xml"]
        logger.debug("Executing: " + " ".join(cmd))
        fwStr = subprocess.check_output(cmd)
    except subprocess.CalledProcessError as e:
        logger.error("Unable to get firmware information. Command: " + \
                     " ".join(cmd) + " Error: " + str(e))
        fwStr = "<X/>"

    try:
        logger.debug("Converting string to ElementTree: " + fwStr)
        root = ET.fromstring(fwStr)
    except ET.ParseError:
        logger.error("The data returned by mlxfwmanager is not valid XML: " + fwStr)
        root = ET.Element("X")

    return root

def FwProgrammingNeeded(fwInfoXml, chip):
    """
    Go through the XML information for the specified chip and determine if the
    firmware version in the chip is not the same as in the firmware directory.
    """
    for device in fwInfoXml:
        if device.tag == "Device" and device.attrib.get("type") == "Spectrum":
            version = device.find('Versions')
            if version is not None:
                fw = version.find("FW")
                if fw is not None:
                    current = fw.attrib.get("current")
                    available = fw.attrib.get("available")
                    logger.info("FW versions are: current = %s, available = %s" %(str(current), str(available)))
                    if current and available and current != available:
                        return True
                else:
                    logger.debug("Versions element does not have a FW sub-element")
            else:
                logger.debug("Device element does not have a Versions sub-element")

    return False

def UpgradeFirmware(fwDir):
    """
    Upgrade the firmware from the specified directory
    """
    try:
        DEVNULL = open(os.devnull, 'wb')
        cmd = ["/usr/bin/mlxfwmanager", "-u", "-D", fwDir, "-y", "-f", "--log"]
        logger.debug("Executing: " + " ".join(cmd))
        if subprocess.call(cmd, stdout=DEVNULL, stderr=DEVNULL):
            logger.error("There was an error upgrading the firmware. Check /var/log/mlxfwmanager")
            return 1
    except subprocess.CalledProcessError as e:
        logger.error("Unable to upgrade the firmware. Command: " + \
                     " ".join(cmd) + " Error: " + str(e))
        return 1
    return 0


#-------------------------------------------------------------------------------
#
#   Main program entry point
#
#-------------------------------------------------------------------------------

def main():
    # Parse the command line parameters and initialize logging.
    cmdargs = ParseCmdLine()
    InitLogging(cmdargs)
    logger.info("Beginning execution of mlxfwupgrade")
    status = 1

    # XXX: In the future more elaborate mechanism will be needed to support
    # multiple ASIC types. For now, hardcoded to "Spectrum"
    fwInfoXml = GetFwInfo(cmdargs.fwpath + "/Spectrum")
    if FwProgrammingNeeded(fwInfoXml, "Spectrum"):
        logger.info("Programming the firmware")
        status = UpgradeFirmware(cmdargs.fwpath + "/Spectrum")

    # Get out of here
    logger.info("Execution of mlxfwupgrade is complete")
    sys.exit(status)


#-------------------------------------------------------------------------------
#
#   Are we being executed or imported?
#
#-------------------------------------------------------------------------------

if __name__ == '__main__':
    main()

