Merge pull request #36 from VE-FORBRYDERNE/sp

Fix a typo in tpu_mtj_backend.py
This commit is contained in:
henk717 2021-11-23 14:23:10 +01:00 committed by GitHub
commit c0df03fc55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -276,7 +276,7 @@ def infer(
) -> List[str]:
maps.thread_resources.env = thread_resources_env
total_batch = numseqs
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - soft_tokens.shape[0] if soft_tokens is not None else 0, truncation=True))
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
provided_ctx = tokens.shape[0]