from threading import Timer
import re
import shutil
import json
import subprocess
import tempfile
import requests
import requests.adapters
import time
from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
import packaging.version
from tqdm.auto import tqdm
import os
import itertools
from typing import List, Optional

HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
try:
    import accelerate
except ImportError:
    HAS_ACCELERATE = False

vars = None
num_shards: Optional[int] = None
current_shard = 0
from_pretrained_model_name = ""
from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {}
bar = None

layers_module_names: Optional[List[str]] = None
module_names: Optional[List[str]] = None
named_buffers: Optional[List[tuple]] = None

default_sampler_order = [6, 0, 1, 2, 3, 4, 5]

#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called
#==================================================================#
def debounce(wait): 
    def decorator(fun):
        def debounced(*args, **kwargs):
            def call_it():
                fun(*args, **kwargs)
 
            try:
                debounced.t.cancel()
            except AttributeError:
                pass
 
            debounced.t = Timer(wait, call_it)
            debounced.t.start()
 
        return debounced
 
    return decorator

#==================================================================#
# Replace fancy quotes and apostrope's with standard ones
#==================================================================#
def fixquotes(txt):
    txt = txt.replace("“", '"')
    txt = txt.replace("”", '"')
    txt = txt.replace("’", "'")
    txt = txt.replace("`", "'")
    return txt

#==================================================================#
# 
#==================================================================#
def trimincompletesentence(txt):
    # Cache length of text
    ln = len(txt)
    # Find last instance of punctuation (Borrowed from Clover-Edition by cloveranon)
    lastpunc = max(txt.rfind("."), txt.rfind("!"), txt.rfind("?"))
    # Is this the end of a quote?
    if(lastpunc < ln-1):
        if(txt[lastpunc+1] == '"'):
            lastpunc = lastpunc + 1
    if(lastpunc >= 0):
        txt = txt[:lastpunc+1]
    return txt

#==================================================================#
# 
#==================================================================#
def replaceblanklines(txt):
    txt = txt.replace("\n\n", "\n")
    return txt

#==================================================================#
# 
#==================================================================#
def removespecialchars(txt, vars=None):
    if vars is None or vars.actionmode == 0:
        txt = re.sub(r"[#/@%<>{}+=~|\^]", "", txt)
    else:
        txt = re.sub(r"[#/@%{}+=~|\^]", "", txt)
    return txt

#==================================================================#
# If the next action follows a sentence closure, add a space
#==================================================================#
def addsentencespacing(txt, vars):
    # Don't add sentence spacing if submission is empty or starts with whitespace
    if(len(txt) == 0 or len(txt) != len(txt.lstrip())):
        return txt
    # Get last character of last action
    if(len(vars.actions) > 0):
        if(len(vars.actions[vars.actions.get_last_key()]) > 0):
            action = vars.actions[vars.actions.get_last_key()]
            lastchar = action[-1] if len(action) else ""
        else:
            # Last action is blank, this should never happen, but
            # since it did let's bail out.
            return txt
    else:
        action = vars.prompt
        lastchar = action[-1] if len(action) else ""
    if(lastchar != " "):
        txt = " " + txt
    return txt
	
def singlelineprocessing(txt, vars):
    txt = vars.regex_sl.sub('', txt)
    if(len(vars.actions) > 0):
        if(len(vars.actions[vars.actions.get_last_key()]) > 0):
            action = vars.actions[vars.actions.get_last_key()]
            lastchar = action[-1] if len(action) else ""
        else:
            # Last action is blank, this should never happen, but
            # since it did let's bail out.
            return txt
    else:
        action = vars.prompt
        lastchar = action[-1] if len(action) else ""
    if(lastchar != "\n"):
        txt = txt + "\n"
    return txt

#==================================================================#
#  Cleans string for use in file name
#==================================================================#
def cleanfilename(filename):
    filteredcharacters = ('/','\\')
    filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
    return filename
    
#==================================================================#
#  Newline substitution for fairseq models
#==================================================================#
def encodenewlines(txt):
    if(vars.newlinemode == "s"):
        return txt.replace('\n', "</s>")
    return txt

def decodenewlines(txt):
    if(vars.newlinemode == "s"):
        return txt.replace("</s>", '\n')
    if(vars.newlinemode == "ns"):
        return txt.replace("</s>", '')
    return txt

#==================================================================#
#  Returns number of layers given an HF model config
#==================================================================#
def num_layers(config):
    return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None

#==================================================================#
#  Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
from flask_socketio import emit
class Send_to_socketio(object):
    def write(self, bar):
        time.sleep(0.01)
        try:
            print(bar)
            emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
        except:
            pass
            
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
    import transformers
    import transformers.modeling_utils
    from huggingface_hub import HfFolder
    if shutil.which("aria2c") is None:  # Don't do anything if aria2 is not installed
        return
    if local_files_only:  # If local_files_only is true, we obviously don't need to download anything
        return
    if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
        return
    if proxies:
        print("WARNING:  KoboldAI does not support using aria2 to download models from huggingface.co through a proxy.  Disabling aria2 download mode.")
        return
    if use_auth_token:
        if isinstance(use_auth_token, str):
            token = use_auth_token
        else:
            token = HfFolder.get_token()
            if token is None:
                raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
    _cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
    sharded = False
    headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
    if use_auth_token:
        headers["authorization"] = f"Bearer {use_auth_token}"
    def is_cached(url):
        try:
            transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
        except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
            return False
        return True
    while True:  # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
        try:
            filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
        except AttributeError:
            return
        url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
        if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
            break
        if sharded:
            return
        else:
            sharded = True
    if not sharded:  # If the model has a pytorch_model.bin file, that's the only file to download
        filenames = [transformers.modeling_utils.WEIGHTS_NAME]
    else:  # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
        map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
        with open(map_filename) as f:
            map_data = json.load(f)
        filenames = set(map_data["weight_map"].values())
    urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
    if not force_download:
        urls = [u for u in urls if not is_cached(u)]
        if not urls:
            return
    etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
    headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
    filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
    for n in filenames:
        path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
        if os.path.exists(path):
            os.remove(path)
        path = os.path.join(_cache_dir, "kai-tempfile." + n)
        if os.path.exists(path):
            os.remove(path)
        if force_download:
            path = os.path.join(_cache_dir, n + ".json")
            if os.path.exists(path):
                os.remove(path)
            path = os.path.join(_cache_dir, n)
            if os.path.exists(path):
                os.remove(path)
    total_length = sum(int(h["Content-Length"]) for h in headers)
    lengths = {}
    aria2_config = "\n".join(f"{u}\n  out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
    s = requests.Session()
    s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
    bar = None
    done = False
    secret = os.urandom(17).hex()
    try:
        with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
            f.write(aria2_config)
            f.flush()
            p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            while p.poll() is None:
                r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
                if not r:
                    s.close()
                    if bar is not None:
                        bar.n = bar.total
                        bar.close()
                    p.terminate()
                    done = True
                    break
                if bar is None:
                    bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
                visited = set()
                for x in r:
                    filename = x["files"][0]["path"]
                    lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
                    visited.add(filename)
                for k, v in lengths.items():
                    if k not in visited:
                        lengths[k] = (v[1], v[1])
                bar.n = sum(v[0] for v in lengths.values())
                bar.update()
                time.sleep(0.1)
            path = f.name
    except Exception as e:
        p.terminate()
        raise e
    finally:
        try:
            os.remove(path)
        except OSError:
            pass
    code = p.wait()
    if not done and code:
        raise OSError(f"aria2 exited with exit code {code}")
    for u, t, n in zip(urls, etags, filenames):
        os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
        with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
            json.dump({"url": u, "etag": t}, f)

#==================================================================#
#  Given the path to a pytorch_model.bin.index.json, returns how many
#  shards there are in the model
#==================================================================#
def get_num_shards(filename):
    with open(filename) as f:
        map_data = json.load(f)
    return len(set(map_data["weight_map"].values()))

#==================================================================#
#  Given the name/path of a sharded model and the path to a
#  pytorch_model.bin.index.json, returns a list of weight names in the
#  sharded model.  Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
    import transformers.modeling_utils
    import torch
    shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
    return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))

#==================================================================#
#  Given a PreTrainedModel, returns the list of module names that correspond
#  to the model's hidden layers.
#==================================================================#
def get_layers_module_names(model: PreTrainedModel) -> List[str]:
    names: List[str] = []
    def recurse(module, head=""):
        for c in module.named_children():
            name = head + c[0]
            if c[0].isnumeric() and any(c[1].__class__.__name__.endswith(suffix) for suffix in ("Block", "Layer")):
                names.append(name)
            else:
                recurse(c[1], head=name + ".")
    recurse(model)
    return names

#==================================================================#
#  Given a PreTrainedModel, returns the module name that corresponds
#  to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name(model: PreTrainedModel) -> str:
    embeddings = model.get_input_embeddings()
    def recurse(module, head=""):
        for c in module.named_children():
            name = head + c[0]
            if c[1] is embeddings:
                return name
            else:
                return recurse(c[1], head=name + ".")
    return recurse(model)

#==================================================================#
#  Given a PreTrainedModel and a list of module names, returns a list
#  of module names such that the union of the set of modules given as input
#  and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]:
    missing_names: List[str] = []
    def recurse(module, head=""):
        for c in module.named_children():
            name = head + c[0]
            if any(name.startswith(n) for n in names):
                continue
            if next(c[1].named_children(), None) is None:
                missing_names.append(name)
            else:
                recurse(c[1], head=name + ".")
    recurse(model)
    return missing_names