Model: More Jax import fixes and formatting

This commit is contained in:
somebody
2023-03-17 15:36:44 -05:00
parent 03af06638c
commit 8d0bc404a5
5 changed files with 17 additions and 22 deletions

View File

@@ -18,12 +18,9 @@ from modeling.inference_model import (
)
from modeling.inference_models.hf import HFInferenceModel
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
# This file shouldn't be imported unless using the TPU
assert utils.koboldai_vars.use_colab_tpu
import tpu_mtj_backend
class HFMTJInferenceModel(HFInferenceModel):
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True
uses_tpu=True,
)
def setup_mtj(self) -> None: