Fix typo in training routine of prompt_tuner.py
This commit is contained in:
parent
1e9f0e68a0
commit
aede7ef192
|
@ -965,7 +965,7 @@ class TrainerBase(abc.ABC):
|
|||
|
||||
# Give the context to the model and compare the model's output logits with the labels to compute the loss
|
||||
logits = model(input_ids=input_ids, labels=input_ids).logits
|
||||
loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(1)), labels.view(-1))
|
||||
loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(0)), labels.view(-1))
|
||||
total_loss += loss.detach()
|
||||
|
||||
# Compute the gradient of the loss function and add it to model.get_soft_params().grad (model.get_soft_params().grad += gradient)
|
||||
|
|
Loading…
Reference in New Issue