From bf8b60ac2db2f9280645c0cdbd1e87aa5c5d6157 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 13 Mar 2023 17:36:58 -0500 Subject: [PATCH] Model: Add GenericTokenizer Because Hugging Face doesnt have a consistant API across their own libraries --- modeling/inference_model.py | 5 +++-- modeling/tokenizer.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 modeling/tokenizer.py diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 72fa3314..6a08b0d5 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -12,6 +12,7 @@ from transformers import ( GPT2Tokenizer, AutoTokenizer, ) +from modeling.tokenizer import GenericTokenizer import utils @@ -180,7 +181,7 @@ class InferenceModel: selected device(s) and preparing it for inference should be implemented here.""" raise NotImplementedError - def _get_tokenizer(self, location: str) -> AutoTokenizer: + def _get_tokenizer(self, location: str) -> GenericTokenizer: """Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`. Args: @@ -214,7 +215,7 @@ class InferenceModel: for i, try_get_tokenizer in enumerate(suppliers): try: - return try_get_tokenizer() + return GenericTokenizer(try_get_tokenizer()) except: # If we error on each attempt, raise the last one if i == len(suppliers) - 1: diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py new file mode 100644 index 00000000..2bf162d7 --- /dev/null +++ b/modeling/tokenizer.py @@ -0,0 +1,23 @@ +from typing import List, Union +from tokenizers import Tokenizer +import torch +from transformers import PreTrainedTokenizer + + +class GenericTokenizer: + """Bridges the gap between Transformers tokenizers and Tokenizers tokenizers. Why they aren't the same, I don't know.""" + + def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: + self.tokenizer = tokenizer + + # TODO: Get rid of this + self._koboldai_header = [] + + def encode(self, text: str) -> list: + return self.tokenizer.encode(text).ids + + def decode(self, tokens: Union[int, List[int], torch.Tensor]) -> str: + if isinstance(tokens, int): + tokens = [tokens] + + return self.tokenizer.decode(tokens)