mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: More Jax import fixes and formatting
This commit is contained in:
@@ -16,13 +16,6 @@ from modeling.tokenizer import GenericTokenizer
|
||||
|
||||
import utils
|
||||
|
||||
try:
|
||||
import tpu_mtj_backend
|
||||
except ModuleNotFoundError as e:
|
||||
# Not on TPU... hopefully
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
# We only want to use logit manipulations and such on our core text model
|
||||
class use_core_manipulations:
|
||||
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||
|
@@ -18,12 +18,9 @@ from modeling.inference_model import (
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
try:
|
||||
# This file shouldn't be imported unless using the TPU
|
||||
assert utils.koboldai_vars.use_colab_tpu
|
||||
import tpu_mtj_backend
|
||||
except ModuleNotFoundError as e:
|
||||
# Not on TPU... hopefully
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
|
||||
class HFMTJInferenceModel(HFInferenceModel):
|
||||
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
post_token_hooks=False,
|
||||
stopper_hooks=False,
|
||||
post_token_probs=False,
|
||||
uses_tpu=True
|
||||
uses_tpu=True,
|
||||
)
|
||||
|
||||
def setup_mtj(self) -> None:
|
||||
|
@@ -217,10 +217,7 @@ class RWKVInferenceModel(InferenceModel):
|
||||
for _ in range(max_new):
|
||||
|
||||
logits, state = self.model.forward([last_token], state)
|
||||
last_token = self._sample_token(
|
||||
logits,
|
||||
context
|
||||
)
|
||||
last_token = self._sample_token(logits, context)
|
||||
out.append(last_token)
|
||||
add = torch.tensor([[last_token]]).to(aux_device)
|
||||
context = torch.cat((context, add), dim=-1)
|
||||
|
@@ -18,6 +18,7 @@ TEST_SEED = 1337
|
||||
|
||||
# HF Torch
|
||||
|
||||
|
||||
def test_generic_hf_torch_load() -> None:
|
||||
global model
|
||||
model = GenericHFTorchInferenceModel(
|
||||
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
print(x.decoded)
|
||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
assert (
|
||||
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
|
||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
|
||||
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
|
||||
|
||||
print(x)
|
||||
|
||||
|
||||
# Horde
|
||||
def test_horde_load() -> None:
|
||||
global model
|
||||
@@ -56,17 +60,22 @@ def test_horde_load() -> None:
|
||||
model = HordeInferenceModel()
|
||||
model.load()
|
||||
|
||||
|
||||
def test_horde_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
assert (
|
||||
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||
print(x)
|
||||
|
||||
|
||||
# API
|
||||
def test_api_load() -> None:
|
||||
global model
|
||||
model = APIInferenceModel()
|
||||
model.load()
|
||||
|
||||
|
||||
def test_api_inference() -> None:
|
||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
||||
|
@@ -28,4 +28,3 @@ class GenericTokenizer:
|
||||
tokens = [tokens]
|
||||
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
Reference in New Issue
Block a user