mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
381 lines
14 KiB
Python
381 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
from eventlet import tpool
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import utils
|
|
import koboldai_settings
|
|
from logger import logger, Colors
|
|
|
|
from modeling import warpers
|
|
from modeling.inference_model import (
|
|
GenerationResult,
|
|
GenerationSettings,
|
|
ModelCapabilities,
|
|
)
|
|
from modeling.inference_models.hf import HFInferenceModel
|
|
from modeling.tokenizer import GenericTokenizer
|
|
|
|
model_backend_name = "Huggingface MTJ"
|
|
|
|
|
|
class model_backend(HFInferenceModel):
|
|
def __init__(
|
|
self,
|
|
#model_name: str,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hf_torch = False
|
|
self.model_config = None
|
|
self.capabilties = ModelCapabilities(
|
|
embedding_manipulation=False,
|
|
post_token_hooks=False,
|
|
stopper_hooks=False,
|
|
post_token_probs=False,
|
|
uses_tpu=True,
|
|
)
|
|
|
|
def is_valid(self, model_name, model_path, menu_path):
|
|
# This file shouldn't be imported unless using the TPU
|
|
return utils.koboldai_vars.use_colab_tpu and super().is_valid(model_name, model_path, menu_path)
|
|
|
|
def setup_mtj(self) -> None:
|
|
import tpu_mtj_backend
|
|
def mtj_warper_callback(scores) -> "np.array":
|
|
scores_shape = scores.shape
|
|
scores_list = scores.tolist()
|
|
utils.koboldai_vars.lua_koboldbridge.logits = (
|
|
utils.koboldai_vars.lua_state.table()
|
|
)
|
|
for r, row in enumerate(scores_list):
|
|
utils.koboldai_vars.lua_koboldbridge.logits[
|
|
r + 1
|
|
] = utils.koboldai_vars.lua_state.table(*row)
|
|
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
|
|
|
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
|
|
|
scores = np.array(
|
|
tuple(
|
|
tuple(row.values())
|
|
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
|
),
|
|
dtype=scores.dtype,
|
|
)
|
|
assert scores.shape == scores_shape
|
|
|
|
return scores
|
|
|
|
def mtj_stopping_callback(
|
|
generated, n_generated
|
|
) -> Tuple[List[set], bool, bool]:
|
|
utils.koboldai_vars.generated_tkns += 1
|
|
|
|
regeneration_required = (
|
|
utils.koboldai_vars.lua_koboldbridge.regeneration_required
|
|
)
|
|
halt = (
|
|
utils.koboldai_vars.abort
|
|
or not utils.koboldai_vars.lua_koboldbridge.generating
|
|
or utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
|
)
|
|
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
|
|
|
|
# Not sure what the deal is with this variable. It's been undefined
|
|
# as far back as I can trace it.
|
|
global past
|
|
|
|
for i in range(utils.koboldai_vars.numseqs):
|
|
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
|
|
utils.koboldai_vars.generated_tkns
|
|
] = int(
|
|
generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()
|
|
)
|
|
|
|
if not utils.koboldai_vars.dynamicscan or halt:
|
|
return regeneration_required, halt
|
|
|
|
for i, t in enumerate(generated):
|
|
decoded = utils.decodenewlines(
|
|
self.tokenizer.decode(past[i])
|
|
) + utils.decodenewlines(
|
|
self.tokenizer.decode(
|
|
t[
|
|
tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params[
|
|
"seq"
|
|
]
|
|
+ n_generated
|
|
]
|
|
)
|
|
)
|
|
# _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
|
_, _, _, used_world_info = utils.koboldai_vars.calc_ai_text(
|
|
submitted_text=decoded
|
|
)
|
|
# found -= excluded_world_info[i]
|
|
if used_world_info:
|
|
regeneration_required = True
|
|
break
|
|
return regeneration_required, halt
|
|
|
|
def mtj_compiling_callback() -> None:
|
|
print(Colors.GREEN + "TPU backend compilation triggered" + Colors.END)
|
|
utils.koboldai_vars.compiling = True
|
|
|
|
def mtj_stopped_compiling_callback() -> None:
|
|
print(Colors.GREEN + "TPU backend compilation stopped" + Colors.END)
|
|
utils.koboldai_vars.compiling = False
|
|
|
|
def mtj_settings_callback() -> dict:
|
|
sampler_order = utils.koboldai_vars.sampler_order[:]
|
|
if (
|
|
len(sampler_order) < 7
|
|
): # Add repetition penalty at beginning if it's not present
|
|
sampler_order = [6] + sampler_order
|
|
return {
|
|
"sampler_order": utils.koboldai_vars.sampler_order,
|
|
"top_p": float(utils.koboldai_vars.top_p),
|
|
"temp": float(utils.koboldai_vars.temp),
|
|
"top_k": int(utils.koboldai_vars.top_k),
|
|
"tfs": float(utils.koboldai_vars.tfs),
|
|
"typical": float(utils.koboldai_vars.typical),
|
|
"top_a": float(utils.koboldai_vars.top_a),
|
|
"repetition_penalty": float(utils.koboldai_vars.rep_pen),
|
|
"rpslope": float(utils.koboldai_vars.rep_pen_slope),
|
|
"rprange": int(utils.koboldai_vars.rep_pen_range),
|
|
}
|
|
|
|
tpu_mtj_backend.socketio = utils.socketio
|
|
|
|
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
|
|
utils.koboldai_vars.badwordsids = utils.koboldai_vars.badwordsids_neox
|
|
|
|
print(
|
|
"{0}Initializing Mesh Transformer JAX, please wait...{1}".format(
|
|
Colors.PURPLE, Colors.END
|
|
)
|
|
)
|
|
if utils.koboldai_vars.model in (
|
|
"TPUMeshTransformerGPTJ",
|
|
"TPUMeshTransformerGPTNeoX",
|
|
) and (
|
|
not utils.koboldai_vars.custmodpth
|
|
or not os.path.isdir(utils.koboldai_vars.custmodpth)
|
|
):
|
|
raise FileNotFoundError(
|
|
f"The specified model path {repr(utils.koboldai_vars.custmodpth)} is not the path to a valid folder"
|
|
)
|
|
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
|
|
tpu_mtj_backend.pad_token_id = 2
|
|
|
|
tpu_mtj_backend.koboldai_vars = utils.koboldai_vars
|
|
tpu_mtj_backend.warper_callback = mtj_warper_callback
|
|
tpu_mtj_backend.stopping_callback = mtj_stopping_callback
|
|
tpu_mtj_backend.compiling_callback = mtj_compiling_callback
|
|
tpu_mtj_backend.stopped_compiling_callback = mtj_stopped_compiling_callback
|
|
tpu_mtj_backend.settings_callback = mtj_settings_callback
|
|
|
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
|
import tpu_mtj_backend
|
|
self.setup_mtj()
|
|
self.init_model_config()
|
|
utils.koboldai_vars.allowsp = True
|
|
|
|
logger.info(self.model_type)
|
|
tpu_mtj_backend.load_model(
|
|
utils.koboldai_vars.model,
|
|
self.model_type,
|
|
hf_checkpoint=utils.koboldai_vars.model
|
|
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
|
and utils.koboldai_vars.use_colab_tpu,
|
|
socketio_queue=koboldai_settings.queue,
|
|
initial_load=initial_load,
|
|
logger=logger,
|
|
**self.model_config.to_dict(),
|
|
)
|
|
|
|
utils.koboldai_vars.modeldim = int(
|
|
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
|
|
)
|
|
self.tokenizer = GenericTokenizer(tpu_mtj_backend.tokenizer)
|
|
|
|
if (
|
|
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
|
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
|
|
):
|
|
utils.koboldai_vars.badwordsids = [
|
|
[v]
|
|
for k, v in self.tokenizer.get_vocab().items()
|
|
if any(c in str(k) for c in "[]")
|
|
]
|
|
|
|
def get_soft_tokens(self) -> np.array:
|
|
import tpu_mtj_backend
|
|
soft_tokens = None
|
|
|
|
if utils.koboldai_vars.sp is None:
|
|
tensor = np.zeros(
|
|
(
|
|
1,
|
|
tpu_mtj_backend.params.get(
|
|
"d_embed", tpu_mtj_backend.params["d_model"]
|
|
),
|
|
),
|
|
dtype=np.float32,
|
|
)
|
|
rows = tensor.shape[0]
|
|
padding_amount = (
|
|
tpu_mtj_backend.params["seq"]
|
|
- (
|
|
tpu_mtj_backend.params["seq"]
|
|
% -tpu_mtj_backend.params["cores_per_replica"]
|
|
)
|
|
- rows
|
|
)
|
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
|
tensor = tensor.reshape(
|
|
tpu_mtj_backend.params["cores_per_replica"],
|
|
-1,
|
|
tpu_mtj_backend.params.get(
|
|
"d_embed", tpu_mtj_backend.params["d_model"]
|
|
),
|
|
)
|
|
utils.koboldai_vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
|
|
|
soft_tokens = np.arange(
|
|
tpu_mtj_backend.params["n_vocab"]
|
|
+ tpu_mtj_backend.params["n_vocab_padding"],
|
|
tpu_mtj_backend.params["n_vocab"]
|
|
+ tpu_mtj_backend.params["n_vocab_padding"]
|
|
+ utils.koboldai_vars.sp_length,
|
|
dtype=np.uint32,
|
|
)
|
|
return soft_tokens
|
|
|
|
def _raw_generate(
|
|
self,
|
|
prompt_tokens: Union[List[int], torch.Tensor],
|
|
max_new: int,
|
|
gen_settings: GenerationSettings,
|
|
single_line: bool = False,
|
|
batch_count: int = 1,
|
|
seed: Optional[int] = None,
|
|
**kwargs,
|
|
) -> GenerationResult:
|
|
import tpu_mtj_backend
|
|
warpers.update_settings()
|
|
|
|
soft_tokens = self.get_soft_tokens()
|
|
|
|
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
|
|
|
if seed is not None:
|
|
tpu_mtj_backend.set_rng_seed(seed)
|
|
|
|
if not dynamic_inference:
|
|
genout = tpool.execute(
|
|
tpu_mtj_backend.infer_static,
|
|
np.uint32(prompt_tokens),
|
|
gen_len=max_new,
|
|
temp=gen_settings.temp,
|
|
top_p=gen_settings.top_p,
|
|
top_k=gen_settings.top_k,
|
|
tfs=gen_settings.tfs,
|
|
typical=gen_settings.typical,
|
|
top_a=gen_settings.top_a,
|
|
numseqs=batch_count,
|
|
repetition_penalty=gen_settings.rep_pen,
|
|
rpslope=gen_settings.rep_pen_slope,
|
|
rprange=gen_settings.rep_pen_range,
|
|
soft_embeddings=utils.koboldai_vars.sp,
|
|
soft_tokens=soft_tokens,
|
|
sampler_order=gen_settings.sampler_order,
|
|
)
|
|
genout = np.array(genout)
|
|
else:
|
|
global past
|
|
context = np.tile(
|
|
np.uint32(prompt_tokens), (utils.koboldai_vars.numseqs, 1)
|
|
)
|
|
past = np.empty((utils.koboldai_vars.numseqs, 0), dtype=np.uint32)
|
|
self.gen_state["wi_scanner_excluded_keys"] = set()
|
|
|
|
while True:
|
|
genout, n_generated, regeneration_required, halt = tpool.execute(
|
|
tpu_mtj_backend.infer_dynamic,
|
|
context,
|
|
gen_len=max_new,
|
|
numseqs=utils.koboldai_vars.numseqs,
|
|
soft_embeddings=utils.koboldai_vars.sp,
|
|
soft_tokens=soft_tokens,
|
|
)
|
|
|
|
past = np.pad(past, ((0, 0), (0, n_generated)))
|
|
for r in range(utils.koboldai_vars.numseqs):
|
|
for c in range(utils.koboldai_vars.lua_koboldbridge.generated_cols):
|
|
assert (
|
|
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][c + 1]
|
|
is not None
|
|
)
|
|
past[r, c] = utils.koboldai_vars.lua_koboldbridge.generated[
|
|
r + 1
|
|
][c + 1]
|
|
|
|
if utils.koboldai_vars.abort or halt or not regeneration_required:
|
|
break
|
|
|
|
encoded = []
|
|
for i in range(utils.koboldai_vars.numseqs):
|
|
txt = utils.decodenewlines(self.tokenizer.decode(past[i]))
|
|
# _, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
|
|
# self.tokenizer.decode(prompt_tokens)
|
|
# )
|
|
# # utils.koboldai_vars.calc_ai_text()
|
|
# print(_found_entries)
|
|
# self.gen_state["wi_scanner_excluded_keys"].update(_found_entries)
|
|
encoded.append(np.array(txt, dtype=np.uint32))
|
|
|
|
max_length = len(max(encoded, key=len))
|
|
encoded = np.stack(
|
|
tuple(
|
|
np.pad(
|
|
e,
|
|
(max_length - len(e), 0),
|
|
constant_values=tpu_mtj_backend.pad_token_id,
|
|
)
|
|
for e in encoded
|
|
)
|
|
)
|
|
context = np.concatenate(
|
|
(
|
|
encoded,
|
|
past,
|
|
),
|
|
axis=-1,
|
|
)
|
|
# genout = tpool.execute(
|
|
# tpu_mtj_backend.infer_dynamic,
|
|
# context=np.uint32(prompt_tokens),
|
|
# numseqs=batch_count,
|
|
# gen_len=max_new,
|
|
# soft_embeddings=utils.koboldai_vars.sp,
|
|
# soft_tokens=soft_tokens,
|
|
# # TODO: Fix Dynamic WI on TPU
|
|
# excluded_world_info=set(),
|
|
# use_callback=True
|
|
# )
|
|
# print(genout)
|
|
# print(type(genout))
|
|
genout = np.array(genout)
|
|
|
|
return GenerationResult(
|
|
self,
|
|
out_batches=genout,
|
|
prompt=prompt_tokens,
|
|
is_whole_generation=True,
|
|
single_line=single_line,
|
|
)
|