diff --git a/prompt_tuner.py b/prompt_tuner.py index c6db1bfb..1c381f2b 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -965,7 +965,7 @@ class TrainerBase(abc.ABC): # Give the context to the model and compare the model's output logits with the labels to compute the loss logits = model(input_ids=input_ids, labels=input_ids).logits - loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(1)), labels.view(-1)) + loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(0)), labels.view(-1)) total_loss += loss.detach() # Compute the gradient of the loss function and add it to model.get_soft_params().grad (model.get_soft_params().grad += gradient)