Modeling: Add seed parameter to raw_generate

Yahooo, decoupling from koboldai_vars. This makes the generation test
pass in `test_generation.py`, and makes full determinism outside of
core_generate work.
This commit is contained in:
somebody
2023-03-12 21:49:10 -05:00
parent 38c4edac40
commit cd8ccf0a5e
8 changed files with 51 additions and 20 deletions

View File

@@ -246,9 +246,6 @@ class InferenceModel:
start_time = time.time()
if utils.koboldai_vars.is_model_torch():
# Torch stuff
if utils.koboldai_vars.full_determinism:
torch.manual_seed(utils.koboldai_vars.seed)
if utils.koboldai_vars.sp is not None:
assert self.capabilties.embedding_manipulation
soft_tokens = torch.arange(
@@ -256,9 +253,6 @@ class InferenceModel:
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
elif utils.koboldai_vars.use_colab_tpu:
if utils.koboldai_vars.full_determinism:
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
logger.debug(
"core_generate: Model Setup (SP, etc) time {}s".format(
@@ -329,6 +323,9 @@ class InferenceModel:
not utils.koboldai_vars.nogenmod
and utils.koboldai_vars.has_genmod
),
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
@@ -481,6 +478,7 @@ class InferenceModel:
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
Returns:
GenerationResult: The model's output
@@ -501,6 +499,7 @@ class InferenceModel:
single_line: bool = False,
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -562,6 +561,7 @@ class InferenceModel:
gen_settings=gen_settings,
single_line=single_line,
tpu_dynamic_inference=tpu_dynamic_inference,
seed=seed,
)
time_end = round(time.time() - time_start, 2)