#!/usr/bin/python3
#

from bcc import BPF
import time
import ctypes as ct
import subprocess, json
from pynag.Plugins import PluginHelper,ok,warning,critical,unknown
import re
import os

# yes, global variable, I know it's bad..
stuff_obj = {}
log2_index_max = 65
SAMPLE_TIME = 15

def get_lsblk(short=True):
    if short:
        output = subprocess.check_output(["/bin/lsblk", "--json", "-o", "NAME,MOUNTPOINT"])
        return(output)
    else:
        output = subprocess.check_output(["/bin/lsblk", "--json", "-o", "NAME,MOUNTPOINT,MAJ:MIN"])
        return(output)

def list_block_devs(short=False):
    bdevs = []
    if short:
        blockdevices = json.loads(get_lsblk())
        for entry in blockdevices['blockdevices']:
            bdevs.append(entry['name'])
    else:
        blockdevices = json.loads(get_lsblk(False))
        for entry in blockdevices['blockdevices']:
            bdev = {}
            bdev['name'] = entry['name']
            bdev['maj:min'] = entry['maj:min']
            bdevs.append(bdev)
    return bdevs


BDEVS_LONG = list_block_devs()
BDEVS_SHORT = list_block_devs(short=True)

def raw_to_majmin(devid):
    return "%s:%s" % (os.major(devid),os.minor(devid))

def raw_to_name(devid):
    devno = raw_to_majmin(devid)
    for dev in BDEVS_LONG:
        if dev['maj:min'] == devno:
            return dev['name']
    return "UNKNOWN"

def decode_c_struct(data, tmp, buckets, bucket_fn, bucket_sort_fn):
        f1 = data.Key._fields_[0][0]
        f2 = data.Key._fields_[3][0]
        # The above code assumes that self.Key._fields_[3][0] holds the
        # slot. But a padding member may have been inserted here, which
        # breaks the assumption and leads to chaos.
        # TODO: this is a quick fix. Fixing/working around in the BCC
        # internal library is the right thing to do.
        for k, v in data.items():
            bucket = getattr(k, f1)
            vals = tmp[bucket] = tmp.get(bucket, [0] * log2_index_max)
            slot = getattr(k, f2)
            vals[slot] = v.value
        buckets_lst = list(tmp.keys())
        for bucket in buckets_lst:
            buckets.append(bucket)


def print_json_hist(data, val_type="value", section_header="Bucket ptr",
                        section_print_fn=None, bucket_fn=None, bucket_sort_fn=None,
                        writes=False):
    if isinstance(data.Key(), ct.Structure):
        tmp = {}
        buckets = []
        decode_c_struct(data, tmp, buckets, bucket_fn, bucket_sort_fn)
        for drive in buckets:
            vals = tmp[drive]
            section_bucket = (section_header, drive)
            yield store_json_hist(vals, val_type, section_bucket, writes=writes)

    else:
        vals = [0] * log2_index_max
        for k, v in data.items():
            vals[k.value] = v.value
        yield store_json_hist(vals, val_type, writes=writes)

def store_json_hist(vals, val_type, section_bucket=None, writes=False):
    max_nonzero_idx = 11
    index = 1
    prev = 0
    data_out = []
    if writes:
        for i in range(len(vals)):
            if i <= max_nonzero_idx:
                temp = {}
                #write_key = str(section_bucket[1], "utf-8") + '_ms' + str(index) + '_writes'
                write_key = raw_to_name(section_bucket[1])+ '_ms' + str(index) + '_writes'
                stuff_obj[write_key] = int(vals[i])
                temp[write_key] = int(vals[i])
                data_out.append(temp)
                index = index * 2
                prev = index

    else:
        for i in range(len(vals)):
            if i <= max_nonzero_idx:
                temp = {}
                #read_key = str(section_bucket[1], "utf-8") + '_ms' + str(index) + '_reads'
                read_key = raw_to_name(section_bucket[1])+ '_ms' + str(index) + '_reads'
                stuff_obj[read_key] = int(vals[i])
                temp[read_key] = int(vals[i])
                data_out.append(temp)
                index = index * 2
                prev = index
    histogram = {"date":str(time.time()).split('.')[0], "data": data_out , "disk": raw_to_name(section_bucket[1])}
    return histogram

def generate_empty_json(device, write=False):
    max_idx = 11
    index = 1
    i = 0
    prev = 0
    if write:
        while i <= max_idx:
            write_key = device + '_ms' + str(index) + '_writes'
            stuff_obj[write_key] = 0
            index = index * 2
            prev = index
            i+=1
    else:
        while i <= max_idx:
            read_key = device + '_ms' + str(index) + '_reads'
            stuff_obj[read_key] = 0
            index = index * 2
            prev = index
            i += 1
    histogram = {"date":str(time.time()).split('.')[0], "data": [], "disk": device}
    return histogram

def values_to_percents(data):
    input_data = {}
    for dev in BDEVS_LONG:
        input_data[dev['name']] = {'reads':{}, 'writes':{}}
    for key in data.keys():
        dev = key.split('_')[0]
        oper_type = key.split('_')[2]
        input_data[dev][oper_type][key] = data[key]
    # ^ created dictionary with separate writes and reads for each block device
    # now will have to calculate sum for all separate groups for each dev
    # and percentages (also for each group of each dev)
    output_data = {}
    for dev in BDEVS_LONG:
        for oper_type in input_data[dev['name']].keys():
            temp = input_data[dev['name']][oper_type]
            oper_sum = sum(temp.values())
            for oper in temp.keys():
                if temp[oper] == 0:
                    output_data[oper + '_percent'] = 0
                else:
                    output_data[oper + '_percent'] = round(int(temp[oper])/oper_sum*100, 2)
    return output_data

def histogrammize_data(data):
    output_data = {}
    for raw_oper_key in data.keys():
        if str(raw_oper_key).endswith("percent"):
            continue
        splitted = str(raw_oper_key).split('_')
        dev = str(splitted[0])
        buck = int(splitted[1].replace('ms', ''))
        op = str(splitted[2])
        amount = data[raw_oper_key]
        output_data.setdefault(dev, {})
        output_data[dev].setdefault(op, {})
        output_data[dev][op][buck] = amount
    for drive in output_data.keys():
        for operation in output_data[drive].keys():
            sorted_ops = sorted(output_data[drive][operation].keys())
            op_data = output_data[drive][operation]
            for buck in sorted_ops:
                # REALLY ugly for loop to summarize only one type of ops for each
                # device separately from each other
                # sorted in reverse - the longer the window for ops time, the more
                # it should contain

                #and below let's prepare proper key name for our big dict
                if buck == 1:
                    continue
                else:
                    output_data[drive][operation][buck] += output_data[drive][operation][int(buck/2)]
    output = {}
    for drive in output_data.keys():
        for operation in output_data[drive].keys():
            for buck in output_data[drive][operation].keys():
                keyname = drive + "_ms" + str(buck) + "_" + operation
                output[keyname] = output_data[drive][operation][buck]
    return output

def regexp_match_metrics(data, regexp):
    result = []
    for entry in data:
        if (re.search(regexp, entry)):
            result.append(entry)
    return result

def find_above_limit(data, limit):
    output = {}
    above = False
    for key in data.keys():
        key_number = key.split('_')[1].replace('ms','')
        if int(key_number) > int(limit):
            # ugly but working hack: if "above limit" number is higher than the stil in-limit
            # then we have to alarm, if not - it's our histograms' fault for false-positive
            if data[str(key).replace(key_number, limit)] == data[key]:
                continue
            else:
                output[key] = data[key]
    return output

def divide_data():
    for key in stuff_obj.keys():
        stuff_obj[key] = round(stuff_obj[key]/SAMPLE_TIME, 2)

#print("Starting program, running bcc")
# load BPF program
def retrieve_data():
        reads = BPF(text="""
        #include <linux/blkdev.h>
        #include <linux/major.h>
      typedef struct disk_key {
          u32 dev;
          u8 op;
          u64 slot;
      } disk_key_t;
      // Max number of disks we expect to see on the host
      const u8 max_disks = 64;
      // 27 buckets for latency, max range is 2.048s .. 4.096s
      const u8 max_latency_slot = 12;
      // Histograms to record latencies
      BPF_HISTOGRAM(write_latency, disk_key_t, max_latency_slot * max_disks);
      BPF_HISTOGRAM(read_latency, disk_key_t, max_latency_slot * max_disks);
      struct key_t {
          dev_t dev;
          sector_t sector;
      };
      struct val_t {
          u64 start;
          u64 bytes;
      };
      // Hash to temporily hold the start time of each bio request, max 10k in-flight by default
      BPF_HASH(start, struct key_t, struct val_t,1024000);
      // Generates function tracepoint__block__block_rq_issue
      TRACEPOINT_PROBE(block, block_rq_issue) {
          // blkid generates these and we're not interested in them
          if (args->dev == 0) {
              return 0;
          }
          struct key_t key = {};
          key.dev = args->dev;
          key.sector = args->sector;
          if (key.sector == -1) {
             key.sector = 0;
          }
          struct val_t val = {};
          val.start = bpf_ktime_get_ns();
          val.bytes = args->bytes;
          start.update(&key, &val);
          return 0;
      }
      // Generates function tracepoint__block__block_rq_complete
      TRACEPOINT_PROBE(block, block_rq_complete) {
          struct key_t key = {};
          key.dev = args->dev;
          key.sector = args->sector;
          if (key.sector == -1) {
             key.sector = 0;
          }
          struct val_t *valp = start.lookup(&key);
          if (valp == 0) {
              return 0; // missed issue
          }
          // Delta in milliseconds
          u64 delta = (bpf_ktime_get_ns() - valp->start) / 1000000;
          // Latency histogram key
          u64 latency_slot = bpf_log2l(delta);
          // Cap latency bucket at max value
          if (latency_slot > max_latency_slot) {
              latency_slot = max_latency_slot;
          }
          disk_key_t latency_key = {};
          latency_key.slot = latency_slot;
          latency_key.dev = new_decode_dev(args->dev);
          if (args->rwbs[0] == 'W' || args->rwbs[0] == 'S' || args->rwbs[0] == 'F' || args->rwbs[1] == 'W' || args->rwbs[1] == 'S' || args->rwbs[1] == 'F') {
              latency_key.op = 2;
              //it's a WRITE op!
              write_latency.increment(latency_key);
          } else {
              latency_key.op = 1;
              read_latency.increment(latency_key);
          }
          start.delete(&key);
          return 0;
      }
        """)

        # list block devices - only drives, without splitting to partitions
        bdev_count = len(BDEVS_SHORT)
        # format output
        time.sleep(SAMPLE_TIME)

        write_io_data = []
        read_io_data = []
        write_lat = reads.get_table("write_latency")
        read_lat = reads.get_table("read_latency")
        for bdev in BDEVS_SHORT:
            write_io_data.append(generate_empty_json(bdev, write=True))
            read_io_data.append(generate_empty_json(bdev, write=False))
        write_io_gen = print_json_hist(data=write_lat, val_type="msec", section_header="disk", writes=True)
        read_io_gen = print_json_hist(data=read_lat, val_type="msec", section_header="disk")
        # populate lists *_io_data with information from generators *_io_gen
        for entry in write_io_gen:
            write_io_data.append(entry)
        for entry in read_io_gen:
            read_io_data.append(entry)

helper = PluginHelper()
# ^ initialize nagios plugin class
# v add necessary arguments - print to stdout, validate only metrics based on regexp, etc.
helper.parser.add_option('-A', dest='stdout', help='Print data to stdout and exit', action='store_true')
helper.parser.add_option('--regex', dest='regex', help='Validate only metrics fulfilling regexp')
helper.parser.add_option('--below', dest='below', help='What latencies we don\'t want to exceed (in milliseconds)')
helper.parse_arguments()


retrieve_data()
divide_data()
perc = values_to_percents(stuff_obj)

# concatenate both raw numbers and their percentage values
output_data = {}
output_data = dict(stuff_obj);
#output_data = histogrammize_data(output_data);
output_data.update(perc)


if (helper.options.regex is None and helper.options.stdout is None and helper.options.below is None):
    for entry in output_data.keys():
        helper.add_metric(label=entry, value=output_data[entry])
    helper.check_all_metrics()
    helper.exit()

elif helper.options.stdout is not None:
    print("Dumping all info gathered:")
    print(output_data)

elif helper.options.regex is not None:
    helper.debug("Using regex %s" % helper.options.regex)
    metrics_regexped = regexp_match_metrics(output_data.keys(), helper.options.regex)
    for entry in metrics_regexped:
        helper.add_metric(label=entry, value = output_data[entry])
    helper.check_all_metrics()
    helper.exit()

elif helper.options.below is not None:
    limit = helper.options.below
    helper.debug("Will fail out if we exceed latency of %s" % limit)
    output = find_above_limit(output_data, limit)
    if len(output) > 0:
        thresholds = [(critical,'1..inf')]
        for entry in output_data.keys():
            helper.add_metric(label=entry, value = output_data[entry])
        for entry in output:
            helper.check_metric(entry, thresholds)
    else:
        for entry in output_data.keys():
            helper.add_metric(label=entry, value=output_data[entry])
        helper.check_all_metrics()
    helper.exit()
