from dotenv import load_dotenv
from json import dumps, loads, dump
from os import getenv
from os.path import join
from requests import post
#
from utils import plg_utils



def strip_comments(code, start_marker="// ---- IDAxLM", end_marker="// ----"):
    result = []
    inside_block = False

    for line in code.splitlines():
        if line.startswith(start_marker):
            inside_block = True
        elif inside_block and line.startswith(end_marker):
            inside_block = False
        elif not inside_block:
            result.append(line)

    return "\n".join(result)


class LmDialer():
    TOKEN_TAG = '<token>'
    MODEL_TAG = '<model>'
    QUERY_TAG = '<query>'

    @staticmethod
    def _load_config_data(cfg_path='data/config.json'):
        data = plg_utils.get_json(cfg_path)
        return data

    cfg_data = _load_config_data()

    def __init__(self):
        pass

    def set_data(self, pvd_name, mdl_name):
        svc_data = LmDialer.cfg_data['providers'][pvd_name]

        if mdl_name not in svc_data['models']:
            raise ValueError("Model '{}' is not supported".format(model))

        self.endpoint, self.headers, self.payload = map(
            svc_data.get, ['endpoint', 'headers', 'payload']
        )

        token = getenv("API_TOKEN_{}".format(pvd_name.upper()), "")

        self.loadToken(token)
        self.loadModel(mdl_name)

    def loadToken(self, token):
        header_str = dumps(self.headers)
        if LmDialer.TOKEN_TAG in header_str:
            self.headers = loads(header_str.replace(LmDialer.TOKEN_TAG, token))
        if LmDialer.TOKEN_TAG in self.endpoint:
            self.endpoint = self.endpoint.replace(LmDialer.TOKEN_TAG, token)

    def loadModel(self, mdl_name):
        payload_str = dumps(self.payload)
        if LmDialer.MODEL_TAG in payload_str:
            self.payload = loads(payload_str.replace(LmDialer.MODEL_TAG, mdl_name))
        if LmDialer.MODEL_TAG in self.endpoint:
            self.endpoint = self.endpoint.replace(LmDialer.MODEL_TAG, mdl_name)

    def loadQuery(self, query):
        if isinstance(query, list):
            if 'messages' in self.payload:
                self.payload['messages'] = query
                return self.payload
            else:
                escaped_multiline_text = dumps(query)[1:-1]
        elif isinstance(query, str):
            escaped_multiline_text = dumps(query)[1:-1]

        payload = dumps(self.payload)
        payload = loads(payload.replace(LmDialer.QUERY_TAG, escaped_multiline_text))
        return payload


    def postData(self, query, is_dry=False):
        payload = self.loadQuery(query)
        if is_dry:
            return "```python\nimport os\ndef func():\n\tpass\n\n```"
        else:
            response = post(self.endpoint, headers=self.headers, json=payload)
            if response.status_code == 200:
                gpt_response = response.json()['choices'][0]['message']['content']
                return gpt_response
            else:
                return {"error": response.status_code, "message": response.text}
