#! /usr/bin/python
# Copyright 2016-2020 Cumulus Networks,  Inc., all rights reserved

try:
    import exceptions
    import os
    import sys
    import re
    import subprocess
    import imp
    import argparse
    import pprint
    import poe_base
    import json
except ImportError, e:
    raise ImportError (str(e) + "- required module not found")

POECTL_MAJOR_VERSION = 1
POECTL_MINOR_VERSION = 1

g_is_root=False

poe_banner_str = 'Power Over Ethernet (POE) display and configuration tool v%u.%02u' % \
                 (POECTL_MAJOR_VERSION, POECTL_MINOR_VERSION)

class Parse_Error(Exception):
    def __init__(self, msg):
        self.msg = msg
    def __str__(self):
        return self.msg

platform_detect = '/usr/bin/platform-detect'
platform_root = '/usr/share/cumulus-platform'
base_root = '/usr/lib/python2.7/dist-packages'

def get_poe():
    #
    # determine the platform
    #
    try:
        ph = subprocess.Popen((platform_detect), stdout=subprocess.PIPE,
                              shell=False, stderr=subprocess.STDOUT)
        cmdout = ph.communicate()[0]
        ph.wait()
    except OSError, e:
        raise OSError("cannot detect platform")

    [platform, model] = cmdout.rstrip('\n').split(',')
    platform_path = '/'.join([platform_root, platform, model])
    #
    # load the poe class file and instantiate the object
    #
    try:
        m = imp.load_source('poe','/'.join([platform_path, 'bin/poe.py']))
    except IOError:
        m = None
    if not m:
        try:
            name = '/'.join([base_root, 'poe_base.py'])
            m = imp.load_source('poe',name)
        except IOError:
            raise IOError("cannot load module: " + name)
    return getattr(m, 'poe')()

def main():
    global g_is_root
    g_is_root = (os.geteuid() == 0)
    default_ports = 'swp1-swp48'
    is_upgrade = False
    ports = []
    poe = get_poe()
    if hasattr(poe, 'default_ports'):
        ports = parse_ports(poe.default_ports)
        default_ports = poe.default_ports
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--info", nargs='?', const=default_ports, type=str)
    if g_is_root:
        parser.add_argument("-e", "--enable-ports", type=str)
        parser.add_argument("-d", "--disable-ports", type=str)
        parser.add_argument("-r", "--reset-ports", type=str)
        parser.add_argument("-p", "--priority", type=str, nargs=2)
        parser.add_argument("-u", "--upgrade", action='store')
    parser.add_argument("-a", "--all", action="store_true")
    parser.add_argument("-s", "--system", action="store_true")
    parser.add_argument("-j", "--json-output", action="store_true")
    parser.add_argument("-v", "--version", action="store_true")
    if g_is_root:
        parser.add_argument("--debug", action="store_true")
        parser.add_argument("--verbose", action="store_true")
        parser.add_argument("--diag-info", action="store_true")
        parser.add_argument("--save", action="store_true")
        parser.add_argument("--load", action="store_true")
    parser.print_help = print_help
    parser.format_usage = format_usage
    args = parser.parse_args()

    ''' check for switches and parse port lists
    '''
    switches = []
    try:
        if args.info:
            switches.append('--info')
            ports = parse_ports(args.info)
        if args.system:
            switches.append('--system')
        if args.all:
            switches.append('--all')
            ports = parse_ports(default_ports)
        if args.version:
            switches.append('--version')
        if g_is_root:
            if args.enable_ports:
                switches.append('--enable-ports')
                ports = parse_ports(args.enable_ports)
            if args.disable_ports:
                switches.append('--disable-ports')
                ports = parse_ports(args.disable_ports)
            if args.reset_ports:
                switches.append('--reset-ports')
                ports = parse_ports(args.reset_ports)
            if args.priority:
                switches.append('--priority')
                ports = parse_ports(args.priority[0])
                priority = parse_priority(args.priority[1])
            if args.save:
                switches.append('--save')
            if args.load:
                switches.append('--load')
            if args.diag_info:
                switches.append('--diag-info')
            if args.upgrade:
                switches.append('--upgrade')
                is_upgrade = True
        if len(switches) > 1:
            raise Parse_Error('incompatible switches specified %s' % (str(switches)))
        if len(switches) == 0:
            args.system = True
    except Parse_Error, e:
        print 'Input error: %s' % (str(e))
        sys.exit(1)

    poe_params = poe.get_poe_params()
    if g_is_root:
        if args.verbose:
            poe.log.set_level(poe_base.POE_VERBOSITY_VERBOSE)
        if args.debug:
            poe.log.set_level(poe_base.POE_VERBOSITY_DEBUG)

    if not poe_params.supported:
        if args.version:
            print poe_banner_str
        print 'This platform does not support POE'
        sys.exit(1)

    if len(ports):
        for port in ports:
            if not port in poe_params.ports:
                print "Input error: invalid port in port list (swp%u)" % (port+1)
                sys.exit(1)

    if args.version:
        show_version(poe, args.json_output)
        sys.exit(0)

    if not poed_running() and not is_upgrade:
        print 'Error: poed daemon is not running.'
        sys.exit(1)

    if args.info:
        show_ports(poe, ports, args.json_output)
        sys.exit(0)
    elif args.system:
        show_system(poe, args.json_output)
        sys.exit(0)
    elif args.all:
        show_all(poe, ports, args.json_output)
        sys.exit(0)

    if not g_is_root:
        ''' the remaining commands can only be run by root user
        '''
        raise RuntimeError("must be root to run")

    if args.enable_ports:
        enable_ports(poe, ports)
    elif args.disable_ports:
        disable_ports(poe, ports)
    elif args.reset_ports:
        reset_ports(poe, ports)
    elif args.priority:
        set_priority(poe, ports, priority)
    elif args.save:
        poe.save_running_configuration()
    elif args.load:
        poe.load_saved_configuration()
    elif args.upgrade:
        upgrade_controller_firmware(poe, args.upgrade)
    elif args.diag_info:
        show_diag_info(poe, args.json_output)
        sys.exit(0)

#--------------------
#
# check to see if an instance of poed is running
#
def poed_running():
    poed_name='poed'
    poed_pid_file = "/var/run/%s.pid" % poed_name
    try:
        if not os.path.isfile(poed_pid_file):
            return False
        oldpid = re.findall('\D*(\d+).*', (file(poed_pid_file, 'r').readline()))[0]
        if not os.path.exists('/proc/%s' % oldpid):
            return False
        if poed_name not in file('/proc/%s/cmdline' % oldpid, 'r').readline():
            return False
        return True
    except Exception as inst:
        print "unable to validate pidfile %s: %s" % (poed_pid_file, str(inst))
    return False

def print_help():

    print poe_banner_str
    print
    print 'optional arguments:'
    print '  -h, --help                           Show this help message and exit'
    print '  -i, --port-info PORT_LIST            Return detailed information for the specified ports.'
    print '                                         eg: -i swp1-swp5,swp10'
    if g_is_root:
        print '  -p, --priority PORT_LIST PRIORITY    Set priority for the specified ports:'
        print '                                         low, high, critical'
        print '  -d, --disable-ports PORT_LIST        Disable POE operation on the specified ports.'
        print '  -e, --enable-ports PORT_LIST         Enable POE operation on the specified ports.'
        print '  -r, --reset-ports PORT_LIST          Perform hardware reset on the specified ports.'
        print '  -u, --upgrade FILE_PATH              Upgrade firmware on controller'
    print '  -s, --system                         Return POE status for the entire switch'
    print '  -a, --all                            Return POE status for system and detailed information for ports.'
    print '  -j, --json                           Return output in json format'
    print '  -v, --version                        Display version info'
    if g_is_root:
        print '  --save                               Save the current configuration. The saved configuration'
        print '                                         is automatically loaded on system boot.'
        print '  --load                               Load and apply the saved configuration.'

def format_usage():
    if g_is_root:
        return 'usage: %s [-i PORT_LIST] [-e PORT_LIST] [-d PORT_LIST] [-r PORT_LIST]\n' \
               '                [-p PRIORITY PORT_LIST] [-u FILE_PATH] [--save] [--load] [-a] [-s] [-j] [-v]\n' % (sys.argv[0][sys.argv[0].rfind('/')+1:])
    else:
        return 'usage: %s [-i PORT_LIST] [-a] [-s] [-j] [-v]\n' % (sys.argv[0][sys.argv[0].rfind('/')+1:])

def show_version(poe, use_json):
    print poe_banner_str
    if not g_is_root:
        return

    result = poe.get_poe_version_info()
    if use_json:
        clean_result(result)
        print json.dumps(result, indent=4)
    else:
        print 'Hardware version: %s' % (result.get('ver_str'))

def get_results(poe, ports):
    result = poe.get_running_state()
    if not result:
        print 'Error: unable to get POE status. Is poed daemon running?'
        return
    port_results = result.get('ports')
    extra_results = {}
    display_results = {}
    if g_is_root:
        extra_results = poe.get_port_extra_status(ports)
    for port_num in ports:
        extra = extra_results.get(port_num, {})
        extra.update(port_results.get(port_num, {}))
        extra['allocated_power'] = get_allocated_power(extra, poe)
        display_results[port_num] = extra
    # Fix ports dictionary to have keys that start with 1 to match 'swp'
    # port numbering.
    fixed_display_results = dict([(k + 1, v) for k, v in display_results.items()])
    return fixed_display_results

def show_all(poe, ports, use_json):
    '''
    Show both systemwide and port specific information
    '''
    try:
        result = poe.get_running_state()
        display_results = {}
        display_results['ports'] = get_results(poe, ports)
        display_results['sys_consumed_power'] = result.get('sys_consumed_power',0)
        display_results['sys_power_limit'] = result.get('sys_power_limit',0)
        if use_json:
            clean_result(display_results)
            print json.dumps(display_results, indent=4)
        else:
            ports_result = result.get('ports', {})
            print_system(ports_result, result)
            print_ports(ports, display_results['ports'])

    except Exception, e:
        print str(e)

def print_ports(ports, display_results):
    '''
    Print the port information for both root and non-root users
    '''
    if g_is_root:
        print 'Port          Status           Allocated   Priority      PD type       PD class   Voltage   Current     Power  '
        print '-----   --------------------   ---------   --------   --------------   --------   -------   -------   ---------'
    else:
        print 'Port          Status             PD type  '
        print '-----   --------------------   --------------'
    for port in ports:
        state = display_results.get(port + 1)
        if g_is_root:
            out_str =  '%-5s   %-20s  %11s   %-8s  %-14s       %-4s   %5s V   %4s mA    %7s' % \
                      ('%s' % (state.get('swp')),
                       state.get('hw_status'), '%.1f W  ' % (state.get('allocated_power', 0)/1000.0),
                       state.get('priority'), state.get('pd_type'),
                       state.get('pd_class'), state.get('voltage', 0),
                       state.get('current', 0), '%s W' % (state.get('power', 0)))
        else:
            out_str =  '%-5s   %-20s   %-11s' % \
                      ('%s' % (state.get('swp')),
                       state.get('hw_status'), state.get('pd_type'))
        print out_str

def show_ports(poe, ports, use_json):
    '''
    What gets displayed by the non-root user is less than what the root-user sees
    '''
    try:
        display_results = get_results(poe, ports)
        if use_json:
            print json.dumps(display_results, indent=4)
        else:
            print_ports(ports, display_results)

    except Exception, e:
        print str(e)

def print_system(ports_result, result):
    connected_ports = []
    error_ports = []
    for port in ports_result:
        hw_status = ports_result.get(port).get('hw_status')
        if hw_status == 'connected':
            connected_ports.append(ports_result.get(port).get('swp'))
        if hw_status == 'fault' or hw_status == 'power-denied':
            error_ports.append(ports_result.get(port).get('swp'))
    consumed = float(result.get('sys_consumed_power', 0))
    limit = float(result.get('sys_power_limit', 0))
    print 'System power:'
    print '  Total:     %6s W' % ('%.1f' % limit)
    print '  Used:      %6s W' % ('%.1f' % consumed)
    print '  Available: %6s W' % ('%.1f' % (limit - consumed))
    print 'Connected ports:'
    if len(connected_ports):
        port_str = str(connected_ports).strip('[]').replace("'", '').replace(' ', '')
    else:
        port_str = 'None'
    print '  %s' % (port_str)
    if len(error_ports):
        port_str = str(error_ports).strip('[]').replace("'", '').replace(' ', '')
        print 'Error ports:'
        print '  %s' % (port_str)

def show_system(poe, use_json):
    try:
        result = poe.get_running_state()
        if not result:
            print 'Error: unable to get POE status. Is poed daemon running?'
            return
        ports_result = result.get('ports', {})
        if use_json:
            clean_result(result)
            display_results = dict([(k + 1, v) for k, v in result['ports'].items()])
            print json.dumps(display_results, indent=4)
        else:
            print_system(ports_result, result)
    except Exception, e:
        print str(e)

def enable_ports(poe, ports):
    '''
    enabling a port that is already enabled can cause it to
    stop powering momentarily, so we only enable the port if it is
    not currently connected.
    '''
    port_states = poe.get_running_state().get('ports', {})
    enable_ports = []
    for port in ports:
        port_state = port_states.get(port, {})
        if port_state.get('hw_status', '') != 'connected':
            enable_ports.append(port)
    try:
        poe.set_port_enable_poe(enable_ports)
    except Exception, e:
        print str(e)

def disable_ports(poe, ports):
    try:
        poe.set_port_disable_poe(ports)
    except Exception, e:
        print str(e)

def reset_ports(poe, ports):
    try:
        poe.reset_ports(ports)
    except Exception, e:
        print str(e)

def set_priority(poe, ports, level):
    try:
        poe.set_port_priority(ports, level)
    except Exception, e:
        print str(e)

def show_diag_info(poe, use_json):
    try:
        result = poe.get_poe_diag_info()
        print json.dumps(result, indent=4)
    except Exception, e:
        print str(e)

def parse_ports(port_list):
    ports = []
    items = port_list.split(',')
    parse_error = 'could not parse PORT_LIST'
    for item in items:
        match = re.match('swp(\d+)-swp(\d+)', item)
        if match:
            start = int(match.group(1))-1
            end = int(match.group(2))-1
            if start < 0 or end < start:
                raise Parse_Error(parse_error)
            ports += range(start, end+1)
        else:
            match = re.match('swp(\d+)', item)
            if match and item.find('-') < 0:
                port = int(match.group(1))-1
                if port < 0:
                    raise Parse_Error(parse_error)
                ports.append(port)
            else:
                raise Parse_Error(parse_error)
        if not match:
            raise Parse_Error(parse_error)
    if len(ports) == 0:
        raise Parse_Error(parse_error)
    return ports

def clean_result(result):
    result.pop('error_str', None)
    result.pop('state', None)

def parse_priority(priority_str):
    priority_str = priority_str.lower()
    if priority_str == 'low':
        priority = poe_base.PORT_CFG_PRIORITY_LOW
    elif priority_str == 'high':
        priority = poe_base.PORT_CFG_PRIORITY_HIGH
    elif priority_str == 'critical':
        priority = poe_base.PORT_CFG_PRIORITY_CRITICAL
    else:
        raise Parse_Error('could not parse priority level "%s"' % (priority_str))
    return priority

def get_allocated_power(port_status, poe):
    power = 0
    try:
        if port_status.get('hw_status') == 'connected':
            lldp_req = int(port_status.get('lldp_requested_power', 0))
            lldp_alloc = int(port_status.get('lldp_allocated_power', 0))
            if lldp_req != 0 and lldp_alloc != 0:
                power = lldp_alloc
            else:
                power = int(poe.get_non_lldp_allocated_power(port_status.get('pd_class'), port_status.get('power')))
    except:
        pass
    return power

def upgrade_controller_firmware(poe, firmware_path):
    if not g_is_root:
        return
    if not os.path.isfile(firmware_path):
        print 'Firmware file does not exist'
        return
    if poed_running():
        print 'POE daemon is up, please do systemctl stop cumulus-poe before doing an upgrade.'
        return
    try:
        poe.upgrade_firmware(firmware_path)
    except Exception as upgrade_e:
        print 'ERROR: %s' % upgrade_e.message
    print 'Upgrade firmware done'

if __name__ == '__main__':
    try:
        exit(main())
    except KeyboardInterrupt:
        sys.stderr.write("\nInterrupted\n")
        exit(1)
    except (RuntimeError, OSError, IOError), errstr:
        sys.stderr.write("%s : ERROR : %s\n" % (sys.argv[0], str(errstr)))
        exit(1)
