mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: More Jax import fixes and formatting
This commit is contained in:
@@ -18,12 +18,9 @@ from modeling.inference_model import (
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
try:
|
||||
import tpu_mtj_backend
|
||||
except ModuleNotFoundError as e:
|
||||
# Not on TPU... hopefully
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
# This file shouldn't be imported unless using the TPU
|
||||
assert utils.koboldai_vars.use_colab_tpu
|
||||
import tpu_mtj_backend
|
||||
|
||||
|
||||
class HFMTJInferenceModel(HFInferenceModel):
|
||||
@@ -39,7 +36,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
post_token_hooks=False,
|
||||
stopper_hooks=False,
|
||||
post_token_probs=False,
|
||||
uses_tpu=True
|
||||
uses_tpu=True,
|
||||
)
|
||||
|
||||
def setup_mtj(self) -> None:
|
||||
|
Reference in New Issue
Block a user