Fix a typo in tpu_mtj_backend.py

This commit is contained in:
Gnome Ann 2021-11-22 12:53:19 -05:00
parent d877190258
commit 691febacd6
1 changed files with 1 additions and 1 deletions

View File

@ -276,7 +276,7 @@ def infer(
) -> List[str]:
maps.thread_resources.env = thread_resources_env
total_batch = numseqs
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - soft_tokens.shape[0] if soft_tokens is not None else 0, truncation=True))
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
provided_ctx = tokens.shape[0]