Fix for tokenizer stuff on pythia

This commit is contained in:
somebody
2023-04-09 18:23:58 -05:00
parent 3e8e3a18b0
commit 334c09606b

View File

@@ -22,9 +22,10 @@ class GenericTokenizer:
setattr(self.tokenizer, name, value)
def encode(self, text: str) -> list:
if isinstance(self.tokenizer, PreTrainedTokenizer):
return self.tokenizer.encode(text)
return self.tokenizer.encode(text).ids
ret = self.tokenizer.encode(text)
if isinstance(ret, list):
return ret
return ret.ids
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
if isinstance(tokens, torch.Tensor):