#  Copyright 2022 Quarkslab
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from __future__ import annotations

import itertools
import logging

import pypcode

import quokka
import quokka.analysis
from quokka.types import Any, Dict, List, Sequence, Type, Optional

logger: logging.Logger = logging.getLogger(__name__)


def get_arch_from_string(target_id: str) -> pypcode.ArchLanguage:
    """Find the architecture for an arch based on the target identification

    Arguments:
        target_id: Identifier of the architecture

    Raises:
        PypcodeError: if the architecture is not found

    Returns:
        The appropriate ArchLang
    """
    pcode_arch: pypcode.Arch
    for pcode_arch in pypcode.Arch.enumerate():
        for lang in pcode_arch.languages:
            if lang.id == target_id:
                return lang

    raise quokka.PypcodeError("Unable to find the appropriate arch: missing lang")


def get_pypcode_context(
    arch: Type[quokka.analysis.QuokkaArch],
) -> pypcode.Context:
    """Convert an arch from Quokka to Pypcode

    For the moment, only the arch described in quokka.analysis are supported.
    This method is a bit slow because enum are generated by pypcode on the fly but should
    be executed only once.

    Arguments:
        arch: Quokka program architecture

    Raises:
        PypcodeError: if the conversion for arch is not found

    Returns:
        A pypcode.Context instance
    """
    names: Dict[Type[quokka.analysis.arch.QuokkaArch], str] = {
        quokka.analysis.ArchX64: "x86:LE:64:default",
        quokka.analysis.ArchX86: "x86:LE:32:default",
        quokka.analysis.ArchARM: "ARM:LE:32:v8",
        quokka.analysis.ArchARM64: "AARCH64:LE:64:v8A",
        quokka.analysis.ArchARMThumb: "ARM:LE:32:v8T",
    }

    try:
        target_id = names[arch]
    except KeyError:
        raise quokka.PypcodeError("Unable to find the appropriate arch: missing id")

    pcode_arch = get_arch_from_string(target_id)
    return pypcode.Context(pcode_arch)


def equality(self: pypcode.ContextObj, other: Any) -> bool:
    """Check if two pypcode objets are the same

    We use monkey patching to attach the equality method to other classes and rely on
    __slots__ to check which fields to check.

    Arguments:
        self: First object
        other: Other variable

    Returns:
        Boolean for equality
    """
    return isinstance(other, self.__class__) and all(
        getattr(other, attr) == getattr(self, attr)
        for attr in self.__slots__
        if attr != "cobj"
    )


def object_hash(obj: pypcode.ContextObj) -> int:
    """Create a hash value for a pypcode object

    This allows to create set of values.

    Arguments:
        obj: Object to hash

    Returns:
        An integer for the hash
    """

    assert isinstance(obj, pypcode.ContextObj)
    return sum(hash(getattr(obj, attr)) for attr in obj.__slots__ if attr != "cobj")


pypcode.Varnode.__eq__ = equality
pypcode.Varnode.__hash__ = object_hash

pypcode.AddrSpace.__eq__ = equality
pypcode.AddrSpace.__hash__ = object_hash

pypcode.PcodeOp.__eq__ = equality
pypcode.PcodeOp.__hash__ = object_hash


def combine_instructions(
    block: quokka.Block, translated_instructions: Sequence[pypcode.Translation]
) -> List[pypcode.PcodeOp]:
    """Combine instructions between the Quokka and PyPcode

    Some instruction are split between IDA and Ghidra, so we have to account for it.
    A problem for example is the support of prefixes (such LOCK) which are decoded as 2
    instructions by Ghidra (wrong) but only 1 by IDA (correct).

    Arguments:
        block: Quokka block
        translated_instructions: Translated instructions by Pypcode

    Raises
        PypcodeError: if the combination doesn't work

    Returns:
        A list of Pypcode statements
    """
    pcode_instructions: List[pypcode.PcodeOp] = []
    translated_instructions = iter(translated_instructions)

    instruction: quokka.Instruction
    for instruction in block.instructions:
        instruction._pcode_insts = []
        remaining_size: int = instruction.size
        while remaining_size > 0:
            try:
                pcode_inst: pypcode.Translation = next(translated_instructions)
            except StopIteration:
                logger.error(
                    f"Disassembly discrepancy between Pypcode / IDA: missing inst"
                )
                raise quokka.PypcodeError(
                    f"Decoding error for block at 0x{block.start:x}"
                )

            remaining_size -= pcode_inst.length
            instruction._pcode_insts.extend(pcode_inst.ops)

            if remaining_size < 0:
                logger.error(
                    f"Disassembly discrepancy between Pypcode / IDA: sizes mismatch"
                )
                raise quokka.PypcodeError(
                    f"Decoding error for block at 0x{block.start:x}"
                )

            pcode_instructions.extend(list(pcode_inst.ops))

    return pcode_instructions


def update_pypcode_context(program: quokka.Program, is_thumb: bool) -> pypcode.Context:
    """Return an appropriate pypcode context for the decoding

    For ARM architecture, if the block starts with a Thumb instruction, we must use
    a different pypcode Context.

    We use the boolean `is_thumb` directly to allow caching of the call here because it
    is costly to generate the context.

    Arguments:
        program: Program to consider
        is_thumb: Is the instruction a thumb one?

    Returns:
        The correct pypcode context
    """

    if (
        program.arch
        in (
            quokka.analysis.ArchARM,
            quokka.analysis.ArchARM64,
            quokka.analysis.ArchARMThumb,
        )
        and is_thumb
    ):
        return get_pypcode_context(quokka.analysis.ArchARMThumb)

    return program.pypcode


def pypcode_decode_block(block: quokka.Block) -> List[pypcode.PcodeOp]:
    """Decode a block at once.

    This method decode a block of instructions using Pypcode context all at once.
    This is faster than multiple calls to the decode at the instruction level.

    Arguments:
        block: Block to decode

    Returns:
        A list of pcode operations
    """

    # Fast guard, empty blocks do not have any Pcode operations
    first_instruction: Optional[quokka.Instruction] = next(block.instructions, None)
    if first_instruction is None:
        return []

    # Retrieve the context from the instruction
    context: pypcode.Context = update_pypcode_context(
        block.program, first_instruction.thumb
    )

    # Translate
    translation = context.translate(
        code=block.bytes,
        base=block.start,
        max_inst=0,
    )

    if translation.error:
        logger.error(translation.error.explain)
        raise quokka.PypcodeError(f"Decoding error for block at 0x{block.start:x}")

    pcode_instructions = combine_instructions(block, translation.instructions)
    return pcode_instructions


def pypcode_decode_instruction(
    inst: quokka.Instruction,
) -> Sequence[pypcode.PcodeOp]:
    """Decode an instruction using Pypcode

    This will return the list of Pcode operations done for the instruction.
    Note that a (binary) instruction is expected to have several pcode instructions
    associated.

    Arguments:
        inst: Instruction to translate

    Raises:
        PypcodeError: if the decoding fails

    Returns:
        A sequence of PcodeOp
    """

    context: pypcode.Context = update_pypcode_context(inst.program, inst.thumb)
    translation = context.translate(
        code=inst.bytes,
        base=inst.address,
        max_inst=1,
    )

    if not translation.error:

        instructions = translation.instructions
        if len(instructions) > 1:
            logger.warning("Mismatch of instruction size IDA/Pypcode")

        instructions = list(
            itertools.chain.from_iterable(inst.ops for inst in instructions)
        )
        return instructions

    else:
        logger.error(translation.error.explain)
        raise quokka.PypcodeError("Unable to decode instruction")
