#!/usr/share/venvs/netq-apps/bin/python
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited
#
"""
This will be called a lot so it needs to be fast. Our job is very simple:
- connect to netqd
- TX sys.argv to netqd
- RX the reply from netqd
- print the reply

netqd does all of the heavy lifting.

"""

from json import dumps as json_dumps, loads as json_loads
from network_docopt import get_bash_prompt, get_network_docopt_info
from select import select
from socket import socket as socket_socket, error as socket_error
from socket import AF_UNIX, SOCK_STREAM
from subprocess import Popen, STDOUT
from time import sleep
import os
import re
import sys
import tempfile
import getpass
from collections import OrderedDict

from netq_lib.common.utils import get_logger, subprocess_wrapper
from netq_lib.common import configuration
from netq_apps.cmd.netq import display_stream
from netq_apps.cmd.netq_config_daemonless import NetqDaemonlessCLI, is_netqd_running
from netq_apps.lib.installers.installer import ClusterNode
from netq_apps.lib.installers.master import MasterNode

# "netq" uses netqd for config and logging
_PROC_NAME = 'netqd'


def posix_get_window_size():
    """
    Return (width, height) of the console terminal on a POSIX system.
    Return (0, 0) upon raising IOError, which happens when no console is
    allocated.
    """
    # See README.txt for reference information.
    # http://www.kernel.org/doc/man-pages/online/pages/man4/tty_ioctl.4.html
    from fcntl import ioctl
    from termios import TIOCGWINSZ
    from array import array

    winsize = array("H", [0] * 4)

    try:
        ioctl(sys.stdout.fileno(), TIOCGWINSZ, winsize)
    except IOError:
        # For example IOError: [Errno 25] Inappropriate ioctl for device
        # when reply is redirected.
        pass

    return (winsize[1], winsize[0])


def rx_reply(sock, argv):
    """
    RX the reply, and close the socket.
    """
    if 'bootstrap' in argv:
        timeout = 60 * 30  # 30 minutes for bootstrapping node for clustering
    elif 'install' in argv or 'upgrade' in argv:
        timeout = 60 * 120  # 120 minutes for install or upgrade
    else:
        timeout = 60 * 15  # default timeout for all other commands
    envelope = ''
    sock.setblocking(0)
    stream_data = OrderedDict()
    print_data = OrderedDict()
    header_dict = dict()
    header_module = 'default'
    line = ''
    first = True
    last = False
    final_rc = 0
    streaming = False
    printing = False
    is_bootstrapped = ClusterNode.is_bootstrapped()
    is_install = 'install' in argv or 'upgrade' in argv

    while True:
        try:
            if is_bootstrapped and is_install:
                timeout = 60
                master_ip = ''
                with open(ClusterNode.MASTER_IP_FILE, 'r') as fd:
                    master_ip = fd.read()
                ready = select([sock], [], [], timeout)
                if not ready[0]:
                    if MasterNode.ping(master_ip):
                        continue
                    else:
                        print(('Error: web socket connection broken to master {}'.format(master_ip)))
                        sock.close()
                        sys.exit(1)

            else:
                ready = select([sock], [], [], timeout)
                if not ready[0]:
                    print(('Error: did not receive a response within %d seconds' %
                          timeout))
                    sock.close()
                    sys.exit(1)
        except KeyboardInterrupt:
            sock.send('SIGKILL'.encode())
            sock.close()
            sys.exit(1)

        rx_data = sock.recv(4096)
        rx_data = rx_data.decode()
        if not rx_data:
            last = True
            break
        line += rx_data
        if '\r' in line or '\n' in line:
            use_lines = line.splitlines()
            if line.endswith("\r") or line.endswith("\n"):
                line = ''
            else:
                line = use_lines[-1]
                continue
        else:
            continue

        for use_line in use_lines:
            if len(use_line) < 2:
                continue
            try:
                json_data = json_loads(use_line)
                if 'rc' in json_data and json_data['rc'] != 0:
                    if 'output' in json_data:
                        print((json_data['output']))
                    sock.close()
                    sys.exit(1)
                stream_data = OrderedDict()
                print_data = OrderedDict()
                envelope = ''
                if 'data' in json_data:
                    if isinstance(json_data['data'], dict):
                        header_module = json_data['data']['module']
                        stream_data[header_module] = []
                        print_data[header_module] = []
                        header_dict = json_data['data']
                        stream_data[header_module].append(header_dict)
                        first = True
                    else:
                        for ele in json_data['data']:
                            if not stream_data:
                                stream_data[header_module] = []
                                print_data[header_module] = []
                                stream_data[header_module].append(header_dict)
                            stream_data[header_module].append(ele)
                        rc = display_reply(stream_data, print_data, envelope,
                                           first, last)
                        final_rc = rc if rc else final_rc
                        first = False
                        streaming = True
                else:
                    if not print_data:
                        print_data[header_module] = []
                    if not stream_data:
                        stream_data[header_module] = []
                    print_data[header_module].append(use_line)
                    rc = display_reply(stream_data, print_data, envelope, first,
                                       last)
                    final_rc = rc if rc else final_rc
                    if "output" in json_data and json_data["output"]:
                        printing = True
            except ValueError:
                print('Error: Parsing error in data received')
                sock.close()
                sys.exit(1)

    blacklist = ["--completions"]
    if ((not [ele for ele in argv if ele in blacklist] and streaming) or
            (not streaming and not printing and first and last)):
        if not print_data:
            print_data[header_module] = []
        if not stream_data:
            stream_data[header_module] = []
        stream_data[header_module].append(header_dict)
        rc = display_reply(stream_data, print_data, envelope, first, last)
        final_rc = rc if rc else final_rc

    upgrade_file = '/tmp/netqd_needs_upgrade'
    if ('bootstrap' not in argv) and ('install' in argv or 'upgrade' in argv):
        if os.path.exists(upgrade_file):
            try:
                with open(upgrade_file, 'r') as file:
                    pkg = file.read()
                    pkg = pkg.strip()
                    if 'netq-apps' in pkg and pkg.endswith('deb'):
                        subprocess_wrapper(['/usr/bin/_dpkg_netq_apps', upgrade_file])
                with open(upgrade_file, 'w') as file:
                    file.truncate(0)
            except Exception as exc:
                print('Upgrading netq-apps failed')
                logger.info('Upgrading netq-apps failed: {}'.format(exc))
                pass

    sock.close()
    sys.exit(final_rc)


def get_keywords(reply_with_help_text):
    """
    Extract the keywords from the reply containing help text.
    """
    keywords = []
    re_key_value = re.compile(r"""^(.*?):(.*)$""")

    for line in reply_with_help_text.splitlines():
        line = line.strip()

        if line:  # If the line is not the empty string . . .
            re_line = re_key_value.match(line)

            if re_line is not None:
                keywords.append(re_line.group(1))
            else:
                keywords.append(line)

    return keywords


def display_reply(stream_data, print_data, envelope, first_entry, last_entry):

    if not print_data:
        return                  # Should never happen

    rc = 0
    print_trailer = False
    first = False
    for key in stream_data:
        # This is to print the "," between individual modules when we're
        # dumping JSON and tere's an envelope enclosing the individual key
        # outputs. This piece is here to avoid printing a trailing "," for
        # the last key.
        if print_trailer:
            print(',')
        output = ''.join(print_data[key])
        try:
            json_out = json_loads(output)
            rc = json_out['rc']
            reply = json_out['output']
        except ValueError:
            reply = output

        pager = os.getenv('PAGER')
        is_json = False
        tab_options = 'USE_STDERR' in reply

        # Determine if the reply is json.
        if reply.startswith('{'):
            try:
                json_loads(reply)
                is_json = True
            except ValueError:
                pass  # Not json

        # Print the entire reply if it is json, if it will fit in the display,
        # or if paging is not an option.
        if (not pager or tab_options or is_json or
                len(reply.splitlines()) < posix_get_window_size()[1]):

            if tab_options:
                reply = reply.replace('USE_STDERR', '')
                stdout_env = os.getenv('NETQ_TAB_STDOUT')

                stdout_env = ((stdout_env is not None) and
                              (stdout_env == '1' or
                               stdout_env.lower() == 'true'))

                if stdout_env:
                    try:
                        sys.stdout.write(' '.join(get_keywords(reply)))
                    except IOError as e:
                        sys.stderr.write('Failed to write output due to {0}\n'.
                                         format(e))

                else:
                    # Printing TAB complete options with help text.
                    cmd = os.path.basename(sys.argv[0])  # get exec name

                    sys.stderr.write(reply)
                    argv = sys.argv[1:-2]

                    if argv:
                        ended_with_space = get_network_docopt_info(sys.argv)[1]
                        sys.stderr.write('\n{0}{1} {2}{3}'.format(
                            get_bash_prompt(), cmd, ' '.join(argv),
                            ' ' if ended_with_space else ''))
                    else:
                        sys.stderr.write('\n{0}{1} '.format(get_bash_prompt(),
                                                            cmd))

            else:  # No tab_options
                use_stdout = True

                # Use write() to avoid adding an extra newline.
                if use_stdout and reply:

                    # If the user pipes the output to something like
                    # "head -n 10" and there are more than 10 lines of output
                    # to display we will get a "Broken pipe" error because
                    # STDOUT was closed.  It is safe to silently ignore this
                    # error.
                    try:
                        sys.stdout.write(reply + '\n')
                    except IOError:
                        pass

                elif reply:
                    sys.stderr.write(reply)

        else:
            # We have more text than can fit on the screen. Write the reply to
            # a tmp file, and then point 'less' at that file. Use mkstemp for
            # for filename in case multiple users are calling "netq" at the
            # time.
            f_tmp, filename = tempfile.mkstemp()
            f_tmp.write(reply)

            # Run whatever executable "pager" references, which is usually
            # 'less'.
            Popen([pager, filename], stdout=None, stdin=None).wait()
            os.remove(filename)

        if stream_data[key]:
            try:
                if first and envelope:
                    print(('{{\n "{}": ['.format(envelope)))
                    print_trailer = True
                display_stream(stream_data[key], first_entry, last_entry)
            except (IOError, KeyboardInterrupt):
                sys.exit(rc)
        first = False

    if print_trailer:
        print(']\n}')
    return rc


if __name__ == '__main__':
    conf_dict = configuration.get_config(_PROC_NAME)
    logger = get_logger(_PROC_NAME, conf_dict)

    if is_netqd_running():

        # Connect to netqd's unix domain socket.
        sock = socket_socket(AF_UNIX, SOCK_STREAM)
        max_retries = 4
        for i in range(max_retries):
            # Four retries because when netqd is started and a netq cmd
            # issued, the command fails because netqd is still loading the
            # parser. Four retries is too much, but what the...
            # Parser currently takes 6-7 seconds to load
            try:
                sock.connect('/var/run/netqd/uds')
                break
            except socket_error:
                if i + 1 != max_retries:  # skip sleep in last attempt
                    sleep(i + 1)  # 1s, 2s, 3s
                pass
        else:
            sys.stderr.write("""
ERROR: net could not connect to netqd

Try starting netqd with:
sudo systemctl start netqd

To configure netqd to start when the box boots:
sudo systemctl enable netqd
""")
            sys.exit(1)

        # If the user enters something crazy like "netq show '' config" we must
        # remove the '', it confuses network-docopt
        sys.argv = [x for x in sys.argv if x != '']
        if '--completions' in sys.argv:
            logger.debug("user: %s command: %s" % (getpass.getuser(), ' '.join(sys.argv)))
        else:
            logger.info("user: %s command: %s" % (getpass.getuser(), ' '.join(sys.argv)))

        # TX our command line args, RX the reply, and display the reply.
        sock.send(bytes(json_dumps(sys.argv), 'utf-8'))
        rx_reply(sock, sys.argv)
    else:
        nd = NetqDaemonlessCLI(False, logger)
        (print_options, ended_with_space, sys.argv) = get_network_docopt_info(sys.argv)
        if print_options:
            nd.print_options(ended_with_space, 'bash', sys.argv)
        else:
            nd.run(sys.argv)
            # Intentionally checked again, so that we don't print this message again if restart cli is issued
            if not (sys.argv == ['netq', 'config', 'restart', 'cli']):
                print("\033[93mWARN: NetQ Daemon is not running\033[0m\n\n"+
                    "\033[90mPlease configure netq cli:\n"
                    "sudo netq config add cli (config arguments)\n\n"
                    "and, start netqd:\n"
                    "sudo netq config restart cli\033[0m")
        sys.exit(0)
