mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for GPU generation
This commit is contained in:
@@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
else:
|
else:
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
|
|
||||||
|
def get_auxilary_device(self):
|
||||||
|
"""Get device auxilary tensors like inputs should be stored on."""
|
||||||
|
|
||||||
|
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
|
||||||
|
if utils.koboldai_vars.hascuda and self.usegpu:
|
||||||
|
return utils.koboldai_vars.gpu_device
|
||||||
|
elif utils.koboldai_vars.hascuda and self.breakmodel:
|
||||||
|
import breakmodel
|
||||||
|
return breakmodel.primary_device
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
def _post_load(m_self) -> None:
|
def _post_load(m_self) -> None:
|
||||||
|
|
||||||
if not utils.koboldai_vars.model_type:
|
if not utils.koboldai_vars.model_type:
|
||||||
@@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
else:
|
else:
|
||||||
gen_in = prompt_tokens
|
gen_in = prompt_tokens
|
||||||
|
|
||||||
device = utils.get_auxilary_device()
|
device = self.get_auxilary_device()
|
||||||
gen_in = gen_in.to(device)
|
gen_in = gen_in.to(device)
|
||||||
|
|
||||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||||
|
@@ -4012,16 +4012,18 @@ function model_settings_checker() {
|
|||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
if (settings_area) {
|
||||||
var element_data = element.value;
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
if (element.getAttribute("data_type") == "int") {
|
var element_data = element.value;
|
||||||
element_data = parseInt(element_data);
|
if (element.getAttribute("data_type") == "int") {
|
||||||
} else if (element.getAttribute("data_type") == "float") {
|
element_data = parseInt(element_data);
|
||||||
element_data = parseFloat(element_data);
|
} else if (element.getAttribute("data_type") == "float") {
|
||||||
} else if (element.getAttribute("data_type") == "bool") {
|
element_data = parseFloat(element_data);
|
||||||
element_data = (element_data == 'on');
|
} else if (element.getAttribute("data_type") == "bool") {
|
||||||
|
element_data = (element_data == 'on');
|
||||||
|
}
|
||||||
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
|
||||||
}
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
@@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) {
|
|||||||
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_settings_checker();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getModelParameterCount(modelName) {
|
function getModelParameterCount(modelName) {
|
||||||
@@ -4371,16 +4375,18 @@ function load_model() {
|
|||||||
|
|
||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
if (settings_area) {
|
||||||
var element_data = element.value;
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
if (element.getAttribute("data_type") == "int") {
|
var element_data = element.value;
|
||||||
element_data = parseInt(element_data);
|
if (element.getAttribute("data_type") == "int") {
|
||||||
} else if (element.getAttribute("data_type") == "float") {
|
element_data = parseInt(element_data);
|
||||||
element_data = parseFloat(element_data);
|
} else if (element.getAttribute("data_type") == "float") {
|
||||||
} else if (element.getAttribute("data_type") == "bool") {
|
element_data = parseFloat(element_data);
|
||||||
element_data = (element_data == 'on');
|
} else if (element.getAttribute("data_type") == "bool") {
|
||||||
|
element_data = (element_data == 'on');
|
||||||
|
}
|
||||||
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
|
||||||
}
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
|
@@ -1686,16 +1686,18 @@ function model_settings_checker() {
|
|||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
if (settings_area) {
|
||||||
var element_data = element.value;
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
if (element.getAttribute("data_type") == "int") {
|
var element_data = element.value;
|
||||||
element_data = parseInt(element_data);
|
if (element.getAttribute("data_type") == "int") {
|
||||||
} else if (element.getAttribute("data_type") == "float") {
|
element_data = parseInt(element_data);
|
||||||
element_data = parseFloat(element_data);
|
} else if (element.getAttribute("data_type") == "float") {
|
||||||
} else if (element.getAttribute("data_type") == "bool") {
|
element_data = parseFloat(element_data);
|
||||||
element_data = (element_data == 'on');
|
} else if (element.getAttribute("data_type") == "bool") {
|
||||||
|
element_data = (element_data == 'on');
|
||||||
|
}
|
||||||
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
|
||||||
}
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
@@ -1965,16 +1967,18 @@ function load_model() {
|
|||||||
|
|
||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
if (settings_area) {
|
||||||
var element_data = element.value;
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
if (element.getAttribute("data_type") == "int") {
|
var element_data = element.value;
|
||||||
element_data = parseInt(element_data);
|
if (element.getAttribute("data_type") == "int") {
|
||||||
} else if (element.getAttribute("data_type") == "float") {
|
element_data = parseInt(element_data);
|
||||||
element_data = parseFloat(element_data);
|
} else if (element.getAttribute("data_type") == "float") {
|
||||||
} else if (element.getAttribute("data_type") == "bool") {
|
element_data = parseFloat(element_data);
|
||||||
element_data = (element_data == 'on');
|
} else if (element.getAttribute("data_type") == "bool") {
|
||||||
|
element_data = (element_data == 'on');
|
||||||
|
}
|
||||||
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
|
||||||
}
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user