Update aux device to depend on primary device

This commit is contained in:
somebody
2023-07-03 19:36:31 -05:00
parent 6f7e6422ef
commit bce1a907e5
4 changed files with 16 additions and 14 deletions

View File

@@ -182,6 +182,16 @@ class InferenceModel:
setattr(self, parameter, parameters[parameter])
return
def get_auxilary_device(self) -> Union[str, int, torch.device]:
"""Get device auxilary tensors like inputs should be stored on."""
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
if utils.koboldai_vars.hascuda and utils.koboldai_vars.usegpu:
return utils.koboldai_vars.gpu_device
elif utils.koboldai_vars.hascuda:
return "cuda"
return "cpu"
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""User-facing load function. Do not override this; try `_load()` instead."""
@@ -301,7 +311,7 @@ class InferenceModel:
)
start_time = time.time()
gen_in = gen_in.to(utils.get_auxilary_device())
gen_in = gen_in.to(self.get_auxilary_device())
logger.debug(
"core_generate: gen_in to device time {}s".format(time.time() - start_time)