Fix a typo in tpu_mtj_backend.py
This commit is contained in:
parent
d877190258
commit
691febacd6
|
@ -276,7 +276,7 @@ def infer(
|
|||
) -> List[str]:
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
total_batch = numseqs
|
||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - soft_tokens.shape[0] if soft_tokens is not None else 0, truncation=True))
|
||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
||||
if(soft_tokens is not None):
|
||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||
provided_ctx = tokens.shape[0]
|
||||
|
|
Loading…
Reference in New Issue