diff --git a/aiserver.py b/aiserver.py index 60e1a4a5..8da79094 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1035,7 +1035,7 @@ else: assert len(excluded_world_info) == len(generated) regeneration_required = vars.lua_koboldbridge.regeneration_required - halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt + halt = vars.abort or not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt vars.lua_koboldbridge.regeneration_required = False global past @@ -1061,6 +1061,15 @@ else: def tpumtjgenerate_stopped_compiling_callback() -> None: vars.compiling = False + + def tpumtjgenerate_settings_callback() -> dict: + return { + "top_p": float(vars.top_p), + "temp": float(vars.temp), + "top_k": int(vars.top_k), + "tfs": float(vars.tfs), + "repetition_penalty": float(vars.rep_pen), + } # If we're running Colab or OAI, we still need a tokenizer. if(vars.model == "Colab"): @@ -3009,12 +3018,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): tpu_mtj_backend.infer_dynamic, context, gen_len = maximum-minimum+1, - temp=vars.temp, - top_p=vars.top_p, - top_k=vars.top_k, - tfs=vars.tfs, numseqs=vars.numseqs, - repetition_penalty=vars.rep_pen, soft_embeddings=vars.sp, soft_tokens=soft_tokens, excluded_world_info=found_entries, @@ -3026,7 +3030,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): assert vars.lua_koboldbridge.generated[r+1][c+1] is not None past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] - if(halt or not regeneration_required): + if(vars.abort or halt or not regeneration_required): break print("(regeneration triggered)") diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 70f27b32..3b3f48e7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -26,6 +26,15 @@ def warper_callback(logits) -> np.array: def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") +def settings_callback() -> dict: + return { + "top_p": 0.9, + "temp": 0.5, + "top_k": 0, + "tfs": 1.0, + "repetition_penalty": 1.0, + } + def started_compiling_callback() -> None: pass @@ -541,7 +550,7 @@ class PenalizingCausalTransformer(CausalTransformer): out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): + def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): assert excluded_world_info is not None assert not return_logits assert gen_length.ndim == 1 @@ -560,7 +569,6 @@ class PenalizingCausalTransformer(CausalTransformer): ] for i in range(numseqs) ] - repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) n_generated = 0 regeneration_required = False halt = False @@ -576,6 +584,8 @@ class PenalizingCausalTransformer(CausalTransformer): logits = warper_callback(logits) for i in range(numseqs): sample_data[i][2] = logits[i] + sampler_options = settings_callback() + repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) n_generated += 1 for i in range(numseqs): @@ -611,11 +621,6 @@ class PenalizingCausalTransformer(CausalTransformer): def infer_dynamic( context: np.array, - top_p=0.9, - temp=0.5, - top_k=0, - tfs=1.0, - repetition_penalty=1.0, numseqs=1, gen_len=80, soft_embeddings: Optional[np.array] = None, @@ -634,19 +639,11 @@ def infer_dynamic( padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) samples = [] - generator_params = { - "temp": float(temp), - "top_p": float(top_p), - "tfs": float(tfs), - "repetition_penalty": float(repetition_penalty), - "top_k": int(top_k), - } output = network.generate_dynamic( batched_tokens, np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * gen_len, numseqs, - generator_params, soft_embeddings=soft_embeddings, excluded_world_info=excluded_world_info, use_callback=use_callback,