From 1a7c2ddab0b582758456af292c439f177460df53 Mon Sep 17 00:00:00 2001 From: ebolam Date: Wed, 24 May 2023 20:14:22 -0400 Subject: [PATCH] TPU Fix? --- tpu_mtj_backend.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index d5a4d1db..bf08f745 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -460,14 +460,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ return carry class PenalizingCausalTransformer(CausalTransformer): - def __init__(self, config, **kwargs): + def __init__(self, badwordsids, config, **kwargs): # Initialize super().__init__(config, **kwargs) def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): compiling_callback() numseqs = numseqs_aux.shape[0] # These are the tokens that we don't want the AI to ever write - badwords = jnp.array(koboldai_vars.badwordsids).squeeze() + badwords = jnp.array(badwordsids).squeeze() @hk.transform def generate_sample(context, ctx_length): # Give the initial context to the transformer @@ -941,7 +941,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: +import koboldai_settings + +def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badwordsids_default driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if kwargs.get("pad_token_id"): @@ -1119,12 +1121,12 @@ def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", global badwords # These are the tokens that we don't want the AI to ever write - badwords = jnp.array(koboldai_vars.badwordsids).squeeze() + badwords = jnp.array(badwordsids).squeeze() if not path.endswith("/"): path += "/" - network = PenalizingCausalTransformer(params, dematerialized=True) + network = PenalizingCausalTransformer(badwordsids, params, dematerialized=True) if not hf_checkpoint and koboldai_vars.model != "TPUMeshTransformerGPTNeoX": network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])