diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 26b619f6..e6e235d7 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -9,7 +9,6 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer - self.valid_tokens = set(self.tokenizer.vocab.values()) def __getattr__(self, name: str) -> Any: # Fall back to tokenizer for non-generic stuff @@ -29,16 +28,12 @@ class GenericTokenizer: return ret.ids def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str: - return self.tokenizer.decode(tokens) - if isinstance(tokens, torch.Tensor): - tokens = tokens.cpu().tolist() - if isinstance(tokens, int): tokens = [tokens] - - # Sometimes soft token placeholders aren't in the vocab, which causes - # errors on decode. Obviously we can't express these tokens as text so - # we can probably slice 'em out without too much issue - tokens = [t for t in tokens if t in self.valid_tokens] + + if isinstance(tokens, list): + tokens = torch.tensor(tokens) + elif isinstance(tokens, torch.Tensor): + tokens = tokens.cpu() return self.tokenizer.decode(tokens)