mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Actually fix decoding with soft prompts
it really wants a tensor
This commit is contained in:
@@ -9,7 +9,6 @@ class GenericTokenizer:
|
||||
|
||||
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.valid_tokens = set(self.tokenizer.vocab.values())
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# Fall back to tokenizer for non-generic stuff
|
||||
@@ -29,16 +28,12 @@ class GenericTokenizer:
|
||||
return ret.ids
|
||||
|
||||
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
|
||||
return self.tokenizer.decode(tokens)
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.cpu().tolist()
|
||||
|
||||
if isinstance(tokens, int):
|
||||
tokens = [tokens]
|
||||
|
||||
# Sometimes soft token placeholders aren't in the vocab, which causes
|
||||
# errors on decode. Obviously we can't express these tokens as text so
|
||||
# we can probably slice 'em out without too much issue
|
||||
tokens = [t for t in tokens if t in self.valid_tokens]
|
||||
|
||||
if isinstance(tokens, list):
|
||||
tokens = torch.tensor(tokens)
|
||||
elif isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.cpu()
|
||||
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
Reference in New Issue
Block a user