Potential fix

This commit is contained in:
somebody
2023-04-27 19:51:10 -05:00
parent b256a8fbc7
commit 4559112551
2 changed files with 1 additions and 3 deletions

View File

@@ -122,7 +122,6 @@ class Stoppers:
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
print(f"[stop_sequence_stopper] Input ids: {input_ids}")
data = [model.tokenizer.decode(x) for x in input_ids] data = [model.tokenizer.decode(x) for x in input_ids]
# null_character = model.tokenizer.encode(chr(0))[0] # null_character = model.tokenizer.encode(chr(0))[0]
if "completed" not in model.gen_state: if "completed" not in model.gen_state:

View File

@@ -28,11 +28,10 @@ class GenericTokenizer:
return ret.ids return ret.ids
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str: def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
print(f"[decode] Tokens: {tokens}")
if isinstance(tokens, torch.Tensor): if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().tolist() tokens = tokens.cpu().tolist()
if isinstance(tokens, int): if isinstance(tokens, int):
tokens = [tokens] tokens = [tokens]
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens, skip_special_tokens=True)