Fix tokenizer fallback for llama

This commit is contained in:
somebody
2023-05-01 19:42:52 -05:00
parent f6b5548131
commit 111028642e
2 changed files with 6 additions and 2 deletions

View File

@@ -223,7 +223,8 @@ class InferenceModel:
for i, try_get_tokenizer in enumerate(suppliers): for i, try_get_tokenizer in enumerate(suppliers):
try: try:
return GenericTokenizer(try_get_tokenizer()) return GenericTokenizer(try_get_tokenizer())
except: except Exception as e:
logger.warn(f"Tokenizer falling back due to {e}")
# If we error on each attempt, raise the last one # If we error on each attempt, raise the last one
if i == len(suppliers) - 1: if i == len(suppliers) - 1:
raise raise

View File

@@ -9,7 +9,10 @@ class GenericTokenizer:
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
try:
self.valid_tokens = set(self.tokenizer.vocab.values()) self.valid_tokens = set(self.tokenizer.vocab.values())
except AttributeError:
self.valid_tokens = set(self.tokenizer.get_vocab().values())
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
# Fall back to tokenizer for non-generic stuff # Fall back to tokenizer for non-generic stuff