From aede7ef192a99e3ad6381342a63f3b7b3c653eca Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 22 Aug 2022 21:38:13 -0400 Subject: [PATCH] Fix typo in training routine of prompt_tuner.py --- prompt_tuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompt_tuner.py b/prompt_tuner.py index c6db1bfb..1c381f2b 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -965,7 +965,7 @@ class TrainerBase(abc.ABC): # Give the context to the model and compare the model's output logits with the labels to compute the loss logits = model(input_ids=input_ids, labels=input_ids).logits - loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(1)), labels.view(-1)) + loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(0)), labels.view(-1)) total_loss += loss.detach() # Compute the gradient of the loss function and add it to model.get_soft_params().grad (model.get_soft_params().grad += gradient)