diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 7417f558..8d0c5294 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -223,7 +223,8 @@ class InferenceModel: for i, try_get_tokenizer in enumerate(suppliers): try: return GenericTokenizer(try_get_tokenizer()) - except: + except Exception as e: + logger.warn(f"Tokenizer falling back due to {e}") # If we error on each attempt, raise the last one if i == len(suppliers) - 1: raise diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 84555501..8c2dacf2 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -9,7 +9,10 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer - self.valid_tokens = set(self.tokenizer.vocab.values()) + try: + self.valid_tokens = set(self.tokenizer.vocab.values()) + except AttributeError: + self.valid_tokens = set(self.tokenizer.get_vocab().values()) def __getattr__(self, name: str) -> Any: # Fall back to tokenizer for non-generic stuff