Modeling: Add seed parameter to raw_generate

Yahooo, decoupling from koboldai_vars. This makes the generation test
pass in `test_generation.py`, and makes full determinism outside of
core_generate work.
This commit is contained in:
somebody
2023-03-12 21:49:10 -05:00
parent 38c4edac40
commit cd8ccf0a5e
8 changed files with 51 additions and 20 deletions

View File

@@ -10,6 +10,7 @@ model: InferenceModel
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
TEST_PROMPT = "Once upon a time I found myself"
TEST_GEN_TOKEN_COUNT = 20
TEST_SEED = 1337
def test_generic_hf_torch_load() -> None:
@@ -29,19 +30,15 @@ def test_generic_hf_torch_low_mem_load() -> None:
def test_model_gen() -> None:
# This should probably be supported in the model interface!
koboldai_vars.full_determinism = True
koboldai_vars.seed = 1337
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert torch.equal(
x.encoded[0], y.encoded[0]
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
print(x)
print(x)