mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for GPU generation
This commit is contained in:
@@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
def get_auxilary_device(self):
|
||||
"""Get device auxilary tensors like inputs should be stored on."""
|
||||
|
||||
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
|
||||
if utils.koboldai_vars.hascuda and self.usegpu:
|
||||
return utils.koboldai_vars.gpu_device
|
||||
elif utils.koboldai_vars.hascuda and self.breakmodel:
|
||||
import breakmodel
|
||||
return breakmodel.primary_device
|
||||
return "cpu"
|
||||
|
||||
def _post_load(m_self) -> None:
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
@@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
else:
|
||||
gen_in = prompt_tokens
|
||||
|
||||
device = utils.get_auxilary_device()
|
||||
device = self.get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
@@ -4012,16 +4012,18 @@ function model_settings_checker() {
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
@@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) {
|
||||
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
||||
}
|
||||
|
||||
model_settings_checker();
|
||||
|
||||
}
|
||||
|
||||
function getModelParameterCount(modelName) {
|
||||
@@ -4371,16 +4375,18 @@ function load_model() {
|
||||
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
|
@@ -1686,16 +1686,18 @@ function model_settings_checker() {
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
@@ -1965,16 +1967,18 @@ function load_model() {
|
||||
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
element_data = parseInt(element_data);
|
||||
} else if (element.getAttribute("data_type") == "float") {
|
||||
element_data = parseFloat(element_data);
|
||||
} else if (element.getAttribute("data_type") == "bool") {
|
||||
element_data = (element_data == 'on');
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
|
Reference in New Issue
Block a user