import collections
import re

# IDA imports that are neutral to version changes.
import ida_hexrays
import idaapi
import idautils
import idc
import ida_typeinf

# Version-specific IDA modules.
try:
    import ida_dirtree
except ImportError:
    pass

from shims import ida_shims
from utils import qt_utils
from shims.qt_shims import QtWidgets

# logic / prefixes '%' and '_' are the opposites:
# 1. '%' - has always single occurence, '_' - not;
# 2. '%' cannot appear at the very beginning of a function name, '_' - can;
# 3. '%' is purely internal prefix representation, '_' - human representation;
# 4. '%' are the prefixes added automatically, '_' - manually

def is_pfx_valid(pfx):
    a_facts = ['@', '$', '?', '-', '+']
    is_complex = any(a in pfx for a in a_facts)
    is_numeric = re.match('^[0-9]+', pfx)
    is_blanked = pfx == ''
    return not (is_complex or is_numeric or is_blanked)

def get_func_prefs(func_name, is_dummy=True):
    if ((func_name.startswith('?') and '@' in func_name) or
        func_name.startswith('_')):
        return []
    pfx_dummy = 'sub_'
    prefs = []
    pfx = ''

    idx = 0
    func_name = func_name.rstrip('_%:')
    while idx < len(func_name):
        char = func_name[idx]
        if char in ['%', ':', '_']:
            pfx_len = 1
            while (idx+pfx_len) < len(func_name) and func_name[idx+pfx_len] in ['_', ':']:
                pfx_len += 1

            if idx != 0:
                # uncomment, if underscore tail in pfx is needed
                # pfx += func_name[idx:idx+pfx_len]
                if is_pfx_valid(pfx):
                    prefs.append(pfx)
                pfx = ''

            idx += pfx_len-1
        else:
            pfx += char

        idx += 1

    if not is_dummy and pfx_dummy in prefs:
        prefs.remove(pfx_dummy)
    return prefs

def get_external_pfx(func_name, is_dummy=True):
    pfx = get_func_prefs(func_name, is_dummy)
    return pfx[0] if pfx else None

def get_parent_dir(func_addr):
    dirs = get_dirs('/')
    dir_funcs = get_dir_funcs(dirs)
    for k, v in dir_funcs.items():
        if func_addr in v and k != '/':
            return k
    return None

def get_same_folder_funcs(func_addr):
    dirs = get_dirs('/')
    dir_funcs = get_dir_funcs(dirs)
    for k, v in dir_funcs.items():
        if func_addr in v and k != '/':
            return v
    return None

def get_folder_funcs(func_dir):
    dir_funcs = get_dir_funcs([func_dir])
    if func_dir in dir_funcs:
        return dir_funcs[func_dir]
    return []

def get_prefix_funcs(func_pfx):
    pfx_funcs = []
    for func_addr in idautils.Functions():
        func_name = idaapi.get_func_name(func_addr)
        anchor_pfx = get_external_pfx(func_name, True)
        if anchor_pfx == func_pfx:
            pfx_funcs.append(func_addr)
    return pfx_funcs

def get_color_funcs(func_col):
    col_funcs = []
    for func_addr in idautils.Functions():
        col = get_func_color(func_addr)
        if col == func_col:
            col_funcs.append(func_addr)
    return col_funcs

def get_func_color(func_addr):
    col = ida_shims.get_color(func_addr, idc.CIC_FUNC)
    return None if col == 16777215 else col

def refresh_ui():
    ida_shims.refresh_idaview_anyway()
    widget = ida_shims.get_current_widget()
    widget_vdui = ida_shims.get_widget_vdui(widget)
    if widget_vdui:
        widget_vdui.refresh_ctext()

def convert_dict(input_dict):
    # Initialize an empty dictionary for the output
    output_dict = {}

    # Iterate over the input dictionary
    for address, level in input_dict.items():
        # If the level is not already a key in the output dictionary, create an empty list for it
        if level not in output_dict:
            output_dict[level] = []

        # Append the address to the list corresponding to its level
        output_dict[level].append(address)
    res = dict(sorted(output_dict.items()))
    return res

def build_dependency_graph(entities):
    graph = collections.defaultdict(list)
    indegree = collections.defaultdict(int)

    # Build the graph and compute indegrees
    for entity, dependencies in entities.items():
        for dependency in dependencies:
            graph[dependency].append(entity)
            indegree[entity] += 1

        # Ensure the entity is in the indegree dictionary even if it has no dependencies
        if entity not in indegree:
            indegree[entity] = 0

    return graph, indegree

def get_dependency_order(graph, indegree):
    # Topological Sort using Kahn's Algorithm
    queue = collections.deque([node for node in indegree if indegree[node] == 0])
    topo_order = []

    while queue:
        node = queue.popleft()
        topo_order.append(node)

        for successor in graph[node]:
            indegree[successor] -= 1
            if indegree[successor] == 0:
                queue.append(successor)

    # If the topo_order doesn't include all nodes, there's a cycle
    if len(topo_order) != len(indegree):
        raise ValueError("The graph contains cycles, which indicates a circular dependency.")

    return topo_order

def graph_down_simple(ea, graph = {}, path = set([])):
    """
    Creates a downgraph of xrefs FROM this function.
    Calling it recursively allow us to get infinite depth.

    :param ea: address of ROOT NODE
    :return:   Dictionary of function ea's and child *addresses* { ea : [c1_ea, c2_ea, ...] }
    """
    graph[ea] = list()
    path.add(ea)

    # Iterate through all function instructions and take only call instructions.
    for x in [x for x in idautils.FuncItems(ea) if idaapi.is_call_insn(x)]:
        for xref in idautils.XrefsFrom(x, idaapi.XREF_FAR):
            if not xref.iscode:
                continue

            # Eliminate recursions.
            if xref.to not in path:
                graph[ea].append(xref.to)
                graph_down_simple(xref.to, graph, path)

    return graph

def get_order_recur_defs(func_ea):
    nodes = ida_utils.graph_down_simple(func_ea)
    dependency_order = get_order_func_defs(nodes)
    return dependency_order

def get_order_func_defs(nodes):
    dependency_graph, indegree = build_dependency_graph(nodes)
    dependency_order = get_dependency_order(dependency_graph, indegree)
    return dependency_order

def get_func_set(nodes):
    eas = list(nodes.keys())
    for deps in nodes.values():
        for ea in deps:
            eas.append(ea)

    return list(set(eas))

def graph_down(ea, depth=0, path=dict(), convert=False):
    path[ea] = depth
    call_instructions = []
    for address in idautils.FuncItems(ea):
        if not ida_shims.decode_insn(address):
            continue
        if not idaapi.is_call_insn(address):
            continue
        call_instructions.append(address)

    for x in call_instructions:
        for r in idautils.XrefsFrom(x, idaapi.XREF_FAR):
            if not r.iscode:
                continue
            func = idaapi.get_func(r.to)
            if not func:
                continue
            if (func.flags & (idaapi.FUNC_THUNK | idaapi.FUNC_LIB)) != 0:
                continue
            if r.to not in path:
                graph_down(r.to, depth + 1, path)
            else:
                if path[r.to] < depth + 1:
                    path[r.to] = depth + 1  # update max depth
    return convert_dict(path) if convert else path

def get_func_ea_by_ref(func_ref):
    if isinstance(func_ref, str):
        if func_ref.startswith('0x'):
            return int(func_ref, 16)
        else:
            return ida_shims.get_name_ea_simple(func_ref)
    elif isinstance(func_ref, func_t):
        return func_ref.start_ea
    elif isinstance(func_ref, int):
        return func_ref

def get_func_item_eas(func_ref):
    func_ea = get_func_ea_by_ref(func_ref)
    for item_ea in list(idautils.FuncItems(func_ea)):
        if idaapi.is_code(ida_shims.get_full_flags(func_ea)):
            yield item_ea

def get_func_item_eas_once(func_ref):
    item_eas = []
    for ea in get_func_item_eas(func_ref):
        item_eas.append(ea)
    return item_eas


def is_function_leaf(func_ref):
    func_ea = get_func_ea_by_ref(func_ref)
    item_eas = [item_ea for item_ea in get_func_item_eas(func_ea)]
    for item_ea in item_eas:
        if ida_shims.ua_mnem(item_ea) == 'call':
            return False
    else:
        if ida_shims.ua_mnem(item_eas[-1]) == 'jmp':
            return False
        else:
            return True

def get_func_dirs(root_dir):
    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    ite = ida_dirtree.dirtree_iterator_t()

    s_folders = [root_dir]
    u_folders = [root_dir]

    while len(s_folders):
        curr_path = s_folders.pop()
        func_dir.chdir(curr_path)
        status = func_dir.findfirst(ite, "*")

        while status:
            entry_name = func_dir.get_entry_name(func_dir.resolve_cursor(ite.cursor))
            if func_dir.isdir(func_dir.get_abspath(ite.cursor)):
                current_dir_new = '{}/{}'.format('' if curr_path == '/' else curr_path, entry_name)
                s_folders.append(current_dir_new)
                if not current_dir_new in u_folders:
                    u_folders.append(current_dir_new)
            status = func_dir.findnext(ite)

    return u_folders

def get_dir_funcs(folders, is_root=True):
    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    ite = ida_dirtree.dirtree_iterator_t()
    idx = 0

    funcs = {}
    while idx < len(folders):
        curr_path = folders[idx]
        func_dir.chdir(curr_path)
        status = func_dir.findfirst(ite, "*")

        while status:
            entry_name = func_dir.get_entry_name(func_dir.resolve_cursor(ite.cursor))
            func_addr = ida_shims.get_name_ea(0, entry_name)
            if func_dir.isfile(func_dir.get_abspath(ite.cursor)):
                if is_root == False and curr_path == '/':
                    # if only the functions with non-standard dir are needed
                    pass
                else:
                    if not curr_path in funcs:
                        funcs[curr_path] = []
                    funcs[curr_path].append(func_addr)
            status = func_dir.findnext(ite)
        idx += 1

    return funcs

def get_func_name(func_ref):
    func_name = None
    if isinstance(func_ref, str):
        func_name = func_ref
    elif isinstance(func_ref, int):
        func_name = ida_shims.get_func_name(func_ref)
    else:
        raise ValueError("Invalid func reference")
    return func_name

def get_folder_norm(folder):
    return '' if folder == '/' else folder

def set_func_folder(func_ref, folder_src, folder_dst):
    func_name = get_func_name(func_ref)
    func_src = '{}/{}'.format(get_folder_norm(folder_src), func_name)
    func_dst = '{}/{}'.format(get_folder_norm(folder_dst), func_name)

    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    func_dir.chdir('/')
    func_dir.rename(func_src, func_dst)

def is_in_interval(addr, func_ivals, is_strict):
    if is_strict:
        return any(beg < addr < end for beg, end in func_ivals)
    else:
        return any(beg <= addr <= end for beg, end in func_ivals)

def get_func_ivals(func_addr):
    return [(func_beg, func_end) for func_beg, func_end in ida_shims.get_chunk_eas(func_addr)]

def get_chunk_count(func_addr):
    num_chunks = len(get_func_ivals(func_addr))
    return num_chunks

def is_addr_func(addr, func_addr, is_chunks, is_strict):
    func_ivals = None
    if is_chunks:
        func_ivals = get_func_ivals(func_addr)
    else:
        func_beg = func_addr
        func_end = idc.get_func_attr(func_addr, idc.FUNCATTR_END)
        func_ivals = [(func_beg, func_end)]

    return is_in_interval(addr, func_ivals, is_strict)

def is_func_thunk(func_addr):
    func_flags = ida_shims.get_func_flags(func_addr)
    return func_flags & idaapi.FUNC_THUNK

def get_code_refs_to(addr):
    return set([cref for cref in idautils.CodeRefsTo(addr, 0)])

def get_data_refs_to(addr):
    return set([dref for dref in idautils.DataRefsTo(addr)])

def get_refs_to(addr):
    return iter(get_code_refs_to(addr).union(get_data_refs_to(addr)))

def is_arch64():
    return bool(idaapi.getseg(ida_shims.get_first_seg()).bitness == 2)

def get_ptr_type():
    return FF_QWORD if is_arch64() else FF_DWORD

def get_ref_off():
    return REF_OFF64 if is_arch64() else REF_OFF32

def get_ptr_size():
    return 8 if is_arch64() else 4

def get_ptr(addr):
    return [idaapi.get_32bit, idaapi.get_64bit][is_arch64()](addr)

def is_vtable(addr):
    if addr and has_xref(addr):
        func_ea = get_ptr(addr)
        if func_ea and idaapi.getseg(func_ea):
            if ida_shims.get_segm_attr(func_ea, idc.SEGATTR_TYPE) == idc.SEG_CODE:
                func_desc = idaapi.get_func(func_ea)
                if func_desc and func_ea == ida_shims.start_ea(func_desc):
                    return True
    return False

def has_xref(addr):
    return ida_shims.has_xref(ida_shims.get_full_flags(addr))


def get_dirs(root_dir):
    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    ite = ida_dirtree.dirtree_iterator_t()

    s_folders = [root_dir]
    u_folders = [root_dir]

    while len(s_folders):
        curr_path = s_folders.pop()
        func_dir.chdir(curr_path)
        status = func_dir.findfirst(ite, "*")

        while status:
            entry_name = func_dir.get_entry_name(func_dir.resolve_cursor(ite.cursor))
            if func_dir.isdir(func_dir.get_abspath(ite.cursor)):
                current_dir_new = '{}/{}'.format('' if curr_path == '/' else curr_path, entry_name)
                s_folders.append(current_dir_new)
                if not current_dir_new in u_folders:
                    u_folders.append(current_dir_new)
            status = func_dir.findnext(ite)

    return u_folders

def get_dir_funcs_B(folders):
    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    ite = ida_dirtree.dirtree_iterator_t()
    idx = 0

    funcs = {}
    while idx < len(folders):
        curr_path = folders[idx]
        func_dir.chdir(curr_path)
        status = func_dir.findfirst(ite, "*")

        while status:
            entry_name = func_dir.get_entry_name(func_dir.resolve_cursor(ite.cursor))
            func_addr = ida_shims.get_name_ea(0, entry_name)
            if func_dir.isfile(func_dir.get_abspath(ite.cursor)):
                if curr_path != '/':
                    funcs[func_addr] = curr_path
            status = func_dir.findnext(ite)
        idx += 1

    return funcs

def update_ui(loc):
    cur_func = ida_hexrays.decompile(loc)
    cur_func.refresh_func_ctext()

def get_all_funcs():
    return list(idautils.Functions())

def get_selected_funcs():
    """
    Return the list of function addresses selected in the Functions widget.

    Locate the widget within the Functions window that actually holds
    all the visible function metadata.

    There is a choice between the two,
    depending on which one is active - QTableView, QTreeView.

    In case "show folders" feature was activated (on-startup/on-demand),
    during the current session QTreeView will spawn as a child as well.

    Scrape the selected function names from the Functions window view.

    """
    twidget = idaapi.find_widget("Functions window")

    # The 'Functions' widget in IDA is closed.
    if not twidget:
        return []

    func_wgt = qt_utils.get_ptr_wgt(int(twidget))
    view_wgts = [QtWidgets.QTableView, QtWidgets.QTreeView]
    table_view, tree_view = list(qt_utils.get_parent_wgts(func_wgt, view_wgts))

    selected_funcs = None
    if table_view.hasFocus():
        selected_funcs = [str(s.data()) for s in table_view.selectionModel().selectedRows()]
    elif tree_view.hasFocus():
        selected_indexes = tree_view.selectionModel().selectedIndexes()
        selected_funcs = []
        for index in selected_indexes:
            if index.column() == 0:
                selected_funcs.append(str(index.data()))
        selected_funcs = list(set(selected_funcs))
    else:
        return []

    # Re-map the scraped names as they appear in the function view,
    # to their true names as they are saved in the IDB.

    return match_funcs(selected_funcs)

def match_funcs(qt_funcs):
    """
    Convert function names scraped from Qt to their *actual* representation.

    The function names we scrape from the Functions window Qt table actually
    use the underscore character ('_') as a substitute for a variety of
    different characters.

    For example, a function named foo%bar in the IDB will appears as foo_bar
    in the Functions window table.

    This function takes a list of names as they appear in the Functions window
    table such as the following:

        ['foo_bar']

    And applies a best effort lookup to return a list of the 'true' function
    names as they are stored in the IDB.

        ['foo%bar']

    TODO: rewrite this to be more efficient for larger idbs
    TODO: takes first matching function, may want to change it to make the requirements more strict
    """
    res = set()
    func_eas = get_all_funcs()
    for f in qt_funcs:
        for ea in func_eas:
            f2 = idaapi.get_func_name(ea)
            if len(f) == len(f2):
                i = 0
                while i < len(f) and (f[i] == f2[i] or f[i] == '_'):
                    i += 1

                if i == len(f):
                    res.add(ea)
                    break

    return list(res)

def create_folder(merge_name):
    func_dir = ida_dirtree.get_std_dirtree(ida_dirtree.DIRTREE_FUNCS)
    func_dir.chdir('/')
    if not func_dir.isdir(merge_name):
        func_dir.mkdir(merge_name)
        return True
    return False

def get_type_def(type_name):
    type_id = idc.get_struc_id(type_name)
    type_size = idc.get_struc_size(type_id)

    struct_def = []
    struct_def.append("{:08X} struct {} // sizeof=0x{:X}".format(0, type_name, type_size))
    struct_def.append("{:08X} {{".format(0))

    for m_offset, m_name, m_size in idautils.StructMembers(type_id):
        m_name_id = idc.get_member_id(type_id, m_offset)
        udm = ida_typeinf.udm_t()
        ida_typeinf.get_udm_by_fullname(udm, idc.get_struc_name(m_name_id))
        m_type = udm.type

        struct_def.append("{:08X}     {} {};".format(m_offset, m_type, m_name))

    struct_def.append("{:08X} }};".format(type_size))
    return "\n".join(struct_def)
