#!/usr/share/venvs/netq-apps/bin/python
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited
#



import subprocess
import atexit
import argparse
import ujson as json
import os
import socket
import sys
import signal
from io import StringIO
import struct
from daemon import DaemonContext
from socketserver import ForkingMixIn, UnixStreamServer, BaseRequestHandler
from netq_lib.common import configuration
from network_docopt import (
    get_network_docopt_info
)

from netq_lib.common import utils
from netq_lib.common.enums import (
    ConfKw,
    BgColors,
    DefaultPaths
)
from netq_apps.cmd.netq import NetqCLI
from netq_apps.lib.netq import NetQException
from netq_lib.common import netq_config
from netq_lib.common.auth import GatewayAuth

_JSON = None
_EMPTY_JSON = '[{}]'
_RC_SUCCESS = 0
_RC_FAIL = 1
_PROC_NAME = 'netqd'


class Capturing(list):

    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        sys.stdout = self._stdout


class NetQRequestHandler(BaseRequestHandler):
    '''Handle each command issued by a front end CLI'''

    def handle(self):
        # RX the request from the 'net' instance.
        self._logger = self.server._logger
        connection = self.request
        cli = self.server.cli
        data = connection.recv(4096)
        if not len(data):
            return
        try:
            argv = json.loads(data)
        except ValueError as ex:
            self._logger.exception(ex)
            return

        reply = None
        (pid, uid, gid) = self.get_pid_uid_gid(connection)
        self._logger.debug("RXed: command '%s'" % (' '.join(argv)))
        # bash model
        if '--completions' in argv:
            # the last 2 args are related to completions
            sys.argv = argv[:-2]
            (print_options, ended_with_space, argv) = (
                get_network_docopt_info(argv))
            options_style = 'bash'

        # fish model
        elif '--fish-completion' in argv:
            # For "foo sh <TAB>" argv will be
            # [foo', '--fish-completion', 'foo show ']
            # the last 2 args are related to completions
            sys.argv = argv[:-2]
            last_arg = argv[-1]

            if last_arg.endswith(' '):
                ended_with_space = True
            else:
                ended_with_space = False

            argv = last_arg.strip().split()
            print_options = True
            options_style = 'fish'
        else:
            print_options = False

        pid = os.fork()
        if pid == 0:  # child process
            try:
                # closing the socket from netq front-end unblocks the blocking call if no message is received
                if connection.recv(4096).decode() == 'SIGKILL':
                    self._logger.warning('Received KeyboardInterrupt')
                    if os.path.exists(DefaultPaths.INSTALLER_LOCK):
                        os.remove(DefaultPaths.INSTALLER_LOCK)
                    os.kill(os.getppid(), signal.SIGKILL)
            except Exception:
                pass
            finally:
                os._exit(os.EX_OK)
        else:  # parent process
            reply = []
            rc = 0
            with Capturing() as reply:
                if print_options:
                    cli.print_options(ended_with_space, options_style, argv)
                else:
                    # This is where the magic happens
                    # (the command gets executed).
                    try:
                        rc = cli.run(argv, uid, gid, connection)
                    except (Exception, NetQException, RuntimeError) as ex:
                        self._logger.error(ex)
                        if argv[-1] == 'json':
                            print((json.dumps([{'Error': '{}'.format(ex)}], indent=4)))
                        else:
                            print((utils.color_wrap(BgColors.RED, str(ex))))
                        rc = 1
            output = {'rc': rc, 'output': '\n'.join(reply)}
            self.tx_reply(connection, json.dumps(output))

    def get_pid_uid_gid(self, uds):
        """
        Obtain the effective user and group IDs of the process on the other end
        of a socket. SO_PEERCRED is used so the information returned is
        trustworthy. (We are OK with root saying they are somebody less
        privileged.)
        """
        pid = None
        uid = None
        gid = None

        try:
            # rely on pid, uid, gid being None if family is not AF_UNIX
            if uds.family != socket.AF_UNIX:
                return (pid, uid, gid)
        except AttributeError:
            pass

        try:
            credentials = uds.getsockopt(socket.SOL_SOCKET,
                                         self.server.SO_PEERCRED,
                                         struct.calcsize('3i'))
            pid, uid, gid = struct.unpack('3i', credentials)

            if pid == -1:
                pid = None

            if uid == -1:
                uid = None

            if gid == -1:
                gid = None

        except Exception as e:
            self._logger.error("socket get credentials failed: {0}".format(e))

        return (pid, uid, gid)

    def tx_reply(self, connection, reply):

        if reply is None:
            reply = ''

        if not reply.endswith('\n'):
            reply += '\n'

        # TX the reply to the 'net' instance
        try:
            connection.sendall(bytes(reply, 'utf-8'))
        except socket.error:
            self._logger.info("TX reply failed")


class NetQDaemonServer(ForkingMixIn, UnixStreamServer):
    def __init__(self, local_file, handler, netq_cli, logger):
        self.cli = netq_cli
        self.daemon = True
        self._logger = logger
        if os.path.exists(local_file):
            os.remove(local_file)
        UnixStreamServer.__init__(self, local_file, handler)

        self.SO_PASSCRED = 16
        self.SO_PEERCRED = 17

        self.socket.setsockopt(socket.SOL_SOCKET, self.SO_PASSCRED, 1)
        os.chmod(local_file, 0o777)


class NetqDaemon(object):
    def __init__(self, conf):
        self.users_with_edit = {}
        self.groups_with_edit = {}
        self.users_with_show = {}
        self.groups_with_show = {}
        self.daemon = True
        self.config = conf
        self._logger = utils.get_logger(_PROC_NAME, conf)
        self.uds = None
        self.cli = None
        self.auth = None

        try:
            self.init_auth()
            self.init_cli()
        except Exception as ex:
            self._logger.info('An exception of type %s occured' % type(ex))
            print(('An exception of type %s occured %s' % (type(ex),ex)))

    @staticmethod
    def rm_token():
        if os.path.exists('/tmp/token.aes'):
            os.remove('/tmp/token.aes')

    def sigchld_handler(self, signal, frame):
        # subprocess.check_output() always returns 0 if SIGCHLD is set to SIG_IGN
        # Since the parent has many children and since waitpid() returns
        # the PID and the exit status of any child that finishes,
        # there is no guarantee that the wait() call in subprocess.py will
        # handle the child process before waitpid() in netqd does, hence raise
        # a CalledProcessError here or else subprocess calls will silently fail

        low_byte_mask = 255
        high_byte_shift = 8
        any_child = -1
        null_signal = 0

        # prevents zombie processes
        pid, status = os.waitpid(any_child, os.WNOHANG)
        return_code, signal_num = status >> high_byte_shift, status & low_byte_mask
        if signal_num != null_signal and return_code > 0:
            output = 'netqd sigchld_handler() received a non-zero exit status'
            raise subprocess.CalledProcessError(returncode=return_code, cmd=['netq'], output=output)

    def signal_handler(self, signal, frame):
        # signal handlers are initialized before _logger is initialized in main
        # A SIGNAL could sneak through in that window.

        if self._logger:
            self._logger.info("received SIGINT or SIGTERM")
            self._logger.info('netqd shutdown initiated')
        if self.uds:
            self.uds._BaseServer__shutdown_request = True
        try:
            if os.path.exists(DefaultPaths.INSTALLER_LOCK):
                os.remove(DefaultPaths.INSTALLER_LOCK)
        except OSError:
            pass

    def reload_handler(self, signal, frame):
        self._logger.info("received SIGHUP")
        self.init_cli()

    def init_auth(self):
        netq_cli = self.config.get(ConfKw.NETQ_CLI)
        self.auth = GatewayAuth(logger=self._logger,
                                gateway_ip=netq_cli.get(ConfKw.SERVER),
                                gateway_port=netq_cli.get(ConfKw.PORT))

    def init_cli(self):
        if self.cli:
            return
        self.cli = NetqCLI(self.daemon, self._logger, self.auth)
        if self.uds:
            self.uds.cli = self.cli

    def __delpid(self):
        """ Removes the pidfile when the process exits gracefully.
        """
        try:
            os.remove('/var/run/%s.pid' % _PROC_NAME)
            NetqDaemon.rm_token()
        except OSError:
            # shutting down anyways, ignore errors
            pass

    def main(self, daemon, conf):
        self.__pidfd = utils.Pidfile('/var/run/%s.pid' % _PROC_NAME, _PROC_NAME)
        self.uds = None
        try:
            # Log the process ID upon start-up.
            try:
                self.netqd_pid = os.getpid()
                # Make sure that I am not already running.
                self.__pidfd.lock()
                self.__pidfd.write(self.netqd_pid)
                atexit.register(self.__delpid)
                self._logger.info('Starting (pid %d) ...', self.netqd_pid)
                atexit.register(self._logger.info, 'Terminating (pid %d)' % self.netqd_pid)
            except OSError as e:
                self._logger.exception(e)
                self.netqd_pid = 'Error'
                caught_exception = True

            self.server_address = '/var/run/netqd/uds'

            self.uds = NetQDaemonServer(self.server_address,
                                        NetQRequestHandler,
                                        self.cli, self._logger)
            self._logger.info('netqd started with process ID (pid) {0}.'
                              ''.format(self.netqd_pid))

            self.init_cli()
            self.daemon = daemon
            self.uds.serve_forever()
            self.uds.server_close()

        except Exception as e:
            self._logger.exception(e)
            caught_exception = True

            if self.uds:
                self.uds.server_close()
                self.uds = None

            if caught_exception:
                sys.exit(1)
            else:
                sys.exit(0)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog=_PROC_NAME,
                                     description='Netq CLI Daemon')
    parser.add_argument('-s', '--server',
                        help='Add Backend server name or IP to the configuration file')
    parser.add_argument('-d', '--daemon', help='run as a daemon',
                        action='store_true')

    args = parser.parse_args()
    conf_dict = configuration.get_config(_PROC_NAME)
    if args.server:
        servers = netq_config.load_backend_servers()
        servers = netq_config.add_backend_server(servers,args.server)
        exit(netq_config.save_backend_servers(servers))

    netqd = NetqDaemon(conf_dict)
    NetqDaemon.rm_token()

    if not os.path.exists('/var/run/netqd/'):
        os.makedirs('/var/run/netqd/', mode=0o755)

    if args.daemon:
        context = DaemonContext(
            working_directory='/var/run/netqd/',
            umask=0o022,
            signal_map={
                signal.SIGTERM: netqd.signal_handler,
                signal.SIGINT: netqd.signal_handler,
                signal.SIGHUP: netqd.reload_handler,
                signal.SIGCHLD: netqd.sigchld_handler,
            }
        )

        context.open()
        with context:
            netqd.main(True, conf_dict)

    else:
        signal.signal(signal.SIGINT, netqd.signal_handler)
        signal.signal(signal.SIGTERM, netqd.signal_handler)
        signal.signal(signal.SIGHUP, netqd.reload_handler)
        signal.signal(signal.SIGCHLD, netqd.sigchld_handler)
        netqd.main(False, conf_dict)
