Model: Add GenericTokenizer

Because Hugging Face doesnt have a consistant API across their own
libraries
This commit is contained in:
somebody
2023-03-13 17:36:58 -05:00
parent 60793eb121
commit bf8b60ac2d
2 changed files with 26 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from transformers import (
GPT2Tokenizer, GPT2Tokenizer,
AutoTokenizer, AutoTokenizer,
) )
from modeling.tokenizer import GenericTokenizer
import utils import utils
@@ -180,7 +181,7 @@ class InferenceModel:
selected device(s) and preparing it for inference should be implemented here.""" selected device(s) and preparing it for inference should be implemented here."""
raise NotImplementedError raise NotImplementedError
def _get_tokenizer(self, location: str) -> AutoTokenizer: def _get_tokenizer(self, location: str) -> GenericTokenizer:
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`. """Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
Args: Args:
@@ -214,7 +215,7 @@ class InferenceModel:
for i, try_get_tokenizer in enumerate(suppliers): for i, try_get_tokenizer in enumerate(suppliers):
try: try:
return try_get_tokenizer() return GenericTokenizer(try_get_tokenizer())
except: except:
# If we error on each attempt, raise the last one # If we error on each attempt, raise the last one
if i == len(suppliers) - 1: if i == len(suppliers) - 1:

23
modeling/tokenizer.py Normal file
View File

@@ -0,0 +1,23 @@
from typing import List, Union
from tokenizers import Tokenizer
import torch
from transformers import PreTrainedTokenizer
class GenericTokenizer:
"""Bridges the gap between Transformers tokenizers and Tokenizers tokenizers. Why they aren't the same, I don't know."""
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
self.tokenizer = tokenizer
# TODO: Get rid of this
self._koboldai_header = []
def encode(self, text: str) -> list:
return self.tokenizer.encode(text).ids
def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str:
if isinstance(tokens, int):
tokens = [tokens]
return self.tokenizer.decode(tokens)