mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import requests
|
|
from typing import Iterable, List
|
|
from tqdm.auto import tqdm
|
|
|
|
import transformers
|
|
from transformers import (
|
|
PreTrainedModel,
|
|
modeling_utils,
|
|
)
|
|
|
|
import utils
|
|
|
|
|
|
def patch_transformers_download():
|
|
def http_get(
|
|
url: str,
|
|
temp_file,
|
|
proxies=None,
|
|
resume_size=0,
|
|
headers=None,
|
|
file_name=None,
|
|
):
|
|
"""
|
|
Download remote file. Do not gobble up errors.
|
|
"""
|
|
headers = copy.deepcopy(headers)
|
|
if resume_size > 0:
|
|
headers["Range"] = f"bytes={resume_size}-"
|
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
|
transformers.utils.hub._raise_for_status(r)
|
|
content_length = r.headers.get("Content-Length")
|
|
total = (
|
|
resume_size + int(content_length) if content_length is not None else None
|
|
)
|
|
|
|
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
|
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
|
if url[-11:] != "config.json":
|
|
progress = tqdm.tqdm(
|
|
unit="B",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
total=total,
|
|
initial=resume_size,
|
|
desc=f"Downloading {file_name}"
|
|
if file_name is not None
|
|
else "Downloading",
|
|
file=utils.UIProgressBarFile(),
|
|
)
|
|
utils.koboldai_vars.status_message = "Download Model"
|
|
utils.koboldai_vars.total_download_chunks = total
|
|
|
|
for chunk in r.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
if url[-11:] != "config.json":
|
|
progress.update(len(chunk))
|
|
utils.koboldai_vars.downloaded_chunks += len(chunk)
|
|
temp_file.write(chunk)
|
|
|
|
if url[-11:] != "config.json":
|
|
progress.close()
|
|
|
|
utils.koboldai_vars.status_message = ""
|
|
|
|
transformers.utils.hub.http_get = http_get
|
|
|
|
|
|
def patch_transformers_loader() -> None:
|
|
"""
|
|
Patch the Transformers loader to use aria2 and our shard tracking.
|
|
Universal for TPU/MTJ and Torch.
|
|
"""
|
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
|
|
|
@classmethod
|
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
utils.koboldai_vars.fp32_model = False
|
|
utils.num_shards = None
|
|
utils.current_shard = 0
|
|
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
|
utils.from_pretrained_index_filename = None
|
|
utils.from_pretrained_kwargs = kwargs
|
|
utils.bar = None
|
|
if not utils.args.no_aria2:
|
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
|
return old_from_pretrained(
|
|
cls, pretrained_model_name_or_path, *model_args, **kwargs
|
|
)
|
|
|
|
if not hasattr(PreTrainedModel, "_kai_patched"):
|
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
|
PreTrainedModel._kai_patched = True
|
|
|
|
if hasattr(modeling_utils, "get_checkpoint_shard_files"):
|
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
|
|
|
def new_get_checkpoint_shard_files(
|
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
|
):
|
|
utils.num_shards = utils.get_num_shards(index_filename)
|
|
utils.from_pretrained_index_filename = index_filename
|
|
return old_get_checkpoint_shard_files(
|
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
|
)
|
|
|
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
|
|
|
|
|
def patch_transformers_generation() -> None:
|
|
# Not sure why this global is needed...
|
|
global transformers
|
|
|
|
# Allow bad words filter to ban <|endoftext|> token
|
|
import transformers.generation.logits_process
|
|
|
|
def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int):
|
|
return new_init.old_init(self, bad_words_ids, -1)
|
|
|
|
new_init.old_init = (
|
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__
|
|
)
|
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
|
|
|
|
|
def patch_transformers() -> None:
|
|
patch_transformers_download()
|
|
patch_transformers_loader()
|
|
|
|
# Doesn't do anything for TPU
|
|
patch_transformers_generation()
|