mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -246,9 +246,6 @@ class InferenceModel:
|
||||
start_time = time.time()
|
||||
if utils.koboldai_vars.is_model_torch():
|
||||
# Torch stuff
|
||||
if utils.koboldai_vars.full_determinism:
|
||||
torch.manual_seed(utils.koboldai_vars.seed)
|
||||
|
||||
if utils.koboldai_vars.sp is not None:
|
||||
assert self.capabilties.embedding_manipulation
|
||||
soft_tokens = torch.arange(
|
||||
@@ -256,9 +253,6 @@ class InferenceModel:
|
||||
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
elif utils.koboldai_vars.use_colab_tpu:
|
||||
if utils.koboldai_vars.full_determinism:
|
||||
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
|
||||
|
||||
logger.debug(
|
||||
"core_generate: Model Setup (SP, etc) time {}s".format(
|
||||
@@ -329,6 +323,9 @@ class InferenceModel:
|
||||
not utils.koboldai_vars.nogenmod
|
||||
and utils.koboldai_vars.has_genmod
|
||||
),
|
||||
seed=utils.koboldai_vars.seed
|
||||
if utils.koboldai_vars.full_determinism
|
||||
else None,
|
||||
)
|
||||
logger.debug(
|
||||
"core_generate: run raw_generate pass {} {}s".format(
|
||||
@@ -481,6 +478,7 @@ class InferenceModel:
|
||||
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||||
single_line (bool, optional): Generate one line only. Defaults to False.
|
||||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
|
||||
|
||||
Returns:
|
||||
GenerationResult: The model's output
|
||||
@@ -501,6 +499,7 @@ class InferenceModel:
|
||||
single_line: bool = False,
|
||||
found_entries: set = (),
|
||||
tpu_dynamic_inference: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||
@@ -562,6 +561,7 @@ class InferenceModel:
|
||||
gen_settings=gen_settings,
|
||||
single_line=single_line,
|
||||
tpu_dynamic_inference=tpu_dynamic_inference,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
|
Reference in New Issue
Block a user