mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix OOM when loading large model split across GPUs
This commit is contained in:
@@ -139,10 +139,8 @@ class HFTorch4BitInferenceModel(HFTorchInferenceModel):
|
|||||||
self.gpu_layers_list = [int(l) for l in gpulayers.split(",")]
|
self.gpu_layers_list = [int(l) for l in gpulayers.split(",")]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.gpu_layers_list = [utils.num_layers(self.model_config)]
|
self.gpu_layers_list = [utils.num_layers(self.model_config)]
|
||||||
self.offload_4bit = sum(self.gpu_layers_list) < utils.num_layers(self.model_config)
|
|
||||||
|
|
||||||
if self.offload_4bit:
|
if sum(self.gpu_layers_list) < utils.num_layers(self.model_config):
|
||||||
utils.koboldai_vars.lazy_load = False
|
|
||||||
print("4-bit CPU offloader active")
|
print("4-bit CPU offloader active")
|
||||||
|
|
||||||
tf_kwargs = {
|
tf_kwargs = {
|
||||||
@@ -343,9 +341,6 @@ class HFTorch4BitInferenceModel(HFTorchInferenceModel):
|
|||||||
|
|
||||||
self.patch_embedding()
|
self.patch_embedding()
|
||||||
|
|
||||||
if not self.offload_4bit:
|
|
||||||
self.model = self.model.half().to(utils.koboldai_vars.gpu_device)
|
|
||||||
|
|
||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
|
||||||
@@ -375,7 +370,7 @@ class HFTorch4BitInferenceModel(HFTorchInferenceModel):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"4-bit load failed. Model type {utils.koboldai_vars.model_type} not supported in 4-bit")
|
raise RuntimeError(f"4-bit load failed. Model type {utils.koboldai_vars.model_type} not supported in 4-bit")
|
||||||
|
|
||||||
return model.half() if not self.offload_4bit else model
|
return model
|
||||||
|
|
||||||
def _get_tokenizer(self, location: str):
|
def _get_tokenizer(self, location: str):
|
||||||
if utils.koboldai_vars.model_type == "llama":
|
if utils.koboldai_vars.model_type == "llama":
|
||||||
|
Reference in New Issue
Block a user