From d9d24902aef2480aa6c4dd2d58d0d5a77b48db73 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 22 Sep 2022 20:10:23 -0500 Subject: [PATCH] Work on mtj --- aiserver.py | 24 +++++++++++++----------- koboldai_settings.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/aiserver.py b/aiserver.py index 3e6220ab..95d92bec 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4841,23 +4841,25 @@ def calcsubmit(txt): # Send it! ikrequest(subtxt) -def __debug(text): - print(f"[DBG] {text}") +def __debug(*args): + print("[DBG] ", *args) def core_generate(text: list, min: int, max: int, found_entries: set): # This generation function is tangled with koboldai_vars intentionally. It # is meant for the story and nothing else. - if koboldai_vars.full_determinism: - torch.manual_seed(koboldai_vars.seed) + if koboldai_vars.is_model_torch(): + # Torch stuff + if koboldai_vars.full_determinism: + torch.manual_seed(koboldai_vars.seed) - gen_in = torch.tensor(text, dtype=torch.long)[None] - if koboldai_vars.sp is not None: - soft_tokens = torch.arange( - model.config.vocab_size, - model.config.vocab_size + koboldai_vars.sp.shape[0], - ) - gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) + gen_in = torch.tensor(text, dtype=torch.long)[None] + if koboldai_vars.sp is not None: + soft_tokens = torch.arange( + model.config.vocab_size, + model.config.vocab_size + koboldai_vars.sp.shape[0], + ) + gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length diff --git a/koboldai_settings.py b/koboldai_settings.py index 9aa539e3..2279b532 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -287,6 +287,15 @@ class koboldai_vars(object): self.context = context return tokens, used_tokens, used_tokens+self.genamt + + def is_model_torch(self) -> bool: + if self.use_colab_tpu: + return False + + if self.model in ["Colab", "API", "CLUSTER", "ReadOnly", "OAI"]: + return False + + return True def __setattr__(self, name, value): if name[0] == "_" or name == "tokenizer": @@ -443,7 +452,7 @@ class model_settings(settings): self.uid_presets = [] self.default_preset = {} self.cluster_requested_models = [] # The models which we allow to generate during cluster mode - + #dummy class to eat the tqdm output class ignore_tqdm(object):