mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: More Jax import fixes and formatting
This commit is contained in:
@@ -16,13 +16,6 @@ from modeling.tokenizer import GenericTokenizer
|
|||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
try:
|
|
||||||
import tpu_mtj_backend
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
# Not on TPU... hopefully
|
|
||||||
if utils.koboldai_vars.use_colab_tpu:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# We only want to use logit manipulations and such on our core text model
|
# We only want to use logit manipulations and such on our core text model
|
||||||
class use_core_manipulations:
|
class use_core_manipulations:
|
||||||
"""Use in a `with` block to patch functions for core story model sampling."""
|
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||||
|
@@ -18,12 +18,9 @@ from modeling.inference_model import (
|
|||||||
)
|
)
|
||||||
from modeling.inference_models.hf import HFInferenceModel
|
from modeling.inference_models.hf import HFInferenceModel
|
||||||
|
|
||||||
try:
|
# This file shouldn't be imported unless using the TPU
|
||||||
import tpu_mtj_backend
|
assert utils.koboldai_vars.use_colab_tpu
|
||||||
except ModuleNotFoundError as e:
|
import tpu_mtj_backend
|
||||||
# Not on TPU... hopefully
|
|
||||||
if utils.koboldai_vars.use_colab_tpu:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class HFMTJInferenceModel(HFInferenceModel):
|
class HFMTJInferenceModel(HFInferenceModel):
|
||||||
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
post_token_hooks=False,
|
post_token_hooks=False,
|
||||||
stopper_hooks=False,
|
stopper_hooks=False,
|
||||||
post_token_probs=False,
|
post_token_probs=False,
|
||||||
uses_tpu=True
|
uses_tpu=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_mtj(self) -> None:
|
def setup_mtj(self) -> None:
|
||||||
|
@@ -217,10 +217,7 @@ class RWKVInferenceModel(InferenceModel):
|
|||||||
for _ in range(max_new):
|
for _ in range(max_new):
|
||||||
|
|
||||||
logits, state = self.model.forward([last_token], state)
|
logits, state = self.model.forward([last_token], state)
|
||||||
last_token = self._sample_token(
|
last_token = self._sample_token(logits, context)
|
||||||
logits,
|
|
||||||
context
|
|
||||||
)
|
|
||||||
out.append(last_token)
|
out.append(last_token)
|
||||||
add = torch.tensor([[last_token]]).to(aux_device)
|
add = torch.tensor([[last_token]]).to(aux_device)
|
||||||
context = torch.cat((context, add), dim=-1)
|
context = torch.cat((context, add), dim=-1)
|
||||||
|
@@ -18,6 +18,7 @@ TEST_SEED = 1337
|
|||||||
|
|
||||||
# HF Torch
|
# HF Torch
|
||||||
|
|
||||||
|
|
||||||
def test_generic_hf_torch_load() -> None:
|
def test_generic_hf_torch_load() -> None:
|
||||||
global model
|
global model
|
||||||
model = GenericHFTorchInferenceModel(
|
model = GenericHFTorchInferenceModel(
|
||||||
@@ -38,7 +39,9 @@ def test_torch_inference() -> None:
|
|||||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
print(x.decoded)
|
print(x.decoded)
|
||||||
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
||||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
assert (
|
||||||
|
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||||
|
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
|
|
||||||
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
|
|
||||||
@@ -48,6 +51,7 @@ def test_torch_inference() -> None:
|
|||||||
|
|
||||||
print(x)
|
print(x)
|
||||||
|
|
||||||
|
|
||||||
# Horde
|
# Horde
|
||||||
def test_horde_load() -> None:
|
def test_horde_load() -> None:
|
||||||
global model
|
global model
|
||||||
@@ -56,19 +60,24 @@ def test_horde_load() -> None:
|
|||||||
model = HordeInferenceModel()
|
model = HordeInferenceModel()
|
||||||
model.load()
|
model.load()
|
||||||
|
|
||||||
|
|
||||||
def test_horde_inference() -> None:
|
def test_horde_inference() -> None:
|
||||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
assert (
|
||||||
|
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
||||||
|
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
print(x)
|
print(x)
|
||||||
|
|
||||||
|
|
||||||
# API
|
# API
|
||||||
def test_api_load() -> None:
|
def test_api_load() -> None:
|
||||||
global model
|
global model
|
||||||
model = APIInferenceModel()
|
model = APIInferenceModel()
|
||||||
model.load()
|
model.load()
|
||||||
|
|
||||||
|
|
||||||
def test_api_inference() -> None:
|
def test_api_inference() -> None:
|
||||||
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
||||||
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
||||||
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
||||||
print(x)
|
print(x)
|
||||||
|
@@ -28,4 +28,3 @@ class GenericTokenizer:
|
|||||||
tokens = [tokens]
|
tokens = [tokens]
|
||||||
|
|
||||||
return self.tokenizer.decode(tokens)
|
return self.tokenizer.decode(tokens)
|
||||||
|
|
Reference in New Issue
Block a user