From 111028642e2b3b4679ee68c28c91f8edb76966b1 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 1 May 2023 19:42:52 -0500 Subject: [PATCH] Fix tokenizer fallback for llama --- modeling/inference_model.py | 3 ++- modeling/tokenizer.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 7417f558..8d0c5294 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -223,7 +223,8 @@ class InferenceModel: for i, try_get_tokenizer in enumerate(suppliers): try: return GenericTokenizer(try_get_tokenizer()) - except: + except Exception as e: + logger.warn(f"Tokenizer falling back due to {e}") # If we error on each attempt, raise the last one if i == len(suppliers) - 1: raise diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 84555501..8c2dacf2 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -9,7 +9,10 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer - self.valid_tokens = set(self.tokenizer.vocab.values()) + try: + self.valid_tokens = set(self.tokenizer.vocab.values()) + except AttributeError: + self.valid_tokens = set(self.tokenizer.get_vocab().values()) def __getattr__(self, name: str) -> Any: # Fall back to tokenizer for non-generic stuff