import idaapi
import idc
import idautils
import os
import pandas as pd
import joblib
import json
import math
import networkx as nx
import pandas as pd
from collections import Counter
from statistics import mean, stdev, median



class MnemocryptPlugin(idaapi.plugin_t):
    flags = idaapi.PLUGIN_UNL
    comment = "Mnemocrypt: Classify and display crypto functions"
    help = "Identifies cryptographic functions in the binary using a trained model"
    wanted_name = "Mnemocrypt"
    wanted_hotkey = "Ctrl-Alt-M"

    def init(self):
        return idaapi.PLUGIN_OK

    def run(self, arg):
        compute_features()

    def term(self):
        return



def compute_features():
    idaapi.auto_wait()
    print("Auto analysis finished. Starting Mnemocrypt...")

    # Constants declaration
    basename = idc.get_root_filename().split('.', 1)[0]
    computed_features = {}
    unwanted_prefixes = {"__", "___", "@", "std::"} # We assume that if a function starts with one of these sequences, it is a priori non-cryptographic
    round_precision = 3 # Compromise between precision and complexity
    # We assume that if a function is made of a single basic block with less than 3 instructions (without counting ones with pop, push, nop or ud mnemonics)
    # and not containing any mnemonic from cryptographic extension sets, then it is a prirori non-cryptographic
    small_func_nb_mnemonics_threshold = 3
    # Roots related to mnemonics used in Caballero heuristics, enriched with their respective undirected ngrams with "mov" root
    adjusted_caballero_roots = {"sh", "sa", "sra", "sla", "srl", "sll", "and", "or", "xor"}
    asymmetric_caballero_roots = {"mul", "clmul", "div", "add", "adc", "xadd", "madd"}
    important_roots = {"mov", "add", "xor", "and", "sh", "mul", "div", "ro"}
    important_non_mov_roots = important_roots - {"mov"}
    ngram_tuples = set([("mov","add"), ("mov","xor"), ("mov", "and"), ("mov","sh"), ("mov", "mul"), ("mov", "div"), ("mov", "ro")])
    ngram_mapping = {frozenset(ngram_tuple): f"{ngram_tuple[0]}_{ngram_tuple[1]}" for ngram_tuple in ngram_tuples}

    full_stat_related_features = {"nb_instr", "data_transfer", "arithmetic", "logic", "string_manipulation", "control_transfer", "process_control"}
    partial_stat_related_features = important_non_mov_roots | set(ngram_mapping.values())
    stat_related_features = full_stat_related_features | partial_stat_related_features
    immediate_crypto_functions = set()
    crypto_instructions_sets = {"aes": "AES-NI", "sha": "Intel SHA extensions"}

    # Load the instruction categories with respective roots and variants (used in mnemonics recognition)
    with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "mnemocrypt_roots.json"), 'r') as f:
        prepared_roots = json.load(f)

    # Main part
    for func_ea in idautils.Functions():
        func_name = idc.get_func_name(func_ea)

        # Some functions can be directly removed
        # Filtering on first symbols in functions' names
        if any(func_name.startswith(unwanted_prefix) for unwanted_prefix in unwanted_prefixes):
            continue
        # Filtering on recognition by FLIRT
        if (idc.get_func_flags(func_ea) & idaapi.FUNC_LIB != 0):
            continue

        cfg = nx.DiGraph() # Future control flow graph of the current function (used for cyclomatic complexity and maximum loop depth features)
        func = idaapi.get_func(func_ea)
        flowchart = None
        try:
            flowchart = idaapi.FlowChart(func)
        except Exception:
            print(f"Function {func_name} exceeds the maximum number of nodes set in IDA, skipped for stability reasons.")
        nb_bb = flowchart.size
        func_info, raw_features = {}, {}
        unique_data_refs, unique_func_calls = set(), set()
        nb_loops, nb_instr, nb_data_refs, nb_mov_instr, nb_edges, nb_default_caballero, nb_adjusted_caballero, nb_asymmetric_caballero = 0, 0, 0, 0, 0, 0, 0, 0
        mnemonic_category_counts = Counter() # To store category frequencies for entropy computation
        is_immediate_crypto, is_immediate_non_crypto = False, False
        default_caballero_ratio, adjusted_caballero_ratio, asymmetric_caballero_ratio = 0, 0, 0 # Value kept iif there is no any mov root in the function (which almost never happens in practice)

        for bb in flowchart:
            # The features of immediate crypto functions (i.e. containing mnemonics from cryptographic sets)
            if func_name not in immediate_crypto_functions:
                for succ in bb.succs():
                    cfg.add_edge(bb.id, succ.id)
                    nb_edges += 1

                if bb.start_ea == bb.end_ea:
                    nb_bb -= 1
                    continue

                hex_bb_address = hex(bb.start_ea)
                bb_raw_features = {raw_feature: 0 for raw_feature in stat_related_features}
                previous_root = None

                for head_ea in idautils.Heads(bb.start_ea, bb.end_ea):
                    if idc.is_code(idc.get_full_flags(head_ea)):
                        mnemonic = idc.print_insn_mnem(head_ea)

                        # Disarding meaningless mnemonics for the tool (more importantly, they tend to pollute the statistics if ever kept)
                        if (mnemonic in {"nop", "fnop"}) or mnemonic.startswith(("pop", "push", "ud")):
                            continue

                        nb_instr += 1
                        bb_raw_features["nb_instr"] += 1
                        nb_parts_mnem = len(mnemonic.split(" "))
                        # Initilization values that normally should change
                        retained_category, retained_root = None, None

                        # call $+5 is a commonly used way to push EIP value on the stack (so eqivalent to lea + push, push being ignored in our case)
                        if (mnemonic == "call" and idc.print_operand(head_ea, 0) == "$+5"):
                            mnemonic = "lea"
                            retained_category, retained_root = "data_transfer", "lea"

                        # xor reg, reg is typically used to set reg to 0 (so equivalent to mov reg, 0)
                        elif (mnemonic == "xor" and idc.print_operand(head_ea, 0) == idc.print_operand(head_ea, 1)):
                            mnemonic = "mov"
                            retained_category, retained_root = "data_transfer", "mov"

                        # Dealing with so-called "hint instructions"
                        elif nb_parts_mnem != 1:
                            retained_category, retained_root = "control_transfer", "j"
                        else:
                            # Getting the category and root associated to the currently analyzed code instruction's mnemonic
                            for root, category, variants in prepared_roots:
                                for variant in variants:
                                    if mnemonic.startswith(variant):
                                        retained_category, retained_root = category, root
                                        break
                                if retained_root: # Stop searching once a match is found (possible thanks to sorted roots)
                                    break

                        if not retained_root:
                            continue

                        if retained_category == "crypto":
                            is_immediate_crypto = True
                            immediate_crypto_functions.add((func_name, crypto_instructions_sets[retained_root]))
                            break

                        if retained_root == "cmps" and mnemonic[-1] in {"s", "d", "h"}:
                            retained_category, retained_root = "arithmetic", "cmp"
                        elif (retained_root == "movs" and mnemonic[-1] in {"s", "d", "h"}) or retained_root in {"movsx", "movzx"}:
                            retained_category, retained_root = "data_transfer", "mov"

                        if retained_root in {"mov", "cmov"}:
                            nb_mov_instr += 1

                        if nb_parts_mnem == 1:
                            idx_root_beginning = mnemonic.index(retained_root)
                            if "f" in mnemonic[:idx_root_beginning] or any(part in mnemonic for part in {"f16", "f32", "f64", "f128", "fYL2X", "sin", "cos", "tan", "tst"}):
                                is_immediate_non_crypto = True
                                break

                        bb_raw_features[retained_category] += 1
                        mnemonic_category_counts[retained_category] += 1

                        # Caballero heuristics related computations
                        if retained_category in {"arithmetic", "logic"}:
                            nb_default_caballero += 1
                            if retained_root in adjusted_caballero_roots:
                                nb_adjusted_caballero += 1
                            if retained_root in asymmetric_caballero_roots:
                                nb_asymmetric_caballero += 1

                        if retained_root == "call":
                            called_function = idc.get_operand_value(head_ea, 0)
                            unique_func_calls.add(called_function)

                        # Normalize retained_root for n-gram processing
                        if retained_root in {"sa", "sra", "sla", "srl", "sll"}:
                            retained_root = "sh"
                        elif retained_root in {"adc", "xadd"}:
                            retained_root = "add"
                        elif retained_root == "rc":
                            retained_root = "ro"
                        elif retained_root == "madd":
                            retained_root = "mul"

                        # mov related roots are not taken into account for density computation because they are extremely frequent
                        if retained_root in important_non_mov_roots:
                            bb_raw_features[retained_root] += 1

                        # Undirected n-grams recognition
                        if (previous_root != retained_root) and (((previous_root == "mov") and (retained_root in important_roots)) or ((retained_root == "mov") and (previous_root in important_roots))):
                            bb_raw_features[ngram_mapping[frozenset({previous_root, retained_root})]] += 1
                        previous_root = retained_root

                        # Update number of data references and unique data references
                        for ref_ea in idautils.DataRefsFrom(head_ea):
                            segment_name = idc.get_segm_name(ref_ea)
                            if segment_name and ('data' in segment_name or 'bss' in segment_name or segment_name == 'ds'):
                                nb_data_refs += 1
                                unique_data_refs.add(ref_ea)

                        # Update number of loops
                        if retained_root == "j":
                            for operand_index in range(2): # We assume that there are at most 2 operands
                                jump_operand = idc.print_operand(head_ea, operand_index)
                                if jump_operand.startswith("loc_"):
                                    target_address = int(jump_operand[4:], 16) # Get and convert the target address from hex to decimal
                                    if target_address < head_ea: # Backward jump => loop
                                        nb_loops += 1
                                        break

                raw_features[hex_bb_address] = bb_raw_features

        if not is_immediate_crypto and not is_immediate_non_crypto:
            nb_non_mov_instr = nb_instr - nb_mov_instr
            if nb_non_mov_instr != 0:
                default_caballero_ratio = round(nb_default_caballero / nb_non_mov_instr, round_precision)
                adjusted_caballero_ratio = round(nb_adjusted_caballero / nb_non_mov_instr, round_precision)
                asymmetric_caballero_ratio = round(nb_asymmetric_caballero / nb_non_mov_instr, round_precision)

            # Need to repeat this step at the end because some functions may have empty blocks
            if (nb_bb == 1):
                bb = next(iter(flowchart))
                mnemonics_count = 0
                # Discarding functions containing only one basic block and having less than [threshold] "meaningful" instructions
                for instr_ea in idautils.Heads(bb.start_ea, bb.end_ea):
                    if idaapi.is_code(idaapi.get_full_flags(instr_ea)):
                        mnemonic = idc.print_insn_mnem(instr_ea)
                        if mnemonic in {"nop", "fnop"} or mnemonic.startswith(("push", "pop", "ud")):
                            continue
                        mnemonics_count += 1
                        if mnemonics_count >= small_func_nb_mnemonics_threshold:
                            break
                if mnemonics_count < small_func_nb_mnemonics_threshold:
                    continue
            
            nb_nodes = cfg.number_of_nodes()
            sccs = list(nx.strongly_connected_components(cfg))
            func_info["nb_bb"] = nb_bb
            func_info["nb_instr"] = nb_instr
            func_info["nb_loops"] = nb_loops
            func_info["nb_data_refs"] = nb_data_refs
            func_info["nb_unique_data_refs"] = len(unique_data_refs)
            func_info["nb_unique_func_calls"] = len(unique_func_calls)
            func_info["cyclomatic_complexity"] = nb_edges - nb_nodes + 2*len(sccs)
            func_info["max_loop_depth"] = sum(1 for scc in sccs if len(scc) > 1)
            func_info["default_caballero_ratio"] = default_caballero_ratio
            func_info["adjusted_caballero_ratio"] = adjusted_caballero_ratio
            func_info["asymmetric_caballero_ratio"] = asymmetric_caballero_ratio
            func_info["crypto"] = 0 # Default value (1 only for some functions from training set)

            # Compute statistics-related features
            for feature in stat_related_features:
                values = [bb_raw_features[feature] for bb_raw_features in raw_features.values()]

                if (feature in important_non_mov_roots):
                    func_info[f"density_{feature}"] = round(sum(values) / nb_non_mov_instr, round_precision) if values and (nb_non_mov_instr > 0) else 0
                func_info[f"mean_{feature}"] = round(mean(values), round_precision) if values else 0

                if feature in full_stat_related_features:
                    func_info[f"std_dev_{feature}"] = round(stdev(values), round_precision) if len(values) > 1 else 0

            values = [bb_raw_features["nb_instr"] for bb_raw_features in raw_features.values()]
            func_info["max_nb_instr"] = max(values)
            func_info["median_nb_instr"] = round(median(values), round_precision) if values else 0
            
            # Compute entropy of mnemonic categories
            total_mnemonics = sum(mnemonic_category_counts.values())
            entropy = -sum((count / total_mnemonics) * math.log2(count / total_mnemonics) for count in mnemonic_category_counts.values()) if total_mnemonics > 0 else 0
            func_info["entropy_mnemonics_categories"] = round(entropy, round_precision)

            # Update the features of the binary under analysis with the current function
            computed_features[func_name] = func_info

    # Convert computed_features directly into a DataFrame
    data = [{'binary_name': basename, 'function_name': func_name, **func_computed_features} for func_name, func_computed_features in computed_features.items()]
    computed_features_df = pd.DataFrame(data)
    # Ensure consistent column ordering
    computed_features_df = computed_features_df[sorted(computed_features_df.columns)]
    # Fill NaN values with zeros for robustness
    computed_features_df.fillna(0, inplace=True)
    # Classify functions with the trained model, using their features and keeping track of a priori cryptographic functions
    classify_crypto_functions(computed_features_df, immediate_crypto_functions)


def classify_crypto_functions(computed_features, immediate_crypto_functions):
    # Configuration
    basename = idc.get_root_filename().split('.', 1)[0]
    min_crypto_confidence_score = 0.5

    # Load pre-trained model
    try:
        model = joblib.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "mnemocrypt_trained.pkl"))
    except Exception as e:
        print(f"[Mnemocrypt] Error loading model: {e}")
        return

    function_names = computed_features["function_name"]
    X_test = computed_features.drop(columns=["function_name", "binary_name", "crypto"])

    # Predict probabilities with the model
    try:
        probabilities = model.predict_proba(X_test)[:, 1]  # Probability for class '1' (crypto)
    except Exception as e:
        print(f"[Mnemocrypt] Error during prediction: {e}")
        return

    # Filter and sort functions with confidence > min_crypto_confidence_score
    classified_functions = []
    for i in range(len(probabilities)):
        current_probability = probabilities[i]
        if current_probability > min_crypto_confidence_score:
            func_name = function_names.iloc[i]
            tags = [] # Feature cryptographic identification tags of functions
            classified_functions.append((func_name, round(current_probability, 2), ", ".join(tags)))
    for func_name, crypto_ident_tag in immediate_crypto_functions: # immediate_crypto_functions and function_names are disjoint by construction
        tags = [crypto_ident_tag]
        classified_functions.append((func_name, 1.0, ", ".join(tags))) # Porbaility of being crypto directly set to 1.0 for a priori cryptographic functions

    sorted_functions = sorted(classified_functions, key=lambda x: x[1], reverse=True)

    if sorted_functions == []:
        print("No cryptographic functions have been detected.")

    idaapi.auto_wait() # To avoid any potential intereference

    # Export results to avoid relaunching the plugin on same binaries (comment the line, if don't want any export)
    save_results_to_csv(f"{basename}_mnemocrypt.csv", sorted_functions)

    display_results(sorted_functions)


def save_results_to_csv(filepath, records):
    """
    Save the classified functions to a CSV file.
    If the file exists, append new records to it.
    """
    # Create a DataFrame for the new records
    data = pd.DataFrame(records, columns=["Function Name", "Confidence Score", "Identification Tag"])

    # Save the updated data back to the file
    try:
        data.to_csv(filepath, index=False)
        print(f"[Mnemocrypt] Results saved to {filepath}")
    except Exception as e:
        print(f"[Mnemocrypt] Error saving CSV: {e}")


def display_results(sorted_functions):
    """
    Display the classified functions and their confidence scores in a table.
    Clicking on a function name navigates to its graph view.
    Rows are progressively color-coded based on a fixed gradient.
    """

    class ResultsChooser(idaapi.Choose):
        def __init__(self, title, items):
            idaapi.Choose.__init__(self, title, [["Function Name", 30], ["Confidence Score", 10], ["Identification Tag", 50]])
            self.items = items
            self.icon = 41  # Assign a default icon

        def OnGetSize(self):
            return len(self.items)

        def OnGetLine(self, index):
            # Display scores as formatted strings
            func_name, score, tags = self.items[index][0], self.items[index][1], self.items[index][2]
            #return [func_name, f"{score:.2f}", tags]
            return [func_name, str(score), tags]

        def OnSelectLine(self, n):
            """
            Navigate to the selected function's graph view.
            """
            func_name = self.items[n][0]
            func_ea = idc.get_name_ea_simple(func_name)
            if func_ea != idc.BADADDR:
                idc.jumpto(func_ea)
            else:
                print(f"[Mnemocrypt] Function not found: {func_name}")

        def OnGetLineAttr(self, n):
            """
            Assign fixed gradient colors based on predefined confidence ranges.
            """
            _, score, _ = self.items[n]

            # Define fixed color gradient (Red → Yellow)
            if score >= 0.95:
                color = 0x0000FF  # Red
            elif score >= 0.75:
                color = 0x007FFF  # Orange (approximation in RGB)
            elif score > 0.5:
                color = 0x00FFFF  # Yellow
            else:
                color = 0xFFFFFF  # Default (white)

            return [color, 0x000000]

    # Keep scores as float for processing
    items = sorted_functions
    chooser = ResultsChooser("Mnemocrypt", items)
    chooser.Show()

# Register the plugin with IDA
def PLUGIN_ENTRY():
    return MnemocryptPlugin()
