Model: More Jax import fixes and formatting

This commit is contained in:
somebody
2023-03-17 15:36:44 -05:00
parent 03af06638c
commit 8d0bc404a5
5 changed files with 17 additions and 22 deletions

View File

@@ -16,13 +16,6 @@ from modeling.tokenizer import GenericTokenizer
import utils import utils
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
# We only want to use logit manipulations and such on our core text model # We only want to use logit manipulations and such on our core text model
class use_core_manipulations: class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling.""" """Use in a `with` block to patch functions for core story model sampling."""

View File

@@ -18,12 +18,9 @@ from modeling.inference_model import (
) )
from modeling.inference_models.hf import HFInferenceModel from modeling.inference_models.hf import HFInferenceModel
try: # This file shouldn't be imported unless using the TPU
import tpu_mtj_backend assert utils.koboldai_vars.use_colab_tpu
except ModuleNotFoundError as e: import tpu_mtj_backend
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
class HFMTJInferenceModel(HFInferenceModel): class HFMTJInferenceModel(HFInferenceModel):
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
post_token_hooks=False, post_token_hooks=False,
stopper_hooks=False, stopper_hooks=False,
post_token_probs=False, post_token_probs=False,
uses_tpu=True uses_tpu=True,
) )
def setup_mtj(self) -> None: def setup_mtj(self) -> None:

View File

@@ -217,10 +217,7 @@ class RWKVInferenceModel(InferenceModel):
for _ in range(max_new): for _ in range(max_new):
logits, state = self.model.forward([last_token], state) logits, state = self.model.forward([last_token], state)
last_token = self._sample_token( last_token = self._sample_token(logits, context)
logits,
context
)
out.append(last_token) out.append(last_token)
add = torch.tensor([[last_token]]).to(aux_device) add = torch.tensor([[last_token]]).to(aux_device)
context = torch.cat((context, add), dim=-1) context = torch.cat((context, add), dim=-1)

View File

@@ -18,6 +18,7 @@ TEST_SEED = 1337
# HF Torch # HF Torch
def test_generic_hf_torch_load() -> None: def test_generic_hf_torch_load() -> None:
global model global model
model = GenericHFTorchInferenceModel( model = GenericHFTorchInferenceModel(
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded) print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)" assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
print(x) print(x)
# Horde # Horde
def test_horde_load() -> None: def test_horde_load() -> None:
global model global model
@@ -56,19 +60,24 @@ def test_horde_load() -> None:
model = HordeInferenceModel() model = HordeInferenceModel()
model.load() model.load()
def test_horde_inference() -> None: def test_horde_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" assert (
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x) print(x)
# API # API
def test_api_load() -> None: def test_api_load() -> None:
global model global model
model = APIInferenceModel() model = APIInferenceModel()
model.load() model.load()
def test_api_inference() -> None: def test_api_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
# NOTE: Below test is flakey due to Horde worker-defined constraints # NOTE: Below test is flakey due to Horde worker-defined constraints
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" # assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x) print(x)

View File

@@ -28,4 +28,3 @@ class GenericTokenizer:
tokens = [tokens] tokens = [tokens]
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)