mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Bugfixes/fix tokenizer hack
This commit is contained in:
17
model.py
17
model.py
@@ -526,12 +526,10 @@ class InferenceModel:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_tokenizer(self, location: str):
|
def _get_tokenizer(self, location: str):
|
||||||
# TODO: This newlinemode inference might need more scrutiny
|
if utils.koboldai_vars.model_type == "xglm":
|
||||||
utils.koboldai_vars.newlinemode = "n"
|
|
||||||
if "xglm" in location:
|
|
||||||
# Default to </s> newline mode if using XGLM
|
# Default to </s> newline mode if using XGLM
|
||||||
utils.koboldai_vars.newlinemode = "s"
|
utils.koboldai_vars.newlinemode = "s"
|
||||||
if "opt" in location or "bloom" in location:
|
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||||
utils.koboldai_vars.newlinemode = "ns"
|
utils.koboldai_vars.newlinemode = "ns"
|
||||||
|
|
||||||
@@ -554,10 +552,10 @@ class InferenceModel:
|
|||||||
for i, try_get_tokenizer in enumerate(suppliers):
|
for i, try_get_tokenizer in enumerate(suppliers):
|
||||||
try:
|
try:
|
||||||
return try_get_tokenizer()
|
return try_get_tokenizer()
|
||||||
except Exception as e:
|
except:
|
||||||
# If we error on each attempt, raise the last one
|
# If we error on each attempt, raise the last one
|
||||||
if i == len(suppliers) - 1:
|
if i == len(suppliers) - 1:
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
def core_generate(
|
def core_generate(
|
||||||
self,
|
self,
|
||||||
@@ -886,7 +884,7 @@ class InferenceModel:
|
|||||||
hook(self, input_ids)
|
hook(self, input_ids)
|
||||||
|
|
||||||
|
|
||||||
class HFMTJInferenceModel:
|
class HFMTJInferenceModel(InferenceModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -2270,11 +2268,10 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||||
tf_kwargs.pop("low_cpu_mem_usage", None)
|
tf_kwargs.pop("low_cpu_mem_usage", None)
|
||||||
|
|
||||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
|
||||||
|
|
||||||
if self.get_local_model_path():
|
if self.get_local_model_path():
|
||||||
# Model is stored locally, load it.
|
# Model is stored locally, load it.
|
||||||
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
|
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
|
||||||
|
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||||
else:
|
else:
|
||||||
# Model not stored locally, we need to download it.
|
# Model not stored locally, we need to download it.
|
||||||
|
|
||||||
@@ -2299,6 +2296,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
|
|
||||||
torch._utils._rebuild_tensor = new_rebuild_tensor
|
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||||
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
||||||
|
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
|
||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
@@ -2460,7 +2458,6 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
with open(
|
with open(
|
||||||
os.path.join(possible_config_path, "config.json"), "r"
|
os.path.join(possible_config_path, "config.json"), "r"
|
||||||
) as file:
|
) as file:
|
||||||
# Unused?
|
|
||||||
self.model_config = json.load(file)
|
self.model_config = json.load(file)
|
||||||
model_path = possible_config_path
|
model_path = possible_config_path
|
||||||
break
|
break
|
||||||
|
Reference in New Issue
Block a user