Model: More Jax import fixes and formatting

This commit is contained in:
somebody
2023-03-17 15:36:44 -05:00
parent 03af06638c
commit 8d0bc404a5
5 changed files with 17 additions and 22 deletions

View File

@@ -16,13 +16,6 @@ from modeling.tokenizer import GenericTokenizer
import utils
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling."""

View File

@@ -18,12 +18,9 @@ from modeling.inference_model import (
)
from modeling.inference_models.hf import HFInferenceModel
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
# This file shouldn't be imported unless using the TPU
assert utils.koboldai_vars.use_colab_tpu
import tpu_mtj_backend
class HFMTJInferenceModel(HFInferenceModel):
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True
uses_tpu=True,
)
def setup_mtj(self) -> None:

View File

@@ -217,10 +217,7 @@ class RWKVInferenceModel(InferenceModel):
for _ in range(max_new):
logits, state = self.model.forward([last_token], state)
last_token = self._sample_token(
logits,
context
)
last_token = self._sample_token(logits, context)
out.append(last_token)
add = torch.tensor([[last_token]]).to(aux_device)
context = torch.cat((context, add), dim=-1)

View File

@@ -18,6 +18,7 @@ TEST_SEED = 1337
# HF Torch
def test_generic_hf_torch_load() -> None:
global model
model = GenericHFTorchInferenceModel(
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
print(x)
# Horde
def test_horde_load() -> None:
global model
@@ -56,19 +60,24 @@ def test_horde_load() -> None:
model = HordeInferenceModel()
model.load()
def test_horde_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
# API
def test_api_load() -> None:
global model
model = APIInferenceModel()
model.load()
def test_api_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
# NOTE: Below test is flakey due to Horde worker-defined constraints
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
print(x)

View File

@@ -28,4 +28,3 @@ class GenericTokenizer:
tokens = [tokens]
return self.tokenizer.decode(tokens)