diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index 0a15be2b..15f943d5 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -137,8 +137,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): # Use save_pretrained to convert fp32 models to fp16, # unless we are using disk cache because save_pretrained # is not supported in that case - model = model.half() - model.save_pretrained( + self.model = self.model.half() + self.model.save_pretrained( self.get_local_model_path(ignore_existance=True), max_shard_size="500MiB", ) diff --git a/modeling/tokenizer.py b/modeling/tokenizer.py index 6c41764b..0f6305e5 100644 --- a/modeling/tokenizer.py +++ b/modeling/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Any, List, Union from tokenizers import Tokenizer import torch from transformers import PreTrainedTokenizer @@ -10,10 +10,16 @@ class GenericTokenizer: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: self.tokenizer = tokenizer - # TODO: Get rid of this - self._koboldai_header = [] + def __getattr__(self, name: str) -> Any: + # Fall back to tokenizer for non-generic stuff + return getattr(self.tokenizer, name) - self.get_vocab = tokenizer.get_vocab + def __setattr__(self, name: str, value: Any) -> None: + # To prevent infinite recursion on __init__ setting + if name == "tokenizer": + super().__setattr__(name, value) + return + setattr(self.tokenizer, name, value) def encode(self, text: str) -> list: if isinstance(self.tokenizer, PreTrainedTokenizer):