Fix soft prompt length calculation in `calcsubmitbudget()`

In TPU instances, `vars.sp.shape[0]` is not always the actual number of
tokens in the soft prompt. We have to use `vars.sp_length` to get an
accurate token count.
This commit is contained in:
Gnome Ann 2022-01-17 13:17:20 -05:00
parent 74f79081d1
commit 9594b2db1c
1 changed files with 1 additions and 1 deletions

View File

@ -2494,7 +2494,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
anotetkns = [] # Placeholder for Author's Note tokens anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length lnanote = 0 # Placeholder for Author's Note length
lnsp = vars.sp.shape[0] if vars.sp is not None else 0 lnsp = vars.sp_length
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast