#! /usr/bin/python3

import binascii
import sys
import struct
import re
import argparse
import json
import matplotlib.pyplot as plt
import math

full_line = re.compile("\(([0-9.]+)\)\s+(\w+)\s+(\w+)#(\w*)")

scale = [255,
         [64,128,192]]

def generate_color(n):
    r = 0
    g = 0
    b = 0

    m = (n % 6) + 1
    if m & 4:
        r = 1
    if m & 2:
        g = 1
    if m & 1:
        b = 1
    
    if n < 6:
        r = scale[0] * r
        g = scale[0] * g
        b = scale[0] * b
    elif n < 18+6:
        k = int((n - 6) / 6)
        r = r * scale[1][k%3]
        g = g * scale[1][(k+1)%3]
        b = b * scale[1][(k+2)%3]
    else:
        r = 0
        g = 0
        b = 0
    #end if

    return "#%02X%02X%02X" % (r,g,b)

def create_view(data,title,last=True):
    if len(data) > 20:
        data = data[0:20]
    # end if
    fig, ax = plt.subplots()

    for i in range(len(data)):
        ax.bar(i * 10, data[i][0], 9,
               label=data[i][1],
               color=generate_color(i),
               yerr=data[i][2],
               tick_label="")
    # end for
    ax.set_ylabel('Count')
    ax.set_title(title)
    ax.legend()

    if last:
        plt.show()
    # end if
    

def parse_line(line):
    m = full_line.match(line)
    if m:
        if m.groups()[3] is None:
            data = ""
        else:
            data = m.groups()[3]
        # end if
        return {"time":m.groups()[0],
                "interface" : m.groups()[1],
                "identifier" : m.groups()[2],
                "data" : data}
    # end if
    return None

def parse_files(parser):
    interest = {"identifiers_counter":dict(),
                "messages_counter":dict(),
                "identifiers_content":dict()}
    blacklist = []
    blacklist_identifier = True
    if parser.blacklist is not None:
        black = json.loads(parser.blacklist.read())
        if black["kind"] == "messages":
            blacklist = black["data"]
            blacklist_identifier = False
        elif black["kind"] == "identifiers":
            blacklist = black["data"]
            blacklist_identifier = True
        # end if
    # end if

    for file in parser.files:
        input = open(file)
        while True:
            line = input.readline()
            if line == "":
                break
            # end if
            line = line.strip()
            fields = parse_line(line)

            identifier = fields["identifier"]
            payload = fields["data"]
            message = identifier + "#" + payload

            if blacklist_identifier:
                data = identifier
            else:
                data = message
            # end if
            if data not in blacklist:
                for entry,key_name,info in [[identifier,"identifiers_counter",file],
                                            [identifier,"identifiers_content",payload],
                                            [message,"messages_counter",file]]:
                    if entry in interest[key_name].keys():
                        if info in interest[key_name][entry].keys():
                            interest[key_name][entry][info] += 1
                        else:
                            interest[key_name][entry].update({info:1})
                        # end if
                    else:
                        interest[key_name].update({entry:{info:1}})
                    # end if
                #end for
            # end if
        # end while
    # end for
    return interest
    
def view(parser):
    interest = parse_files(parser)
    json_mode = False

    '''
    if False:
        # Sort message or id by files containing them
        results = []
        for item in sorted(interest.keys()):
            results.append([len(interest[item]),item,interest[item].keys()])
        # end for
        # print(sorted(results,key=lambda item: item[0]))
    # end if
    if json_mode:
        info = {"identifiers":[],
                "messages":[]}
        if idOnly:
            for item in interest["identifiers_count"].keys():
                info["identifiers"].append(item)
            # end for
        else:
            for item in interest.keys():
                info["messages"].append(item)
            # end for
        #end if 
        json_output = open(json_file,"w")
        json_output.write(json.dumps(info))
        json_output.close()
    else:
    '''
    if True:
        # Display statistics
        if True:
            # Sort by Id
            results = []
            for item in interest["identifiers_counter"].keys():
                sum = 0
                for f in interest["identifiers_counter"][item].keys():
                    sum += interest["identifiers_counter"][item][f]
                # end for
                results.append([sum,item,0])
            # end for
            sorted_results = sorted(results,reverse=True)
            create_view(sorted_results,"Messages count by identifiers",last=False)
            
            if parser.output_identifiers is not None:
                parser.output_identifiers.write(json.dumps({"kind":"identifiers",
                                                         "data":list(interest["identifiers_counter"].keys())}))
                parser.output_identifiers.close()
            # end if

            
            results = []
            for item in interest["identifiers_counter"].keys():
                sum = 0
                for f in interest["identifiers_counter"][item].keys():
                    sum += interest["identifiers_counter"][item][f]
                # end for
                results.append([len(interest["identifiers_content"][item]),item,0])
            # end for
            sorted_results = sorted(results,reverse=True)
            create_view(sorted_results,"Different contents by identifiers",last=False)
            
            first_id = sorted_results[0][1]

            
            if parser.output_messages is not None:
                parser.output_messages.write(json.dumps({"kind":"messages",
                                                         "data":list(interest["messages_counter"].keys())}))
                parser.output_messages.close()
            # end if


            stat_sum = None
            stat_square_sum = None

            nb_messages = 0
            for item in interest["identifiers_content"][first_id].keys():
                count = interest["identifiers_content"][first_id][item]
                content = binascii.unhexlify(item)
                nb_bytes = len(content)
                if stat_sum is None:
                    stat_sum = [0]*nb_bytes
                    stat_square_sum = [0]*nb_bytes
                # end if
                for i in range(nb_bytes):
                    if content[i] > 127:
                        value = content[i] - 256
                    else:
                        value = content[i]
                    # end if
                    
                    stat_sum[i] += value * count
                    stat_square_sum[i] += value * value * count
                # end for
                nb_messages += count
            # end for

            results = []
            for i in range(len(stat_sum)):
                mean = stat_sum[i] / nb_messages
                results.append([mean,
                                i,
                                math.sqrt(stat_square_sum[i] / nb_messages - mean * mean)])
            # end for
            create_view(results,"Content distribution for identifier " + first_id,last=True)

            
        else:
            # Sort by Id
            for item in sorted(interest.keys()):
                if count:
                    print(item)
                    for f in interest[item].keys():
                        print("%5d %s" % (interest[item][f],f))
                    # end for
                else:
                    sum = 0
                    for f in interest[item].keys():
                        sum += interest[item][f]
                    # end for
                # end if
            # end for
        # end if
    # end if
    
def count_messages(parser):
    list_messagesOrId(parser,False,True)

def count_identifiers(parser):
    list_messagesOrId(parser,True,True)

def list_messages(parser):
    list_messagesOrId(parser,False,False)

def list_identifiers(parser):
    list_messagesOrId(parser,True,False)

def time_correlation(args):
    interest = dict()
    
    window = 1.0
    
    for file in files:
        input = open(file)
        while True:
            line = input.readline()
            if line == "":
                break
            # end if
            line = line.strip()
            fields = parse_line(line)

            
            if idOnly:
                data = fields["identifier"]
            else:
                data = fields["identifier"] + "#" + fields["data"]
            # end if
            
            if data in interest.keys():
                if file in interest[data].keys():
                    interest[data][file] += 1
                else:
                    interest[data].update({file:1})
                # end if
            else:
                interest.update({data:{file:1}})
            # end if
        # end while
    # end for

parser = argparse.ArgumentParser(description='Rnet traces statistical analyzer')
parser.add_argument('command', help='list|count')
parser.add_argument('--output_identifiers', default=None, type=argparse.FileType('w'),
                    help='File identifiers are written in')
parser.add_argument('--output_messages', default=None, type=argparse.FileType('w'),
                    help='File messages are written in')
parser.add_argument('--blacklist', default=None, type=argparse.FileType('r'),
                    help='JSON file containing identifiers or messages to ignore')
parser.add_argument('files', nargs = '+', help='Dump file(s) to process')

args = parser.parse_args()


print("Command = %s" % (args.command))

commands = {'view' : { 'function' : view},
            'list_messages' : { 'function' : list_messages},
            'list_identifiers' : { 'function' : list_identifiers},
            'count_messages' : { 'function' : count_messages},
            'count_identifiers' : { 'function' : count_identifiers},
            'correlation' : { 'function' : time_correlation}}


if args.command in commands.keys():
    commands[args.command]["function"](args)
else:
    print("Known commands")
    for k in commands.keys():
        print("%s" % k)
    # end for
# end if