From 8d0bc404a5da54712b6c0cd4e8412d33c194d185 Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 17 Mar 2023 15:36:44 -0500 Subject: [PATCH] Model: More Jax import fixes and formatting --- modeling/inference_model.py | 7 ------- modeling/inference_models/hf_mtj.py | 11 ++++------- modeling/inference_models/rwkv.py | 5 +---- modeling/test_generation.py | 15 ++++++++++++--- modeling/tokenizer.py | 1 - 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 4eb63618..010b9ddd 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -16,13 +16,6 @@ from modeling.tokenizer import GenericTokenizer import utils -try: - import tpu_mtj_backend -except ModuleNotFoundError as e: - # Not on TPU... hopefully - if utils.koboldai_vars.use_colab_tpu: - raise e - # We only want to use logit manipulations and such on our core text model class use_core_manipulations: """Use in a `with` block to patch functions for core story model sampling.""" diff --git a/modeling/inference_models/hf_mtj.py b/modeling/inference_models/hf_mtj.py index 4a856d7a..756c5fef 100644 --- a/modeling/inference_models/hf_mtj.py +++ b/modeling/inference_models/hf_mtj.py @@ -18,12 +18,9 @@ from modeling.inference_model import ( ) from modeling.inference_models.hf import HFInferenceModel -try: - import tpu_mtj_backend -except ModuleNotFoundError as e: - # Not on TPU... hopefully - if utils.koboldai_vars.use_colab_tpu: - raise e +# This file shouldn't be imported unless using the TPU +assert utils.koboldai_vars.use_colab_tpu +import tpu_mtj_backend class HFMTJInferenceModel(HFInferenceModel): @@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel): post_token_hooks=False, stopper_hooks=False, post_token_probs=False, - uses_tpu=True + uses_tpu=True, ) def setup_mtj(self) -> None: diff --git a/modeling/inference_models/rwkv.py b/modeling/inference_models/rwkv.py index e6bef128..d6ebf357 100644 --- a/modeling/inference_models/rwkv.py +++ b/modeling/inference_models/rwkv.py @@ -217,10 +217,7 @@ class RWKVInferenceModel(InferenceModel): for _ in range(max_new): logits, state = self.model.forward([last_token], state) - last_token = self._sample_token( - logits, - context - ) + last_token = self._sample_token(logits, context) out.append(last_token) add = torch.tensor([[last_token]]).to(aux_device) context = torch.cat((context, add), dim=-1) diff --git a/modeling/test_generation.py b/modeling/test_generation.py index 0f700d0b..b47b39bd 100644 --- a/modeling/test_generation.py +++ b/modeling/test_generation.py @@ -18,6 +18,7 @@ TEST_SEED = 1337 # HF Torch + def test_generic_hf_torch_load() -> None: global model model = GenericHFTorchInferenceModel( @@ -38,7 +39,9 @@ def test_torch_inference() -> None: x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) print(x.decoded) assert len(x.encoded) == 1, "Bad output shape (too many batches!)" - assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" + assert ( + len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT + ), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) @@ -48,6 +51,7 @@ def test_torch_inference() -> None: print(x) + # Horde def test_horde_load() -> None: global model @@ -56,19 +60,24 @@ def test_horde_load() -> None: model = HordeInferenceModel() model.load() + def test_horde_inference() -> None: x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) - assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" + assert ( + len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT + ), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" print(x) + # API def test_api_load() -> None: global model model = APIInferenceModel() model.load() + def test_api_inference() -> None: x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) # NOTE: Below test is flakey due to Horde worker-defined constraints # assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" - print(x) \ No newline at end of file + print(x) diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index f05398f2..6c41764b 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -28,4 +28,3 @@ class GenericTokenizer: tokens = [tokens] return self.tokenizer.decode(tokens) - \ No newline at end of file