Model: Bugfixes/fix tokenizer hack

This commit is contained in:
somebody
2023-02-27 18:30:13 -06:00
parent af73527be0
commit 2e3ca6f769

View File

@@ -526,12 +526,10 @@ class InferenceModel:
raise NotImplementedError
def _get_tokenizer(self, location: str):
# TODO: This newlinemode inference might need more scrutiny
utils.koboldai_vars.newlinemode = "n"
if "xglm" in location:
if utils.koboldai_vars.model_type == "xglm":
# Default to </s> newline mode if using XGLM
utils.koboldai_vars.newlinemode = "s"
if "opt" in location or "bloom" in location:
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
utils.koboldai_vars.newlinemode = "ns"
@@ -554,10 +552,10 @@ class InferenceModel:
for i, try_get_tokenizer in enumerate(suppliers):
try:
return try_get_tokenizer()
except Exception as e:
except:
# If we error on each attempt, raise the last one
if i == len(suppliers) - 1:
raise e
raise
def core_generate(
self,
@@ -886,7 +884,7 @@ class InferenceModel:
hook(self, input_ids)
class HFMTJInferenceModel:
class HFMTJInferenceModel(InferenceModel):
def __init__(
self,
model_name: str,
@@ -2270,11 +2268,10 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
tf_kwargs.pop("low_cpu_mem_usage", None)
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
if self.get_local_model_path():
# Model is stored locally, load it.
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
else:
# Model not stored locally, we need to download it.
@@ -2299,6 +2296,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
torch._utils._rebuild_tensor = new_rebuild_tensor
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
torch._utils._rebuild_tensor = old_rebuild_tensor
if save_model:
@@ -2460,7 +2458,6 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
with open(
os.path.join(possible_config_path, "config.json"), "r"
) as file:
# Unused?
self.model_config = json.load(file)
model_path = possible_config_path
break