Model: More Jax import fixes and formatting

This commit is contained in:
somebody
2023-03-17 15:36:44 -05:00
parent 03af06638c
commit 8d0bc404a5
5 changed files with 17 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ TEST_SEED = 1337
# HF Torch
def test_generic_hf_torch_load() -> None:
global model
model = GenericHFTorchInferenceModel(
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
print(x)
# Horde
def test_horde_load() -> None:
global model
@@ -56,19 +60,24 @@ def test_horde_load() -> None:
model = HordeInferenceModel()
model.load()
def test_horde_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
# API
def test_api_load() -> None:
global model
model = APIInferenceModel()
model.load()
def test_api_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
# NOTE: Below test is flakey due to Horde worker-defined constraints
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
print(x)