Implement softprompt hack

This commit is contained in:
one-some
2023-04-28 10:26:59 -05:00
parent fa6bb4b956
commit 455b8257a9

View File

@@ -9,6 +9,7 @@ class GenericTokenizer:
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.valid_tokens = set(self.tokenizer.vocab.values())
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
# Fall back to tokenizer for non-generic stuff # Fall back to tokenizer for non-generic stuff
@@ -28,12 +29,15 @@ class GenericTokenizer:
return ret.ids return ret.ids
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str: def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
# TODO: Figure out why this breaks softprompts on some models if isinstance(tokens, torch.Tensor):
# if isinstance(tokens, int): tokens = tokens.cpu().tolist()
# tokens = [tokens]
# if isinstance(tokens, list): if isinstance(tokens, int):
# tokens = torch.tensor(tokens) tokens = [tokens]
# elif isinstance(tokens, torch.Tensor):
# tokens = tokens.cpu() # HACK: Sometimes soft token placeholders aren't in the vocab, which
# causes errors on decode. Obviously we can't express these tokens as
# text so we can probably slice 'em out without too much issue.
tokens = [t for t in tokens if t in self.valid_tokens]
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)