mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Work on mtj
This commit is contained in:
24
aiserver.py
24
aiserver.py
@@ -4841,23 +4841,25 @@ def calcsubmit(txt):
|
||||
# Send it!
|
||||
ikrequest(subtxt)
|
||||
|
||||
def __debug(text):
|
||||
print(f"[DBG] {text}")
|
||||
def __debug(*args):
|
||||
print("[DBG] ", *args)
|
||||
|
||||
def core_generate(text: list, min: int, max: int, found_entries: set):
|
||||
# This generation function is tangled with koboldai_vars intentionally. It
|
||||
# is meant for the story and nothing else.
|
||||
|
||||
if koboldai_vars.full_determinism:
|
||||
torch.manual_seed(koboldai_vars.seed)
|
||||
if koboldai_vars.is_model_torch():
|
||||
# Torch stuff
|
||||
if koboldai_vars.full_determinism:
|
||||
torch.manual_seed(koboldai_vars.seed)
|
||||
|
||||
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||||
if koboldai_vars.sp is not None:
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||||
if koboldai_vars.sp is not None:
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
|
||||
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length
|
||||
|
||||
|
@@ -287,6 +287,15 @@ class koboldai_vars(object):
|
||||
|
||||
self.context = context
|
||||
return tokens, used_tokens, used_tokens+self.genamt
|
||||
|
||||
def is_model_torch(self) -> bool:
|
||||
if self.use_colab_tpu:
|
||||
return False
|
||||
|
||||
if self.model in ["Colab", "API", "CLUSTER", "ReadOnly", "OAI"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name[0] == "_" or name == "tokenizer":
|
||||
@@ -443,7 +452,7 @@ class model_settings(settings):
|
||||
self.uid_presets = []
|
||||
self.default_preset = {}
|
||||
self.cluster_requested_models = [] # The models which we allow to generate during cluster mode
|
||||
|
||||
|
||||
|
||||
#dummy class to eat the tqdm output
|
||||
class ignore_tqdm(object):
|
||||
|
Reference in New Issue
Block a user