mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Add seed parameter to raw_generate
Yahooo, decoupling from koboldai_vars. This makes the generation test pass in `test_generation.py`, and makes full determinism outside of core_generate work.
This commit is contained in:
@@ -10,6 +10,7 @@ model: InferenceModel
|
||||
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
|
||||
TEST_PROMPT = "Once upon a time I found myself"
|
||||
TEST_GEN_TOKEN_COUNT = 20
|
||||
TEST_SEED = 1337
|
||||
|
||||
|
||||
def test_generic_hf_torch_load() -> None:
|
||||
@@ -29,19 +30,15 @@ def test_generic_hf_torch_low_mem_load() -> None:
|
||||
|
||||
|
||||
def test_model_gen() -> None:
|
||||
# This should probably be supported in the model interface!
|
||||
koboldai_vars.full_determinism = True
|
||||
koboldai_vars.seed = 1337
|
||||
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
print(x.decoded)
|
||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
|
||||
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT)
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
|
||||
assert torch.equal(
|
||||
x.encoded[0], y.encoded[0]
|
||||
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
|
||||
|
||||
print(x)
|
||||
print(x)
|
||||
|
Reference in New Issue
Block a user