Actually fix decoding with soft prompts

it really wants a tensor
This commit is contained in:
somebody
2023-04-27 21:01:12 -05:00
parent ffa7b22734
commit 2eee535540

View File

@@ -9,7 +9,6 @@ class GenericTokenizer:
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
self.tokenizer = tokenizer
self.valid_tokens = set(self.tokenizer.vocab.values())
def __getattr__(self, name: str) -> Any:
# Fall back to tokenizer for non-generic stuff
@@ -29,16 +28,12 @@ class GenericTokenizer:
return ret.ids
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
return self.tokenizer.decode(tokens)
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().tolist()
if isinstance(tokens, int):
tokens = [tokens]
# Sometimes soft token placeholders aren't in the vocab, which causes
# errors on decode. Obviously we can't express these tokens as text so
# we can probably slice 'em out without too much issue
tokens = [t for t in tokens if t in self.valid_tokens]
if isinstance(tokens, list):
tokens = torch.tensor(tokens)
elif isinstance(tokens, torch.Tensor):
tokens = tokens.cpu()
return self.tokenizer.decode(tokens)