Fix for --nobreakmodel forcing CPU

Put importing of colab packages into a if function so it doesn't error out
This commit is contained in:
ebolam
2023-06-02 12:58:59 -04:00
parent bda8c931f8
commit 5c4d580aac
3 changed files with 12 additions and 6 deletions

View File

@@ -1682,7 +1682,6 @@ class RestrictedUnpickler(pickle.Unpickler):
) )
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
logger.info("Using safe unpickle")
self.original_persistent_load = getattr( self.original_persistent_load = getattr(
self, "persistent_load", pickle.Unpickler.persistent_load self, "persistent_load", pickle.Unpickler.persistent_load
) )

View File

@@ -250,7 +250,7 @@ class model_backend(HFTorchInferenceModel):
if utils.koboldai_vars.hascuda: if utils.koboldai_vars.hascuda:
if self.usegpu: if self.usegpu or self.nobreakmodel:
# Use just VRAM # Use just VRAM
self.model = self.model.half().to(utils.koboldai_vars.gpu_device) self.model = self.model.half().to(utils.koboldai_vars.gpu_device)
elif self.breakmodel: elif self.breakmodel:

View File

@@ -42,10 +42,17 @@ import utils
import torch import torch
import numpy as np import numpy as np
if utils.koboldai_vars.use_colab_tpu: try:
import jax ignore = utils.koboldai_vars.use_colab_tpu
import jax.numpy as jnp ok = True
import tpu_mtj_backend except:
ok = False
if ok:
if utils.koboldai_vars.use_colab_tpu:
import jax
import jax.numpy as jnp
import tpu_mtj_backend
def update_settings(): def update_settings():