Model: Bugfixes/fix tokenizer hack

This commit is contained in:
somebody
2023-02-27 18:30:13 -06:00
parent af73527be0
commit 2e3ca6f769

View File

@@ -526,12 +526,10 @@ class InferenceModel:
raise NotImplementedError raise NotImplementedError
def _get_tokenizer(self, location: str): def _get_tokenizer(self, location: str):
# TODO: This newlinemode inference might need more scrutiny if utils.koboldai_vars.model_type == "xglm":
utils.koboldai_vars.newlinemode = "n"
if "xglm" in location:
# Default to </s> newline mode if using XGLM # Default to </s> newline mode if using XGLM
utils.koboldai_vars.newlinemode = "s" utils.koboldai_vars.newlinemode = "s"
if "opt" in location or "bloom" in location: elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them # Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
utils.koboldai_vars.newlinemode = "ns" utils.koboldai_vars.newlinemode = "ns"
@@ -554,10 +552,10 @@ class InferenceModel:
for i, try_get_tokenizer in enumerate(suppliers): for i, try_get_tokenizer in enumerate(suppliers):
try: try:
return try_get_tokenizer() return try_get_tokenizer()
except Exception as e: except:
# If we error on each attempt, raise the last one # If we error on each attempt, raise the last one
if i == len(suppliers) - 1: if i == len(suppliers) - 1:
raise e raise
def core_generate( def core_generate(
self, self,
@@ -886,7 +884,7 @@ class InferenceModel:
hook(self, input_ids) hook(self, input_ids)
class HFMTJInferenceModel: class HFMTJInferenceModel(InferenceModel):
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
@@ -2270,11 +2268,10 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
tf_kwargs.pop("low_cpu_mem_usage", None) tf_kwargs.pop("low_cpu_mem_usage", None)
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
if self.get_local_model_path(): if self.get_local_model_path():
# Model is stored locally, load it. # Model is stored locally, load it.
self.model = self._get_model(self.get_local_model_path(), tf_kwargs) self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
else: else:
# Model not stored locally, we need to download it. # Model not stored locally, we need to download it.
@@ -2299,6 +2296,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
torch._utils._rebuild_tensor = new_rebuild_tensor torch._utils._rebuild_tensor = new_rebuild_tensor
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs) self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
if save_model: if save_model:
@@ -2460,7 +2458,6 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
with open( with open(
os.path.join(possible_config_path, "config.json"), "r" os.path.join(possible_config_path, "config.json"), "r"
) as file: ) as file:
# Unused?
self.model_config = json.load(file) self.model_config = json.load(file)
model_path = possible_config_path model_path = possible_config_path
break break