Fix for GPU generation

This commit is contained in:
ebolam
2023-05-23 08:33:19 -04:00
parent a16b540c34
commit 5561cc1f22
3 changed files with 58 additions and 37 deletions

View File

@@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel):
else: else:
return "Unknown" return "Unknown"
def get_auxilary_device(self):
"""Get device auxilary tensors like inputs should be stored on."""
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
if utils.koboldai_vars.hascuda and self.usegpu:
return utils.koboldai_vars.gpu_device
elif utils.koboldai_vars.hascuda and self.breakmodel:
import breakmodel
return breakmodel.primary_device
return "cpu"
def _post_load(m_self) -> None: def _post_load(m_self) -> None:
if not utils.koboldai_vars.model_type: if not utils.koboldai_vars.model_type:
@@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel):
else: else:
gen_in = prompt_tokens gen_in = prompt_tokens
device = utils.get_auxilary_device() device = self.get_auxilary_device()
gen_in = gen_in.to(device) gen_in = gen_in.to(device)
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []

View File

@@ -4012,16 +4012,18 @@ function model_settings_checker() {
//get an object of all the input settings from the user //get an object of all the input settings from the user
data = {} data = {}
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area"); settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { if (settings_area) {
var element_data = element.value; for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
if (element.getAttribute("data_type") == "int") { var element_data = element.value;
element_data = parseInt(element_data); if (element.getAttribute("data_type") == "int") {
} else if (element.getAttribute("data_type") == "float") { element_data = parseInt(element_data);
element_data = parseFloat(element_data); } else if (element.getAttribute("data_type") == "float") {
} else if (element.getAttribute("data_type") == "bool") { element_data = parseFloat(element_data);
element_data = (element_data == 'on'); } else if (element.getAttribute("data_type") == "bool") {
element_data = (element_data == 'on');
}
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data = {...data, ...selected_model_data}; data = {...data, ...selected_model_data};
@@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) {
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden"); document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
} }
model_settings_checker();
} }
function getModelParameterCount(modelName) { function getModelParameterCount(modelName) {
@@ -4371,16 +4375,18 @@ function load_model() {
//get an object of all the input settings from the user //get an object of all the input settings from the user
data = {} data = {}
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { if (settings_area) {
var element_data = element.value; for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
if (element.getAttribute("data_type") == "int") { var element_data = element.value;
element_data = parseInt(element_data); if (element.getAttribute("data_type") == "int") {
} else if (element.getAttribute("data_type") == "float") { element_data = parseInt(element_data);
element_data = parseFloat(element_data); } else if (element.getAttribute("data_type") == "float") {
} else if (element.getAttribute("data_type") == "bool") { element_data = parseFloat(element_data);
element_data = (element_data == 'on'); } else if (element.getAttribute("data_type") == "bool") {
element_data = (element_data == 'on');
}
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data = {...data, ...selected_model_data}; data = {...data, ...selected_model_data};

View File

@@ -1686,16 +1686,18 @@ function model_settings_checker() {
//get an object of all the input settings from the user //get an object of all the input settings from the user
data = {} data = {}
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area"); settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { if (settings_area) {
var element_data = element.value; for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
if (element.getAttribute("data_type") == "int") { var element_data = element.value;
element_data = parseInt(element_data); if (element.getAttribute("data_type") == "int") {
} else if (element.getAttribute("data_type") == "float") { element_data = parseInt(element_data);
element_data = parseFloat(element_data); } else if (element.getAttribute("data_type") == "float") {
} else if (element.getAttribute("data_type") == "bool") { element_data = parseFloat(element_data);
element_data = (element_data == 'on'); } else if (element.getAttribute("data_type") == "bool") {
element_data = (element_data == 'on');
}
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data = {...data, ...selected_model_data}; data = {...data, ...selected_model_data};
@@ -1965,16 +1967,18 @@ function load_model() {
//get an object of all the input settings from the user //get an object of all the input settings from the user
data = {} data = {}
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { if (settings_area) {
var element_data = element.value; for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
if (element.getAttribute("data_type") == "int") { var element_data = element.value;
element_data = parseInt(element_data); if (element.getAttribute("data_type") == "int") {
} else if (element.getAttribute("data_type") == "float") { element_data = parseInt(element_data);
element_data = parseFloat(element_data); } else if (element.getAttribute("data_type") == "float") {
} else if (element.getAttribute("data_type") == "bool") { element_data = parseFloat(element_data);
element_data = (element_data == 'on'); } else if (element.getAttribute("data_type") == "bool") {
element_data = (element_data == 'on');
}
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data[element.id.split("|")[1].replace("_value", "")] = element_data;
} }
data = {...data, ...selected_model_data}; data = {...data, ...selected_model_data};