From 8937e7f6df9304ad2152503d88249825a82b9907 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 9 Mar 2023 21:02:12 -0600 Subject: [PATCH] Model: Add basic tests We now do some basic tests for: - hf torch loading (normal, lazy, lowmem) - hf torch generation (shape batches, shape tokencount, faulty determinism) Currently full determinism is failing; yahoo, the tests work! All of the tests initally failed (note the test environment functions different from the aiserver environment due to aiserver doing a lot of initalizing stuff, working on phasing that out) but now only one fails. Very useful for finding bugs! --- modeling/test_generation.py | 47 +++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 modeling/test_generation.py diff --git a/modeling/test_generation.py b/modeling/test_generation.py new file mode 100644 index 00000000..0d9173b5 --- /dev/null +++ b/modeling/test_generation.py @@ -0,0 +1,47 @@ +# We have to go through aiserver to initalize koboldai_vars :( +import torch +from aiserver import GenericHFTorchInferenceModel +from aiserver import koboldai_vars +from modeling.inference_model import InferenceModel + +model: InferenceModel + +# Preferably teeny tiny +TEST_MODEL_HF_ID = "EleutherAI/pythia-70m" +TEST_PROMPT = "Once upon a time I found myself" +TEST_GEN_TOKEN_COUNT = 20 + + +def test_generic_hf_torch_load() -> None: + global model + model = GenericHFTorchInferenceModel( + TEST_MODEL_HF_ID, lazy_load=False, low_mem=False + ) + model.load() + + +def test_generic_hf_torch_lazy_load() -> None: + GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=True, low_mem=False).load() + + +def test_generic_hf_torch_low_mem_load() -> None: + GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load() + + +def test_model_gen() -> None: + # This should probably be supported in the model interface! + koboldai_vars.full_determinism = True + koboldai_vars.seed = 1337 + + x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + print(x.decoded) + assert len(x.encoded) == 1, "Bad output shape (too many batches!)" + assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" + + y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + + assert torch.equal( + x.encoded[0], y.encoded[0] + ), f"Faulty full determinism! {x.decoded} vs {y.decoded}" + + print(x) \ No newline at end of file