From 455b8257a96547e1f12b07beca4c864887d930a3 Mon Sep 17 00:00:00 2001 From: one-some Date: Fri, 28 Apr 2023 10:26:59 -0500 Subject: [PATCH] Implement softprompt hack --- modeling/tokenizer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 323c3885..84555501 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -9,6 +9,7 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer + self.valid_tokens = set(self.tokenizer.vocab.values()) def __getattr__(self, name: str) -> Any: # Fall back to tokenizer for non-generic stuff @@ -28,12 +29,15 @@ class GenericTokenizer: return ret.ids def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str: - # TODO: Figure out why this breaks softprompts on some models - # if isinstance(tokens, int): - # tokens = [tokens] - # if isinstance(tokens, list): - # tokens = torch.tensor(tokens) - # elif isinstance(tokens, torch.Tensor): - # tokens = tokens.cpu() + if isinstance(tokens, torch.Tensor): + tokens = tokens.cpu().tolist() + + if isinstance(tokens, int): + tokens = [tokens] + + # HACK: Sometimes soft token placeholders aren't in the vocab, which + # causes errors on decode. Obviously we can't express these tokens as + # text so we can probably slice 'em out without too much issue. + tokens = [t for t in tokens if t in self.valid_tokens] return self.tokenizer.decode(tokens)