Files
KoboldAI-Client/modeling/inference_models/hf_mtj/class.py
2023-05-24 19:17:51 -04:00

381 lines
14 KiB
Python

from __future__ import annotations
import os
import torch
import numpy as np
from eventlet import tpool
from typing import List, Optional, Tuple, Union
import utils
import koboldai_settings
from logger import logger, Colors
from modeling import warpers
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
ModelCapabilities,
)
from modeling.inference_models.hf import HFInferenceModel
from modeling.tokenizer import GenericTokenizer
model_backend_name = "Huggingface MTJ"
class model_backend(HFInferenceModel):
def __init__(
self,
#model_name: str,
) -> None:
super().__init__()
self.hf_torch = False
self.model_config = None
self.capabilties = ModelCapabilities(
embedding_manipulation=False,
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True,
)
def is_valid(self, model_name, model_path, menu_path):
# This file shouldn't be imported unless using the TPU
return utils.koboldai_vars.use_colab_tpu and super().is_valid(model_name, model_path, menu_path)
def setup_mtj(self) -> None:
import tpu_mtj_backend
def mtj_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
utils.koboldai_vars.lua_koboldbridge.logits = (
utils.koboldai_vars.lua_state.table()
)
for r, row in enumerate(scores_list):
utils.koboldai_vars.lua_koboldbridge.logits[
r + 1
] = utils.koboldai_vars.lua_state.table(*row)
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
scores = np.array(
tuple(
tuple(row.values())
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
def mtj_stopping_callback(
generated, n_generated
) -> Tuple[List[set], bool, bool]:
utils.koboldai_vars.generated_tkns += 1
regeneration_required = (
utils.koboldai_vars.lua_koboldbridge.regeneration_required
)
halt = (
utils.koboldai_vars.abort
or not utils.koboldai_vars.lua_koboldbridge.generating
or utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
)
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
# Not sure what the deal is with this variable. It's been undefined
# as far back as I can trace it.
global past
for i in range(utils.koboldai_vars.numseqs):
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
utils.koboldai_vars.generated_tkns
] = int(
generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()
)
if not utils.koboldai_vars.dynamicscan or halt:
return regeneration_required, halt
for i, t in enumerate(generated):
decoded = utils.decodenewlines(
self.tokenizer.decode(past[i])
) + utils.decodenewlines(
self.tokenizer.decode(
t[
tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params[
"seq"
]
+ n_generated
]
)
)
# _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
_, _, _, used_world_info = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded
)
# found -= excluded_world_info[i]
if used_world_info:
regeneration_required = True
break
return regeneration_required, halt
def mtj_compiling_callback() -> None:
print(Colors.GREEN + "TPU backend compilation triggered" + Colors.END)
utils.koboldai_vars.compiling = True
def mtj_stopped_compiling_callback() -> None:
print(Colors.GREEN + "TPU backend compilation stopped" + Colors.END)
utils.koboldai_vars.compiling = False
def mtj_settings_callback() -> dict:
sampler_order = utils.koboldai_vars.sampler_order[:]
if (
len(sampler_order) < 7
): # Add repetition penalty at beginning if it's not present
sampler_order = [6] + sampler_order
return {
"sampler_order": utils.koboldai_vars.sampler_order,
"top_p": float(utils.koboldai_vars.top_p),
"temp": float(utils.koboldai_vars.temp),
"top_k": int(utils.koboldai_vars.top_k),
"tfs": float(utils.koboldai_vars.tfs),
"typical": float(utils.koboldai_vars.typical),
"top_a": float(utils.koboldai_vars.top_a),
"repetition_penalty": float(utils.koboldai_vars.rep_pen),
"rpslope": float(utils.koboldai_vars.rep_pen_slope),
"rprange": int(utils.koboldai_vars.rep_pen_range),
}
tpu_mtj_backend.socketio = utils.socketio
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
utils.koboldai_vars.badwordsids = utils.koboldai_vars.badwordsids_neox
print(
"{0}Initializing Mesh Transformer JAX, please wait...{1}".format(
Colors.PURPLE, Colors.END
)
)
if utils.koboldai_vars.model in (
"TPUMeshTransformerGPTJ",
"TPUMeshTransformerGPTNeoX",
) and (
not utils.koboldai_vars.custmodpth
or not os.path.isdir(utils.koboldai_vars.custmodpth)
):
raise FileNotFoundError(
f"The specified model path {repr(utils.koboldai_vars.custmodpth)} is not the path to a valid folder"
)
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
tpu_mtj_backend.pad_token_id = 2
tpu_mtj_backend.koboldai_vars = utils.koboldai_vars
tpu_mtj_backend.warper_callback = mtj_warper_callback
tpu_mtj_backend.stopping_callback = mtj_stopping_callback
tpu_mtj_backend.compiling_callback = mtj_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = mtj_stopped_compiling_callback
tpu_mtj_backend.settings_callback = mtj_settings_callback
def _load(self, save_model: bool, initial_load: bool) -> None:
import tpu_mtj_backend
self.setup_mtj()
self.init_model_config()
utils.koboldai_vars.allowsp = True
logger.info(self.model_type)
tpu_mtj_backend.load_model(
self.model,
self.model_type,
hf_checkpoint=self.model
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
and utils.koboldai_vars.use_colab_tpu,
socketio_queue=koboldai_settings.queue,
initial_load=initial_load,
logger=logger,
**self.model_config.to_dict(),
)
utils.koboldai_vars.modeldim = int(
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
)
self.tokenizer = GenericTokenizer(tpu_mtj_backend.tokenizer)
if (
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
def get_soft_tokens(self) -> np.array:
import tpu_mtj_backend
soft_tokens = None
if utils.koboldai_vars.sp is None:
tensor = np.zeros(
(
1,
tpu_mtj_backend.params.get(
"d_embed", tpu_mtj_backend.params["d_model"]
),
),
dtype=np.float32,
)
rows = tensor.shape[0]
padding_amount = (
tpu_mtj_backend.params["seq"]
- (
tpu_mtj_backend.params["seq"]
% -tpu_mtj_backend.params["cores_per_replica"]
)
- rows
)
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params.get(
"d_embed", tpu_mtj_backend.params["d_model"]
),
)
utils.koboldai_vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"]
+ tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"]
+ tpu_mtj_backend.params["n_vocab_padding"]
+ utils.koboldai_vars.sp_length,
dtype=np.uint32,
)
return soft_tokens
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new: int,
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
import tpu_mtj_backend
warpers.update_settings()
soft_tokens = self.get_soft_tokens()
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
if seed is not None:
tpu_mtj_backend.set_rng_seed(seed)
if not dynamic_inference:
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(prompt_tokens),
gen_len=max_new,
temp=gen_settings.temp,
top_p=gen_settings.top_p,
top_k=gen_settings.top_k,
tfs=gen_settings.tfs,
typical=gen_settings.typical,
top_a=gen_settings.top_a,
numseqs=batch_count,
repetition_penalty=gen_settings.rep_pen,
rpslope=gen_settings.rep_pen_slope,
rprange=gen_settings.rep_pen_range,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
sampler_order=gen_settings.sampler_order,
)
genout = np.array(genout)
else:
global past
context = np.tile(
np.uint32(prompt_tokens), (utils.koboldai_vars.numseqs, 1)
)
past = np.empty((utils.koboldai_vars.numseqs, 0), dtype=np.uint32)
self.gen_state["wi_scanner_excluded_keys"] = set()
while True:
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer_dynamic,
context,
gen_len=max_new,
numseqs=utils.koboldai_vars.numseqs,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(utils.koboldai_vars.numseqs):
for c in range(utils.koboldai_vars.lua_koboldbridge.generated_cols):
assert (
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][c + 1]
is not None
)
past[r, c] = utils.koboldai_vars.lua_koboldbridge.generated[
r + 1
][c + 1]
if utils.koboldai_vars.abort or halt or not regeneration_required:
break
encoded = []
for i in range(utils.koboldai_vars.numseqs):
txt = utils.decodenewlines(self.tokenizer.decode(past[i]))
# _, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
# self.tokenizer.decode(prompt_tokens)
# )
# # utils.koboldai_vars.calc_ai_text()
# print(_found_entries)
# self.gen_state["wi_scanner_excluded_keys"].update(_found_entries)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(
tuple(
np.pad(
e,
(max_length - len(e), 0),
constant_values=tpu_mtj_backend.pad_token_id,
)
for e in encoded
)
)
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
# genout = tpool.execute(
# tpu_mtj_backend.infer_dynamic,
# context=np.uint32(prompt_tokens),
# numseqs=batch_count,
# gen_len=max_new,
# soft_embeddings=utils.koboldai_vars.sp,
# soft_tokens=soft_tokens,
# # TODO: Fix Dynamic WI on TPU
# excluded_world_info=set(),
# use_callback=True
# )
# print(genout)
# print(type(genout))
genout = np.array(genout)
return GenerationResult(
self,
out_batches=genout,
prompt=prompt_tokens,
is_whole_generation=True,
single_line=single_line,
)