#!/usr/bin/env python
#
# Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited
#

import argparse
import errno
import sys

from python_sdk_api.sx_api import *

def get_handle():
    rc, handle = sx_api_open(None)
    print ("sx_api_open handle:0x%x , rc %d " % (handle, rc))
    if (rc != SX_STATUS_SUCCESS):
        sys.stderr.write("Failed to open api handle.\nPlease check that SDK is running.")
        sys.exit(errno.EACCES)
    return handle


def get_all_ports():
    port_num_max = 128
    port_attributes_list = new_sx_port_attributes_t_arr(port_num_max)
    port_cnt_p = new_uint32_t_p()
    uint32_t_p_assign(port_cnt_p, port_num_max)
    handle = get_handle()

    rc = sx_api_port_device_get(handle, 1, 0, port_attributes_list, port_cnt_p)
    sx_api_close(handle)
    if rc != 0:
        raise Exception('Failed to get ports from SDK')
    ports = []
    for i in range(uint32_t_p_value(port_cnt_p)):
        port_attributes = sx_port_attributes_t_arr_getitem(port_attributes_list, i)
        ports.append("0x%x" % port_attributes.log_port)
    print "\n"
    print ",".join(ports)


def get_chip_type():
    device_info_cnt_p = new_uint32_t_p()
    uint32_t_p_assign(device_info_cnt_p, 1)
    device_info_cnt = uint32_t_p_value(device_info_cnt_p)
    device_info_list_p = new_sx_device_info_t_arr(device_info_cnt)
    handle = get_handle()

    rc = sx_api_port_device_list_get(handle,
                                     device_info_list_p,
                                     device_info_cnt_p)
    sx_api_close(handle)
    if rc != SX_STATUS_SUCCESS:
        raise Exception("sx_api_port_device_list_get failed, rc = %d" % (rc))
    
    device_info = sx_device_info_t_arr_getitem(device_info_list_p, 0)
    print "\n"
    print device_info.dev_type

def get_ecmp_count():
    ecmp_cnt_p = new_uint32_t_p()
    uint32_t_p_assign(ecmp_cnt_p, 0)
    handle = get_handle()

    rc = sx_api_router_ecmp_iter_get(handle, SX_ACCESS_CMD_GET, 0, None, None, ecmp_cnt_p)

    sx_api_close(handle)
    if rc != SX_STATUS_SUCCESS:
        raise Exception("sx_api_port_device_list_get failed, rc = %d" % (rc))

    ecmp_cnt = uint32_t_p_value(ecmp_cnt_p)
    print "\n"
    print ecmp_cnt

def setup_arg_parser():
    cfg = argparse.ArgumentParser()
    cfg.add_argument('--get-logical-ports', help='Get all logical port ids (lids)', default=False, action='store_true')
    cfg.add_argument('--get-chip-type', help='Get chip type', default=False, action='store_true')
    cfg.add_argument('--get-ecmp-count', help='Get ecmp object count', default=False, action='store_true')
    return cfg


def main(args):
    parser = setup_arg_parser()
    cfgs = parser.parse_args(args)
    
    if cfgs.get_logical_ports:
        get_all_ports()
    elif cfgs.get_chip_type:
        get_chip_type()
    elif cfgs.get_ecmp_count:
        get_ecmp_count()
    return 0

if __name__ == '__main__':
    rc = 1
    try:
        rc = main(sys.argv[1:])
    except Exception as e:
        sys.stderr.write(str(e))
    sys.exit(rc)
