This commit is contained in:
ebolam
2023-05-24 20:14:22 -04:00
parent c9523a340e
commit 1a7c2ddab0

View File

@@ -460,14 +460,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
return carry
class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config, **kwargs):
def __init__(self, badwordsids, config, **kwargs):
# Initialize
super().__init__(config, **kwargs)
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
compiling_callback()
numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write
badwords = jnp.array(koboldai_vars.badwordsids).squeeze()
badwords = jnp.array(badwordsids).squeeze()
@hk.transform
def generate_sample(context, ctx_length):
# Give the initial context to the transformer
@@ -941,7 +941,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = ""
def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
import koboldai_settings
def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badwordsids_default driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if kwargs.get("pad_token_id"):
@@ -1119,12 +1121,12 @@ def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109",
global badwords
# These are the tokens that we don't want the AI to ever write
badwords = jnp.array(koboldai_vars.badwordsids).squeeze()
badwords = jnp.array(badwordsids).squeeze()
if not path.endswith("/"):
path += "/"
network = PenalizingCausalTransformer(params, dematerialized=True)
network = PenalizingCausalTransformer(badwordsids, params, dematerialized=True)
if not hf_checkpoint and koboldai_vars.model != "TPUMeshTransformerGPTNeoX":
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])