From 91bb433b5f302fe6f6e7c7de5cb78a3b17ad04ba Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 19 Mar 2023 19:03:20 -0500 Subject: [PATCH] GenericTokenizer: Fall back to defined tokenizer Shouldn't be relied on for model-agnostic code, but for loading processes where you know the tokenizer class used it should be okie dokie --- modeling/inference_models/generic_hf_torch.py | 4 ++-- modeling/tokenizer.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index 0a15be2b..15f943d5 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -137,8 +137,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): # Use save_pretrained to convert fp32 models to fp16, # unless we are using disk cache because save_pretrained # is not supported in that case - model = model.half() - model.save_pretrained( + self.model = self.model.half() + self.model.save_pretrained( self.get_local_model_path(ignore_existance=True), max_shard_size="500MiB", ) diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 6c41764b..0f6305e5 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Any, List, Union from tokenizers import Tokenizer import torch from transformers import PreTrainedTokenizer @@ -10,10 +10,16 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer - # TODO: Get rid of this - self._koboldai_header = [] + def __getattr__(self, name: str) -> Any: + # Fall back to tokenizer for non-generic stuff + return getattr(self.tokenizer, name) - self.get_vocab = tokenizer.get_vocab + def __setattr__(self, name: str, value: Any) -> None: + # To prevent infinite recursion on __init__ setting + if name == "tokenizer": + super().__setattr__(name, value) + return + setattr(self.tokenizer, name, value) def encode(self, text: str) -> list: if isinstance(self.tokenizer, PreTrainedTokenizer):