Work on mtj

This commit is contained in:
somebody
2022-09-22 20:10:23 -05:00
parent ae90c39f72
commit d9d24902ae
2 changed files with 23 additions and 12 deletions

View File

@@ -4841,23 +4841,25 @@ def calcsubmit(txt):
# Send it!
ikrequest(subtxt)
def __debug(text):
print(f"[DBG] {text}")
def __debug(*args):
print("[DBG] ", *args)
def core_generate(text: list, min: int, max: int, found_entries: set):
# This generation function is tangled with koboldai_vars intentionally. It
# is meant for the story and nothing else.
if koboldai_vars.full_determinism:
torch.manual_seed(koboldai_vars.seed)
if koboldai_vars.is_model_torch():
# Torch stuff
if koboldai_vars.full_determinism:
torch.manual_seed(koboldai_vars.seed)
gen_in = torch.tensor(text, dtype=torch.long)[None]
if koboldai_vars.sp is not None:
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
gen_in = torch.tensor(text, dtype=torch.long)[None]
if koboldai_vars.sp is not None:
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length

View File

@@ -287,6 +287,15 @@ class koboldai_vars(object):
self.context = context
return tokens, used_tokens, used_tokens+self.genamt
def is_model_torch(self) -> bool:
if self.use_colab_tpu:
return False
if self.model in ["Colab", "API", "CLUSTER", "ReadOnly", "OAI"]:
return False
return True
def __setattr__(self, name, value):
if name[0] == "_" or name == "tokenizer":
@@ -443,7 +452,7 @@ class model_settings(settings):
self.uid_presets = []
self.default_preset = {}
self.cluster_requested_models = [] # The models which we allow to generate during cluster mode
#dummy class to eat the tqdm output
class ignore_tqdm(object):