mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Bugfixes/fix tokenizer hack
This commit is contained in:
17
model.py
17
model.py
@@ -526,12 +526,10 @@ class InferenceModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_tokenizer(self, location: str):
|
||||
# TODO: This newlinemode inference might need more scrutiny
|
||||
utils.koboldai_vars.newlinemode = "n"
|
||||
if "xglm" in location:
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
if "opt" in location or "bloom" in location:
|
||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
utils.koboldai_vars.newlinemode = "ns"
|
||||
|
||||
@@ -554,10 +552,10 @@ class InferenceModel:
|
||||
for i, try_get_tokenizer in enumerate(suppliers):
|
||||
try:
|
||||
return try_get_tokenizer()
|
||||
except Exception as e:
|
||||
except:
|
||||
# If we error on each attempt, raise the last one
|
||||
if i == len(suppliers) - 1:
|
||||
raise e
|
||||
raise
|
||||
|
||||
def core_generate(
|
||||
self,
|
||||
@@ -886,7 +884,7 @@ class InferenceModel:
|
||||
hook(self, input_ids)
|
||||
|
||||
|
||||
class HFMTJInferenceModel:
|
||||
class HFMTJInferenceModel(InferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -2270,11 +2268,10 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||
tf_kwargs.pop("low_cpu_mem_usage", None)
|
||||
|
||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||
|
||||
if self.get_local_model_path():
|
||||
# Model is stored locally, load it.
|
||||
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
|
||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||
else:
|
||||
# Model not stored locally, we need to download it.
|
||||
|
||||
@@ -2299,6 +2296,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
|
||||
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
||||
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
|
||||
if save_model:
|
||||
@@ -2460,7 +2458,6 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
||||
with open(
|
||||
os.path.join(possible_config_path, "config.json"), "r"
|
||||
) as file:
|
||||
# Unused?
|
||||
self.model_config = json.load(file)
|
||||
model_path = possible_config_path
|
||||
break
|
||||
|
Reference in New Issue
Block a user