mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: More Jax import fixes and formatting
This commit is contained in:
@@ -18,6 +18,7 @@ TEST_SEED = 1337
|
||||
|
||||
# HF Torch
|
||||
|
||||
|
||||
def test_generic_hf_torch_load() -> None:
|
||||
global model
|
||||
model = GenericHFTorchInferenceModel(
|
||||
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
print(x.decoded)
|
||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
assert (
|
||||
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
|
||||
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
|
||||
|
||||
print(x)
|
||||
|
||||
|
||||
# Horde
|
||||
def test_horde_load() -> None:
|
||||
global model
|
||||
@@ -56,19 +60,24 @@ def test_horde_load() -> None:
|
||||
model = HordeInferenceModel()
|
||||
model.load()
|
||||
|
||||
|
||||
def test_horde_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
assert (
|
||||
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
print(x)
|
||||
|
||||
|
||||
# API
|
||||
def test_api_load() -> None:
|
||||
global model
|
||||
model = APIInferenceModel()
|
||||
model.load()
|
||||
|
||||
|
||||
def test_api_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
||||
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
print(x)
|
||||
print(x)
|
||||
|
Reference in New Issue
Block a user