Alt generate mode is enabled. Each generation will be sequential allowing for lower VRAM usage at potentially slower generation time

This commit is contained in:
ebolam
2022-12-06 10:06:24 -05:00
parent 4b51d0abd8
commit 93890b9034
2 changed files with 16 additions and 11 deletions

View File

@@ -2358,7 +2358,7 @@ def patch_transformers():
self.halt = not koboldai_vars.lua_koboldbridge.generating
koboldai_vars.lua_koboldbridge.regeneration_required = False
for i in range(koboldai_vars.numseqs):
for i in range(koboldai_vars.numseqs) if not koboldai_vars.alt_multi_gen else range(1):
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
return self.regeneration_required or self.halt
@@ -5281,7 +5281,7 @@ def core_generate(text: list, _min: int, _max: int, found_entries: set, is_core:
already_generated += len(genout[0])
try:
assert already_generated <= koboldai_vars.genamt
assert already_generated <= koboldai_vars.genamt * koboldai_vars.numseqs if koboldai_vars.alt_multi_gen else 1
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", koboldai_vars.genamt)
@@ -6059,7 +6059,6 @@ def generate(txt, minimum, maximum, found_entries=None):
try:
start_time = time.time()
genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries)
print(genout)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):