GenericTokenizer: Fall back to defined tokenizer

Shouldn't be relied on for model-agnostic code, but for loading
processes where you know the tokenizer class used it should be okie
dokie
This commit is contained in:
somebody
2023-03-19 19:03:20 -05:00
parent 864f9ed8c3
commit 91bb433b5f
2 changed files with 12 additions and 6 deletions

View File

@@ -137,8 +137,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# Use save_pretrained to convert fp32 models to fp16, # Use save_pretrained to convert fp32 models to fp16,
# unless we are using disk cache because save_pretrained # unless we are using disk cache because save_pretrained
# is not supported in that case # is not supported in that case
model = model.half() self.model = self.model.half()
model.save_pretrained( self.model.save_pretrained(
self.get_local_model_path(ignore_existance=True), self.get_local_model_path(ignore_existance=True),
max_shard_size="500MiB", max_shard_size="500MiB",
) )

View File

@@ -1,4 +1,4 @@
from typing import List, Union from typing import Any, List, Union
from tokenizers import Tokenizer from tokenizers import Tokenizer
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -10,10 +10,16 @@ class GenericTokenizer:
def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None: def __init__(self, tokenizer: Union[Tokenizer, PreTrainedTokenizer]) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
# TODO: Get rid of this def __getattr__(self, name: str) -> Any:
self._koboldai_header = [] # Fall back to tokenizer for non-generic stuff
return getattr(self.tokenizer, name)
self.get_vocab = tokenizer.get_vocab def __setattr__(self, name: str, value: Any) -> None:
# To prevent infinite recursion on __init__ setting
if name == "tokenizer":
super().__setattr__(name, value)
return
setattr(self.tokenizer, name, value)
def encode(self, text: str) -> list: def encode(self, text: str) -> list:
if isinstance(self.tokenizer, PreTrainedTokenizer): if isinstance(self.tokenizer, PreTrainedTokenizer):