From 5561cc1f220c0cf9d957bcbd3e535ad88502ab82 Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 23 May 2023 08:33:19 -0400 Subject: [PATCH] Fix for GPU generation --- modeling/inference_models/hf_torch.py | 13 ++++++++- static/application.js | 42 +++++++++++++++------------ static/koboldai.js | 40 +++++++++++++------------ 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 681d3ab1..2f575e73 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel): else: return "Unknown" + def get_auxilary_device(self): + """Get device auxilary tensors like inputs should be stored on.""" + + # NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU. + if utils.koboldai_vars.hascuda and self.usegpu: + return utils.koboldai_vars.gpu_device + elif utils.koboldai_vars.hascuda and self.breakmodel: + import breakmodel + return breakmodel.primary_device + return "cpu" + def _post_load(m_self) -> None: if not utils.koboldai_vars.model_type: @@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel): else: gen_in = prompt_tokens - device = utils.get_auxilary_device() + device = self.get_auxilary_device() gen_in = gen_in.to(device) additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] diff --git a/static/application.js b/static/application.js index ca445c5f..ca81f729 100644 --- a/static/application.js +++ b/static/application.js @@ -4012,16 +4012,18 @@ function model_settings_checker() { //get an object of all the input settings from the user data = {} settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area"); - for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { - var element_data = element.value; - if (element.getAttribute("data_type") == "int") { - element_data = parseInt(element_data); - } else if (element.getAttribute("data_type") == "float") { - element_data = parseFloat(element_data); - } else if (element.getAttribute("data_type") == "bool") { - element_data = (element_data == 'on'); + if (settings_area) { + for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { + var element_data = element.value; + if (element.getAttribute("data_type") == "int") { + element_data = parseInt(element_data); + } else if (element.getAttribute("data_type") == "float") { + element_data = parseFloat(element_data); + } else if (element.getAttribute("data_type") == "bool") { + element_data = (element_data == 'on'); + } + data[element.id.split("|")[1].replace("_value", "")] = element_data; } - data[element.id.split("|")[1].replace("_value", "")] = element_data; } data = {...data, ...selected_model_data}; @@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) { document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden"); } + model_settings_checker(); + } function getModelParameterCount(modelName) { @@ -4371,16 +4375,18 @@ function load_model() { //get an object of all the input settings from the user data = {} - for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { - var element_data = element.value; - if (element.getAttribute("data_type") == "int") { - element_data = parseInt(element_data); - } else if (element.getAttribute("data_type") == "float") { - element_data = parseFloat(element_data); - } else if (element.getAttribute("data_type") == "bool") { - element_data = (element_data == 'on'); + if (settings_area) { + for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { + var element_data = element.value; + if (element.getAttribute("data_type") == "int") { + element_data = parseInt(element_data); + } else if (element.getAttribute("data_type") == "float") { + element_data = parseFloat(element_data); + } else if (element.getAttribute("data_type") == "bool") { + element_data = (element_data == 'on'); + } + data[element.id.split("|")[1].replace("_value", "")] = element_data; } - data[element.id.split("|")[1].replace("_value", "")] = element_data; } data = {...data, ...selected_model_data}; diff --git a/static/koboldai.js b/static/koboldai.js index c4b2e160..f0a1f6f8 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -1686,16 +1686,18 @@ function model_settings_checker() { //get an object of all the input settings from the user data = {} settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area"); - for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { - var element_data = element.value; - if (element.getAttribute("data_type") == "int") { - element_data = parseInt(element_data); - } else if (element.getAttribute("data_type") == "float") { - element_data = parseFloat(element_data); - } else if (element.getAttribute("data_type") == "bool") { - element_data = (element_data == 'on'); + if (settings_area) { + for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { + var element_data = element.value; + if (element.getAttribute("data_type") == "int") { + element_data = parseInt(element_data); + } else if (element.getAttribute("data_type") == "float") { + element_data = parseFloat(element_data); + } else if (element.getAttribute("data_type") == "bool") { + element_data = (element_data == 'on'); + } + data[element.id.split("|")[1].replace("_value", "")] = element_data; } - data[element.id.split("|")[1].replace("_value", "")] = element_data; } data = {...data, ...selected_model_data}; @@ -1965,16 +1967,18 @@ function load_model() { //get an object of all the input settings from the user data = {} - for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { - var element_data = element.value; - if (element.getAttribute("data_type") == "int") { - element_data = parseInt(element_data); - } else if (element.getAttribute("data_type") == "float") { - element_data = parseFloat(element_data); - } else if (element.getAttribute("data_type") == "bool") { - element_data = (element_data == 'on'); + if (settings_area) { + for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { + var element_data = element.value; + if (element.getAttribute("data_type") == "int") { + element_data = parseInt(element_data); + } else if (element.getAttribute("data_type") == "float") { + element_data = parseFloat(element_data); + } else if (element.getAttribute("data_type") == "bool") { + element_data = (element_data == 'on'); + } + data[element.id.split("|")[1].replace("_value", "")] = element_data; } - data[element.id.split("|")[1].replace("_value", "")] = element_data; } data = {...data, ...selected_model_data};