gpt2 fixed

This commit is contained in:
ebolam
2023-05-23 20:33:55 -04:00
parent 839d56ebf2
commit 9bd445c2a8
2 changed files with 9 additions and 4 deletions

View File

@@ -59,7 +59,7 @@ class model_backend(HFTorchInferenceModel):
"low_cpu_mem_usage": True, "low_cpu_mem_usage": True,
} }
if utils.koboldai_vars.model_type == "gpt2": if self.model_type == "gpt2":
# We must disable low_cpu_mem_usage and if using a GPT-2 model # We must disable low_cpu_mem_usage and if using a GPT-2 model
# because GPT-2 is not compatible with this feature yet. # because GPT-2 is not compatible with this feature yet.
tf_kwargs.pop("low_cpu_mem_usage", None) tf_kwargs.pop("low_cpu_mem_usage", None)

View File

@@ -61,6 +61,7 @@ class HFInferenceModel(InferenceModel):
else: else:
self.model_config = AutoConfig.from_pretrained(model_name, revision=utils.koboldai_vars.revision, cache_dir="cache") self.model_config = AutoConfig.from_pretrained(model_name, revision=utils.koboldai_vars.revision, cache_dir="cache")
layer_count = self.model_config["n_layer"] if isinstance(self.model_config, dict) else self.model_config.num_layers if hasattr(self.model_config, "num_layers") else self.model_config.n_layer if hasattr(self.model_config, "n_layer") else self.model_config.num_hidden_layers if hasattr(self.model_config, 'num_hidden_layers') else None layer_count = self.model_config["n_layer"] if isinstance(self.model_config, dict) else self.model_config.num_layers if hasattr(self.model_config, "num_layers") else self.model_config.n_layer if hasattr(self.model_config, "n_layer") else self.model_config.num_hidden_layers if hasattr(self.model_config, 'num_hidden_layers') else None
layer_count = None if hasattr(self, "get_model_type") and self.get_model_type() == "gpt2" else layer_count #Skip layers if we're a GPT2 model as it doesn't support breakmodel
if layer_count is not None and layer_count >= 0 and not self.nobreakmodel: if layer_count is not None and layer_count >= 0 and not self.nobreakmodel:
if os.path.exists("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_"))) and 'base_url' not in vars(self): if os.path.exists("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_"))) and 'base_url' not in vars(self):
with open("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_")), "r") as f: with open("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_")), "r") as f:
@@ -143,15 +144,13 @@ class HFInferenceModel(InferenceModel):
return requested_parameters return requested_parameters
def set_input_parameters(self, parameters): def set_input_parameters(self, parameters):
if self.hf_torch: if self.hf_torch and hasattr(self, "get_model_type") and self.get_model_type() != "gpt2":
import breakmodel import breakmodel
layer_count = self.model_config["n_layer"] if isinstance(self.model_config, dict) else self.model_config.num_layers if hasattr(self.model_config, "num_layers") else self.model_config.n_layer if hasattr(self.model_config, "n_layer") else self.model_config.num_hidden_layers if hasattr(self.model_config, 'num_hidden_layers') else None layer_count = self.model_config["n_layer"] if isinstance(self.model_config, dict) else self.model_config.num_layers if hasattr(self.model_config, "num_layers") else self.model_config.n_layer if hasattr(self.model_config, "n_layer") else self.model_config.num_hidden_layers if hasattr(self.model_config, 'num_hidden_layers') else None
if layer_count is not None and layer_count >= 0 and not self.nobreakmodel: if layer_count is not None and layer_count >= 0 and not self.nobreakmodel:
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
layers = [] layers = []
logger.info(parameters)
for i in range(gpu_count): for i in range(gpu_count):
logger.info(parameters["{}_Layers".format(i)])
if isinstance(parameters["{}_Layers".format(i)], str) and parameters["{}_Layers".format(i)].isnumeric(): if isinstance(parameters["{}_Layers".format(i)], str) and parameters["{}_Layers".format(i)].isnumeric():
layers.append(int(parameters["{}_Layers".format(i)])) layers.append(int(parameters["{}_Layers".format(i)]))
elif isinstance(parameters["{}_Layers".format(i)], str): elif isinstance(parameters["{}_Layers".format(i)], str):
@@ -170,8 +169,13 @@ class HFInferenceModel(InferenceModel):
self.usegpu = self.cpu_layers == 0 and breakmodel.disk_blocks == 0 and sum(self.layers)-self.layers[0] == 0 self.usegpu = self.cpu_layers == 0 and breakmodel.disk_blocks == 0 and sum(self.layers)-self.layers[0] == 0
self.model_type = self.get_model_type() self.model_type = self.get_model_type()
self.breakmodel = ((self.model_type != 'gpt2') or self.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not self.nobreakmodel self.breakmodel = ((self.model_type != 'gpt2') or self.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not self.nobreakmodel
self.lazy_load = True
logger.debug("Model type: {}".format(self.model_type))
else: else:
logger.debug("Disabling breakmodel and lazyload")
self.usegpu = parameters['use_gpu'] if 'use_gpu' in parameters else None self.usegpu = parameters['use_gpu'] if 'use_gpu' in parameters else None
self.breakmodel = False
self.lazy_load = False
self.model_name = parameters['custom_model_name'] if 'custom_model_name' in parameters else parameters['id'] self.model_name = parameters['custom_model_name'] if 'custom_model_name' in parameters else parameters['id']
self.path = parameters['path'] if 'path' in parameters else None self.path = parameters['path'] if 'path' in parameters else None
@@ -199,6 +203,7 @@ class HFInferenceModel(InferenceModel):
pass pass
def _post_load(self) -> None: def _post_load(self) -> None:
utils.koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
self.model_type = str(self.model_config.model_type) self.model_type = str(self.model_config.model_type)
# These are model specific tokenizer overrides if a model has bad defaults # These are model specific tokenizer overrides if a model has bad defaults
if self.model_type == "llama": if self.model_type == "llama":