mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Update aux device to depend on primary device
This commit is contained in:
@@ -182,6 +182,16 @@ class InferenceModel:
|
||||
setattr(self, parameter, parameters[parameter])
|
||||
return
|
||||
|
||||
def get_auxilary_device(self) -> Union[str, int, torch.device]:
|
||||
"""Get device auxilary tensors like inputs should be stored on."""
|
||||
|
||||
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
|
||||
if utils.koboldai_vars.hascuda and utils.koboldai_vars.usegpu:
|
||||
return utils.koboldai_vars.gpu_device
|
||||
elif utils.koboldai_vars.hascuda:
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||||
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||||
|
||||
@@ -301,7 +311,7 @@ class InferenceModel:
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
gen_in = gen_in.to(utils.get_auxilary_device())
|
||||
gen_in = gen_in.to(self.get_auxilary_device())
|
||||
|
||||
logger.debug(
|
||||
"core_generate: gen_in to device time {}s".format(time.time() - start_time)
|
||||
|
Reference in New Issue
Block a user