Fix typo in training routine of prompt_tuner.py

This commit is contained in:
vfbd 2022-08-22 21:38:13 -04:00
parent 1e9f0e68a0
commit aede7ef192
1 changed files with 1 additions and 1 deletions

View File

@ -965,7 +965,7 @@ class TrainerBase(abc.ABC):
# Give the context to the model and compare the model's output logits with the labels to compute the loss
logits = model(input_ids=input_ids, labels=input_ids).logits
loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(1)), labels.view(-1))
loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(0)), labels.view(-1))
total_loss += loss.detach()
# Compute the gradient of the loss function and add it to model.get_soft_params().grad (model.get_soft_params().grad += gradient)