Fix some remaining problems in prompt_tuner.py
This commit is contained in:
parent
07eb2b5c4f
commit
624f916dc6
|
@ -532,7 +532,7 @@ class TrainerBase(abc.ABC):
|
|||
|
||||
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
|
||||
with z.open("tensor.npy", "w") as f:
|
||||
np.save(f, tensor, allow_pickle=False)
|
||||
np.save(f, tensor.detach().cpu().numpy(), allow_pickle=False)
|
||||
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
|
||||
with z.open("meta.json", "w") as f:
|
||||
f.write(json.dumps(meta, indent=2).encode("utf-8"))
|
||||
|
@ -555,7 +555,7 @@ class TrainerBase(abc.ABC):
|
|||
{
|
||||
"metadata": {
|
||||
"step": _step,
|
||||
"loss": float(z["loss"].item()),
|
||||
"loss": float(z["loss"]),
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"name": soft_prompt_name,
|
||||
"description": soft_prompt_description,
|
||||
|
@ -563,7 +563,7 @@ class TrainerBase(abc.ABC):
|
|||
},
|
||||
"tensor": base64.b64encode(
|
||||
pickle.dumps(
|
||||
tensor,
|
||||
tensor.detach().cpu(),
|
||||
protocol=4,
|
||||
),
|
||||
).decode("ascii"),
|
||||
|
@ -695,7 +695,7 @@ class TrainerBase(abc.ABC):
|
|||
if breakmodel_gpulayers is None:
|
||||
breakmodel_gpulayers = []
|
||||
if breakmodel_primary_device is None:
|
||||
breakmodel_primary_device = 0 if sum(x if x >= 0 else 0 for x in breakmodel_gpulayers) else "cpu"
|
||||
breakmodel_primary_device = 0 if sum(x if x >= 0 else 1 for x in breakmodel_gpulayers) else "cpu"
|
||||
|
||||
if self.data.params is not None and "max_batch_size" not in self.data.params:
|
||||
self.data.params["max_batch_size"] = 2048
|
||||
|
@ -744,13 +744,14 @@ class TrainerBase(abc.ABC):
|
|||
assert len(breakmodel_gpulayers) <= torch.cuda.device_count()
|
||||
assert sum(breakmodel_gpulayers) + breakmodel_disklayers <= n_layers
|
||||
|
||||
breakmodel.gpu_blocks = breakmodel_gpulayers
|
||||
breakmodel.disk_blocks = breakmodel_disklayers
|
||||
disk_blocks = breakmodel.disk_blocks
|
||||
gpu_blocks = breakmodel.gpu_blocks
|
||||
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
device_list(n_layers, primary=breakmodel.primary_device)
|
||||
device_list(ram_blocks, primary=breakmodel.primary_device)
|
||||
|
||||
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
|
||||
if lazy_load_callback.nested:
|
||||
|
@ -883,11 +884,11 @@ class TrainerBase(abc.ABC):
|
|||
if("out of memory" in traceback.format_exc().lower()):
|
||||
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
|
||||
model = GPTNeoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||
|
||||
|
||||
if(hascuda):
|
||||
if(usegpu):
|
||||
model = model.half().to(gpu_device)
|
||||
elif(breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
elif(use_breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
move_model_to_devices(model, usegpu, gpu_device)
|
||||
elif(__import__("breakmodel").disk_blocks > 0):
|
||||
move_model_to_devices(model, usegpu, gpu_device)
|
||||
|
@ -1068,9 +1069,9 @@ class BasicTrainer(TrainerBase):
|
|||
k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens
|
||||
]
|
||||
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)))
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32, device=model.get_input_embeddings().weight.device)))
|
||||
elif self.data.prompt_method == "tokens":
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)))
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32, device=model.get_input_embeddings().weight.device)))
|
||||
self.raise_configuration_error(
|
||||
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue