#! /usr/bin/python
# Copyright 2015 Cumulus Networks LLC, all rights reserved

try:
    import exceptions
    import os
    import sys
    import re
    import subprocess
    import imp
    import argparse
    import pprint
    import poe_base
    import json
except ImportError, e:
    raise ImportError (str(e) + "- required module not found")

POECTL_MAJOR_VERSION = 1
POECTL_MINOR_VERSION = 1

g_is_root=False

poe_banner_str = 'Power Over Ethernet (POE) display and configuration tool v%u.%02u' % \
                 (POECTL_MAJOR_VERSION, POECTL_MINOR_VERSION)

class Parse_Error(Exception):
    def __init__(self, msg):
        self.msg = msg
    def __str__(self):
        return self.msg

platform_detect = '/usr/bin/platform-detect'
platform_root = '/usr/share/cumulus-platform'
base_root = '/usr/lib/python2.7/dist-packages'

def get_poe():
    #
    # determine the platform
    #
    try:
        ph = subprocess.Popen((platform_detect), stdout=subprocess.PIPE,
                              shell=False, stderr=subprocess.STDOUT)
        cmdout = ph.communicate()[0]
        ph.wait()
    except OSError, e:
        raise OSError("cannot detect platform")

    [platform, model] = cmdout.rstrip('\n').split(',')
    platform_path = '/'.join([platform_root, platform, model])
    #
    # load the poe class file and instantiate the object
    #
    try:
        m = imp.load_source('poe','/'.join([platform_path, 'bin/poe.py']))
    except IOError:
        m = None
    if not m:
        try:
            name = '/'.join([base_root, 'poe_base.py'])
            m = imp.load_source('poe',name)
        except IOError:
            raise IOError("cannot load module: " + name)
    return getattr(m, 'poe')()

def main():
    global g_is_root
    g_is_root = (os.geteuid() == 0)

    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--info", nargs='?', const='swp1-swp48', type=str)
    if g_is_root:
        parser.add_argument("-e", "--enable-ports", type=str)
        parser.add_argument("-d", "--disable-ports", type=str)
        parser.add_argument("-r", "--reset-ports", type=str)
        parser.add_argument("-p", "--priority", type=str, nargs=2)
    parser.add_argument("-s", "--system", action="store_true")
    parser.add_argument("-j", "--json-output", action="store_true")
    parser.add_argument("-v", "--version", action="store_true")
    if g_is_root:
        parser.add_argument("--debug", action="store_true")
        parser.add_argument("--verbose", action="store_true")
        parser.add_argument("--diag-info", action="store_true")
        parser.add_argument("--save", action="store_true")
        parser.add_argument("--load", action="store_true")
    parser.print_help = print_help
    parser.format_usage = format_usage
    args = parser.parse_args()

    ''' check for switches and parse port lists
    '''
    switches = []
    ports = []
    try:
        if args.info:
            switches.append('--info')
            ports = parse_ports(args.info)
        if args.system:
            switches.append('--system')
        if args.version:
            switches.append('--version')
        if g_is_root:
            if args.enable_ports:
                switches.append('--enable-ports')
                ports = parse_ports(args.enable_ports)
            if args.disable_ports:
                switches.append('--disable-ports')
                ports = parse_ports(args.disable_ports)
            if args.reset_ports:
                switches.append('--reset-ports')
                ports = parse_ports(args.reset_ports)
            if args.priority:
                switches.append('--priority')
                ports = parse_ports(args.priority[0])
                priority = parse_priority(args.priority[1])
            if args.save:
                switches.append('--save')
            if args.load:
                switches.append('--load')
            if args.diag_info:
                switches.append('--diag-info')
        if len(switches) > 1:
            raise Parse_Error('incompatible switches specified %s' % (str(switches)))
        if len(switches) == 0:
            args.system = True
    except Parse_Error, e:
        print 'Input error: %s' % (str(e))
        sys.exit(1)

    poe = get_poe()
    poe_params = poe.get_poe_params()
    if g_is_root:
        if args.verbose:
            poe.log.set_level(poe_base.POE_VERBOSITY_VERBOSE)
        if args.debug:
            poe.log.set_level(poe_base.POE_VERBOSITY_DEBUG)

    if not poe_params.supported:
        if args.version:
            print poe_banner_str
        print 'This platform does not support POE'
        sys.exit(1)

    if len(ports):
        for port in ports:
            if not port in poe_params.ports:
                print "Input error: invalid port in port list (swp%u)" % (port+1)
                sys.exit(1)

    if args.version:
        show_version(poe, args.json_output)
        sys.exit(0)

    if not poed_running():
        print 'Error: poed daemon is not running.'
        sys.exit(1)

    if args.info:
        show_ports(poe, ports, args.json_output)
        sys.exit(0)
    elif args.system:
        show_system(poe, args.json_output)
        sys.exit(0)

    if not g_is_root:
        ''' the remaining commands can only be run by root user
        '''
        raise RuntimeError("must be root to run")

    if args.enable_ports:
        enable_ports(poe, ports)
    elif args.disable_ports:
        disable_ports(poe, ports)
    elif args.reset_ports:
        reset_ports(poe, ports)
    elif args.priority:
        set_priority(poe, ports, priority)
    elif args.save:
        poe.save_running_configuration()
    elif args.load:
        poe.load_saved_configuration()
    elif args.diag_info:
        show_diag_info(poe, args.json_output)
        sys.exit(0)

#--------------------
#
# check to see if an instance of poed is running
#
def poed_running():
    poed_name='poed'
    poed_pid_file = "/var/run/%s.pid" % poed_name
    try:
        if not os.path.isfile(poed_pid_file):
            return False
        oldpid = re.findall('\D*(\d+).*', (file(poed_pid_file, 'r').readline()))[0]
        if not os.path.exists('/proc/%s' % oldpid):
            return False
        if poed_name not in file('/proc/%s/cmdline' % oldpid, 'r').readline():
            return False
        return True
    except Exception as inst:
        print "unable to validate pidfile %s: %s" % (poed_pid_file, str(inst))
    return False

def print_help():

    print poe_banner_str
    print
    print 'optional arguments:'
    print '  -h, --help                           Show this help message and exit'
    print '  -i, --port-info PORT_LIST            Return detailed information for the specified ports.'
    print '                                         eg: -i swp1-swp5,swp10'
    if g_is_root:
        print '  -p, --priority PORT_LIST PRIORITY    Set priority for the specified ports:'
        print '                                         low, high, critical'
        print '  -d, --disable-ports PORT_LIST        Disable POE operation on the specified ports.'
        print '  -e, --enable-ports PORT_LIST         Enable POE operation on the specified ports.'
        print '  -r, --reset-ports PORT_LIST          Perform hardware reset on the specified ports.'
    print '  -s, --system                         Return POE status for the entire switch'
    print '  -j, --json                           Return output in json format'
    print '  -v, --version                        Display version info'
    if g_is_root:
        print '  --save                               Save the current configuration. The saved configuration'
        print '                                         is automatically loaded on system boot.'
        print '  --load                               Load and apply the saved configuration.'

def format_usage():
    if g_is_root:
        return 'usage: %s [-i PORT_LIST] [-e PORT_LIST] [-d PORT_LIST] [-r PORT_LIST]\n' \
               '                [-p PRIORITY PORT_LIST] [--save] [--load] [-s] [-j] [-v]\n' % (sys.argv[0][sys.argv[0].rfind('/')+1:])
    else:
        return 'usage: %s [-i PORT_LIST] [-s] [-j] [-v]\n' % (sys.argv[0][sys.argv[0].rfind('/')+1:])

def show_version(poe, use_json):
    print poe_banner_str
    if not g_is_root:
        return

    result = poe.get_poe_version_info()
    if use_json:
        print json.dumps(result)
    else:
        print 'Hardware version: %s' % (result.get('ver_str'))

def show_ports(poe, ports, use_json):
    '''
    What gets displayed by the non-root user is less than what the root-user sees
    '''
    try:
        result = poe.get_running_state()
        #print result
        if not result:
            print 'Error: unable to get POE status. Is poed daemon running?'
            return
        port_results = result.get('ports')
        extra_results = {}
        display_results = {}
        if g_is_root:
            extra_results = poe.get_port_extra_status(ports)
        for port_num in ports:
            extra = extra_results.get(port_num, {})
            extra.update(port_results.get(port_num, {}))
            display_results[port_num] = extra
        if use_json:
            print json.dumps(display_results)
        else:
            if g_is_root:
                print 'Port          Status               LLDP      Priority     PD type     PD class   Voltage   Current     Power  '
                print '-----   --------------------   -----------   --------   -----------   --------   -------   -------   ---------'
            else:
                print 'Port          Status               LLDP        PD type  '
                print '-----   --------------------   -----------   -----------'
            for port in ports:
                state = display_results.get(port)
                if g_is_root:
                    out_str =  '%-5s   %-20s   %-11s   %-8s   %-11s     %-4s     %5s V   %4s mA    %7s' % ('%s' % (state.get('swp')), get_status_string(state), state.get('lldp_status'),
                                                          state.get('priority'), state.get('pd_type'), state.get('pd_class'), state.get('voltage', 0), state.get('current', 0), '%s W' % (state.get('power', 0)))
                else:
                    out_str =  '%-5s   %-20s   %-11s   %-11s' % ('%s' % (state.get('swp')), get_status_string(state), state.get('lldp_status'), state.get('pd_type'))
                print out_str
    except Exception, e:
        print str(e)

def show_system(poe, use_json):
    try:
        result = poe.get_running_state()
        if not result:
            print 'Error: unable to get POE status. Is poed daemon running?'
            return
        ports_result = result.get('ports', {})
        connected_ports = []
        error_ports = []
        if use_json:
            print json.dumps(result)
        else:
            for port in ports_result:
                if ports_result.get(port).get('hw_status') == 'connected':
                    connected_ports.append(ports_result.get(port).get('swp'))
                if ports_result.get(port).get('status') == 'fault':
                    error_ports.append(ports_result.get(port).get('swp'))
            consumed = float(result.get('sys_consumed_power', 0))
            limit = float(result.get('sys_power_limit', 0))
            print 'System power:'
            print '  Total:     %6s W' % ('%.1f' % limit)
            print '  Used:      %6s W' % ('%.1f' % consumed)
            print '  Available: %6s W' % ('%.1f' % (limit - consumed))
            print 'Connected ports:'
            if len(connected_ports):
                port_str = str(connected_ports).strip('[]').replace("'", '')
            else:
                port_str = 'None'
            print '  %s' % (port_str)
            if len(error_ports):
                port_str = str(error_ports).strip('[]').replace("'", '')
                print 'Error ports:'
                print '  %s' % (port_str)
    except Exception, e:
        print str(e)

def enable_ports(poe, ports):
    try:
        poe.set_port_enable_poe(ports)
    except Exception, e:
        print str(e)

def disable_ports(poe, ports):
    try:
        poe.set_port_disable_poe(ports)
    except Exception, e:
        print str(e)

def reset_ports(poe, ports):
    try:
        poe.reset_ports(ports)
    except Exception, e:
        print str(e)

def set_priority(poe, ports, level):
    try:
        poe.set_port_priority(ports, level)
    except Exception, e:
        print str(e)

def show_diag_info(poe, use_json):
    try:
        result = poe.get_poe_diag_info()
        print json.dumps(result)
    except Exception, e:
        print str(e)

def get_status_string(status_result):
    val = status_result.get('hw_status')
    if val == 'fault':
        val = status_result.get('error_str')
    return val

def parse_ports(port_list):
    ports = []
    items = port_list.split(',')
    parse_error = 'could not parse PORT_LIST'
    for item in items:
        match = re.match('swp(\d+)-swp(\d+)', item)
        if match:
            start = int(match.group(1))-1
            end = int(match.group(2))-1
            if start < 0 or end < start:
                raise Parse_Error(parse_error)
            ports += range(start, end+1)
        else:
            match = re.match('swp(\d+)', item)
            if match and item.find('-') < 0:
                port = int(match.group(1))-1
                if port < 0:
                    raise Parse_Error(parse_error)
                ports.append(port)
            else:
                raise Parse_Error(parse_error)
        if not match:
            raise Parse_Error(parse_error)
    if len(ports) == 0:
        raise Parse_Error(parse_error)
    return ports

def parse_priority(priority_str):
    priority_str = priority_str.lower()
    if priority_str == 'low':
        priority = poe_base.PORT_CFG_PRIORITY_LOW
    elif priority_str == 'high':
        priority = poe_base.PORT_CFG_PRIORITY_HIGH
    elif priority_str == 'critical':
        priority = poe_base.PORT_CFG_PRIORITY_CRITICAL
    else:
        raise Parse_Error('could not parse priority level "%s"' % (priority_str))
    return priority

if __name__ == '__main__':
    try:
        exit(main())
    except KeyboardInterrupt:
        sys.stderr.write("\nInterrupted\n")
        exit(1)
    except (RuntimeError, OSError, IOError), errstr:
        sys.stderr.write("%s : ERROR : %s\n" % (sys.argv[0], str(errstr)))
        exit(1)
