mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
from typing import Any, List, Union
|
|
from tokenizers import Tokenizer
|
|
import torch
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
class GenericTokenizer:
|
|
"""Bridges the gap between Transformers tokenizers and Tokenizers tokenizers. Why they aren't the same, I don't know."""
|
|
|
|
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
|
|
self.tokenizer = tokenizer
|
|
try:
|
|
self.valid_tokens = set(self.tokenizer.vocab.values())
|
|
except AttributeError:
|
|
self.valid_tokens = set(self.tokenizer.get_vocab().values())
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
# Fall back to tokenizer for non-generic stuff
|
|
return getattr(self.tokenizer, name)
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
# To prevent infinite recursion on __init__ setting
|
|
if name == "tokenizer":
|
|
super().__setattr__(name, value)
|
|
return
|
|
setattr(self.tokenizer, name, value)
|
|
|
|
def encode(self, text: str) -> list:
|
|
ret = self.tokenizer.encode(text)
|
|
if isinstance(ret, list):
|
|
return ret
|
|
return ret.ids
|
|
|
|
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
|
|
if isinstance(tokens, torch.Tensor):
|
|
tokens = tokens.cpu().tolist()
|
|
|
|
if isinstance(tokens, int):
|
|
tokens = [tokens]
|
|
|
|
# HACK: Sometimes soft token placeholders aren't in the vocab, which
|
|
# causes errors on decode. Obviously we can't express these tokens as
|
|
# text so we can probably slice 'em out without too much issue.
|
|
tokens = [t for t in tokens if t in self.valid_tokens]
|
|
|
|
return self.tokenizer.decode(tokens)
|