#!/usr/bin/python
#

from bcc import BPF
import time
import ctypes as ct
import subprocess, json
from pynag.Plugins import PluginHelper,ok,warning,critical,unknown
import re
import os

# yes, global variable, I know it's bad..
stuff_obj = {}
log2_index_max = 65

def raw_to_majmin(devid):
    return "%s:%s" % (os.major(devid),os.minor(devid))

def raw_to_name(devid):
    bdevs = list_block_devs()
    devno = raw_to_majmin(devid)
    for dev in bdevs:
        if dev['maj:min'] == devno:
            return dev['name']
    return "UNKNOWN"

def decode_c_struct(data, tmp, buckets, bucket_fn, bucket_sort_fn):
        f1 = data.Key._fields_[0][0]
        f2 = data.Key._fields_[2][0]
        # The above code assumes that self.Key._fields_[2][0] holds the
        # slot. But a padding member may have been inserted here, which
        # breaks the assumption and leads to chaos.
        # TODO: this is a quick fix. Fixing/working around in the BCC
        # internal library is the right thing to do.
        for k, v in data.items():
            bucket = getattr(k, f1)
            vals = tmp[bucket] = tmp.get(bucket, [0] * log2_index_max)
            slot = getattr(k, f2)
            vals[slot] = v.value
        buckets_lst = list(tmp.keys())
        for bucket in buckets_lst:
            buckets.append(bucket)


def print_json_hist(data, val_type="value", section_header="Bucket ptr",
                        section_print_fn=None, bucket_fn=None, bucket_sort_fn=None,
                        writes=False):
    if isinstance(data.Key(), ct.Structure):
        tmp = {}
        buckets = []
        decode_c_struct(data, tmp, buckets, bucket_fn, bucket_sort_fn)
        for drive in buckets:
            vals = tmp[drive]
            section_bucket = (section_header, drive)
            yield store_json_hist(vals, val_type, section_bucket, writes=writes)

    else:
        vals = [0] * log2_index_max
        for k, v in data.items():
            vals[k.value] = v.value
        yield store_json_hist(vals, val_type, writes=writes)

def store_json_hist(vals, val_type, section_bucket=None, writes=False):
    max_nonzero_idx = 0
    max_nonzero_idx = 11
    index = 1
    prev = 0
    if writes:
        for i in range(len(vals)):
            if i <= max_nonzero_idx:
                #write_key = str(section_bucket[1], "utf-8") + '_ms' + str(index) + '_writes'
                write_key = raw_to_name(section_bucket[1])+ '_ms' + str(index) + '_writes'
                stuff_obj[write_key] = int(vals[i])
                index = index * 2
                prev = index

    else:
        for i in range(len(vals)):
            if i <= max_nonzero_idx:
                #read_key = str(section_bucket[1], "utf-8") + '_ms' + str(index) + '_reads'
                read_key = raw_to_name(section_bucket[1])+ '_ms' + str(index) + '_reads'
                stuff_obj[read_key] = int(vals[i])
                index = index * 2
                prev = index
    histogram = {"date":str(time.time()).split('.')[0], "data": [], "disk": raw_to_name(section_bucket[1])}
    return histogram

def generate_empty_json(device, write=False):
    max_idx = 11
    index = 1
    i = 0
    prev = 0
    if write:
        while i <= max_idx:
            write_key = device + '_ms' + str(index) + '_writes'
            stuff_obj[write_key] = 0
            index = index * 2
            prev = index
            i+=1
    else:
        while i <= max_idx:
            read_key = device + '_ms' + str(index) + '_reads'
            stuff_obj[read_key] = 0
            index = index * 2
            prev = index
            i += 1
    histogram = {"date":str(time.time()).split('.')[0], "data": [], "disk": device}
    return histogram

def get_lsblk(short=True):
    if short:
        output = subprocess.check_output(["/bin/lsblk", "--json", "-o", "NAME,MOUNTPOINT"])
        return(output)
    else:
        output = subprocess.check_output(["/bin/lsblk", "--json", "-o", "NAME,MOUNTPOINT,MAJ:MIN"])
        return(output)

def list_block_devs(short=False):
    bdevs = []
    if short:
        blockdevices = json.loads(get_lsblk())
        for entry in blockdevices['blockdevices']:
            bdevs.append(entry['name'])
    else:
        blockdevices = json.loads(get_lsblk(False))
        for entry in blockdevices['blockdevices']:
            bdev = {}
            bdev['name'] = entry['name']
            bdev['maj:min'] = entry['maj:min']
            bdevs.append(bdev)
    return bdevs

def sum_all_values(data):
    return(sum(data.values()))

def values_to_percents(data):
    bdevs = list_block_devs()
    input_data = {}
    for dev in bdevs:
        input_data[dev['name']] = {'reads':{}, 'writes':{}}
    for key in data.keys():
        dev = key.split('_')[0]
        oper_type = key.split('_')[2]
        input_data[dev][oper_type][key] = data[key]
    # ^ created dictionary with separate writes and reads for each block device
    # now will have to calculate sum for all separate groups for each dev
    # and percentages (also for each group of each dev)
    output_data = {}
    for dev in bdevs:
        for oper_type in input_data[dev['name']].keys():
            temp = input_data[dev['name']][oper_type]
            oper_sum = sum_all_values(temp)
            for oper in temp.keys():
                if temp[oper] == 0:
                    output_data[oper + '_percent'] = 0
                else:
                    output_data[oper + '_percent'] = round(int(temp[oper])/oper_sum*100, 2)
    return output_data

def regexp_match_metrics(data, regexp):
    result = []
    for entry in data:
        if (re.search(regexp, entry)):
            result.append(entry)
    return result

def find_above_limit(data, limit):
    output = {}
    above = False
    for key in data.keys():
        key_number = key.split('_')[1].replace('ms','')
        if int(key_number) > int(limit):
            output[key] = data[key]
    return output

#print("Starting program, running bcc")
# load BPF program
def retrieve_data():
        reads = BPF(text="""
        #include <linux/blkdev.h>
        #include <linux/major.h>
      typedef struct disk_key {
          u32 dev;
          u8 op;
          u64 slot;
      } disk_key_t;
      // Max number of disks we expect to see on the host
      const u8 max_disks = 64;
      // 27 buckets for latency, max range is 2.048s .. 4.096s
      const u8 max_latency_slot = 12;
      // Histograms to record latencies
      BPF_HISTOGRAM(write_latency, disk_key_t, max_latency_slot * max_disks);
      BPF_HISTOGRAM(read_latency, disk_key_t, max_latency_slot * max_disks);
      struct key_t {
          dev_t dev;
          sector_t sector;
      };
      struct val_t {
          u64 start;
          u64 bytes;
      };
      // Hash to temporily hold the start time of each bio request, max 10k in-flight by default
      BPF_HASH(start, struct key_t, struct val_t,1024000);
      // Generates function tracepoint__block__block_rq_issue
      TRACEPOINT_PROBE(block, block_rq_issue) {
          // blkid generates these and we're not interested in them
          if (args->dev == 0) {
              return 0;
          }
          struct key_t key = {};
          key.dev = args->dev;
          key.sector = args->sector;
          if (key.sector == -1) {
             key.sector = 0;
          }
          struct val_t val = {};
          val.start = bpf_ktime_get_ns();
          val.bytes = args->bytes;
          start.update(&key, &val);
          return 0;
      }
      // Generates function tracepoint__block__block_rq_complete
      TRACEPOINT_PROBE(block, block_rq_complete) {
          struct key_t key = {};
          key.dev = args->dev;
          key.sector = args->sector;
          if (key.sector == -1) {
             key.sector = 0;
          }
          struct val_t *valp = start.lookup(&key);
          if (valp == 0) {
              return 0; // missed issue
          }
          // Delta in milliseconds
          u64 delta = (bpf_ktime_get_ns() - valp->start) / 1000000;
          // Latency histogram key
          u64 latency_slot = bpf_log2l(delta);
          // Cap latency bucket at max value
          if (latency_slot > max_latency_slot) {
              latency_slot = max_latency_slot;
          }
          disk_key_t latency_key = {};
          latency_key.slot = latency_slot;
          latency_key.dev = new_decode_dev(args->dev);
          if (args->rwbs[0] == 'W' || args->rwbs[0] == 'S' || args->rwbs[0] == 'F' || args->rwbs[1] == 'W' || args->rwbs[1] == 'S' || args->rwbs[1] == 'F') {
              latency_key.op = 2;
              //it's a WRITE op!
              write_latency.increment(latency_key);
          } else {
              latency_key.op = 1;
              read_latency.increment(latency_key);
          }
          start.delete(&key);
          return 0;
      }
        """)

        # list block devices - only drives, without splitting to partitions
        bdevs = list_block_devs(short=True)
        bdev_count = len(bdevs)
        # format output
        time.sleep(15)

        write_lat = reads.get_table("write_latency")
        read_lat = reads.get_table("read_latency")
        write_io_gen = print_json_hist(data=write_lat, val_type="msec", section_header="disk", writes=True)
        read_io_gen = print_json_hist(data=read_lat, val_type="msec", section_header="disk")
        write_io_data = []
        read_io_data = []
        # populate lists *_io_data with information from generators *_io_gen
        for entry in write_io_gen:
            write_io_data.append(entry)
        for entry in read_io_gen:
            read_io_data.append(entry)
        if (len(write_io_data) < bdev_count or len(read_io_data) < bdev_count):
            # if there is less output at the start of measuring than should be (device count greater
            # than output from tracing probes), we should just insert empty jsons for each device
            # UNTIL their data becomes available

            # process write data - find bdevs without tracing output
            if len(write_io_data) < bdev_count:
                if len(write_io_data) > 0:
                    for bdev in bdevs:
                        for entry in write_io_data:
                            # try to decode bytes-like object with drive name to string,
                            # if it fails due to not being bytes-like - great, treat it
                            # as a string
                            #
                            # this stuff is unnecessary for nagios plugin, but needed for
                            # validation if every block device was measured (and fill with
                            # zeroes if not
                            try:
                                drive = entry["disk"].decode("utf-8")
                            except (UnicodeDecodeError, AttributeError):
                                    drive = entry["disk"]
                            # if device is in both lists - skip
                            if drive == bdev:
                                continue
                            # if device is missing from the stats - fill it in
                            else:
                                write_io_data.append(generate_empty_json(bdev, write=True))
                else:
                    # if there are NO stats at all - fill them in with zeroes
                    for bdev in bdevs:
                        write_io_data.append(generate_empty_json(bdev, write=True))
            # same stuff here, see few lines higher for info
            if len(read_io_data) < bdev_count:
                if len(read_io_data) > 0:
                    for bdev in bdevs:
                        for entry in read_io_data:
                            try:
                                drive = entry["disk"].decode("utf-8")
                            except (UnicodeDecodeError, AttributeError):
                                drive = entry["disk"]
                            if drive == bdev:
                                continue
                            else:
                                read_io_data.append(generate_empty_json(bdev, write=False))
                else:
                    for bdev in bdevs:
                        read_io_data.append(generate_empty_json(bdev, write=False))

helper = PluginHelper()
# ^ initialize nagios plugin class
# v add necessary arguments - print to stdout, validate only metrics based on regexp, etc.
helper.parser.add_option('-A', dest='stdout', help='Print data to stdout and exit', action='store_true')
helper.parser.add_option('--regex', dest='regex', help='Validate only metrics fulfilling regexp')
helper.parser.add_option('--below', dest='below', help='What latencies we don\'t want to exceed (in milliseconds)')
helper.parse_arguments()


retrieve_data()
perc = values_to_percents(stuff_obj)

# concatenate both raw numbers and their percentage values
output_data = {}
output_data = dict(stuff_obj);
output_data.update(perc)


if (helper.options.regex is None and helper.options.stdout is None and helper.options.below is None):
    for entry in output_data.keys():
        helper.add_metric(label=entry, value=output_data[entry])
    helper.check_all_metrics()
    helper.exit()

elif helper.options.stdout is not None:
    print("Dumping all info gathered:")
    print(output_data)

elif helper.options.regex is not None:
    helper.debug("Using regex %s" % helper.options.regex)
    metrics_regexped = regexp_match_metrics(output_data.keys(), helper.options.regex)
    for entry in metrics_regexped:
        helper.add_metric(label=entry, value = output_data[entry])
    helper.check_all_metrics()
    helper.exit()

elif helper.options.below is not None:
    limit = helper.options.below
    helper.debug("Will fail out if we exceed latency of %s" % limit)
    output = find_above_limit(output_data, limit)
    if len(output) > 0:
        thresholds = [(critical,'1..inf')]
        for entry in output_data.keys():
            helper.add_metric(label=entry, value = output_data[entry])
        for entry in output:
            helper.check_metric(entry, thresholds)
    helper.exit()
