From 38c4edac400a6bac53fc93355fe8d15f5c3a7e9d Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 10 Mar 2023 18:36:22 -0600 Subject: [PATCH] Model: Fix eos/bos padding issue Weird config None assignments --- aiserver.py | 21 +-------------------- tpu_mtj_backend.py | 4 ++-- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/aiserver.py b/aiserver.py index 799c6089..0577fe99 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5493,26 +5493,7 @@ def final_startup(): # Precompile TPU backend if required if isinstance(model, HFMTJInferenceModel): - import tpu_mtj_backend - soft_tokens = model.get_soft_tokens() - if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)): - tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)), - soft_embeddings= koboldai_vars.sp, - soft_tokens= soft_tokens, - gen_len= 1, - use_callback= False, - numseqs= koboldai_vars.numseqs, - excluded_world_info= list(set() for _ in range(koboldai_vars.numseqs)) - ) - else: - tpool.execute( - tpu_mtj_backend.infer_static, - np.uint32((23403, 727, 20185)), - soft_embeddings= koboldai_vars.sp, - soft_tokens= soft_tokens, - gen_len= 1, - numseqs= koboldai_vars.numseqs - ) + model.raw_generate([23403, 727, 20185], max_new=1) # Set the initial RNG seed set_seed() diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a20cb213..55f382f7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -779,9 +779,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id - if "pad_token_id" in kwargs: + if kwargs.get("pad_token_id"): pad_token_id = kwargs["pad_token_id"] - elif "eos_token_id" in kwargs: + elif kwargs.get("eos_token_id"): pad_token_id = kwargs["eos_token_id"] if not hasattr(koboldai_vars, "sampler_order") or not koboldai_vars.sampler_order: