mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for GPU generation
This commit is contained in:
@@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
def get_auxilary_device(self):
|
||||
"""Get device auxilary tensors like inputs should be stored on."""
|
||||
|
||||
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
|
||||
if utils.koboldai_vars.hascuda and self.usegpu:
|
||||
return utils.koboldai_vars.gpu_device
|
||||
elif utils.koboldai_vars.hascuda and self.breakmodel:
|
||||
import breakmodel
|
||||
return breakmodel.primary_device
|
||||
return "cpu"
|
||||
|
||||
def _post_load(m_self) -> None:
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
@@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
else:
|
||||
gen_in = prompt_tokens
|
||||
|
||||
device = utils.get_auxilary_device()
|
||||
device = self.get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
@@ -4012,6 +4012,7 @@ function model_settings_checker() {
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
@@ -4023,6 +4024,7 @@ function model_settings_checker() {
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
data['plugin'] = document.getElementById("modelplugin").value;
|
||||
@@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) {
|
||||
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
||||
}
|
||||
|
||||
model_settings_checker();
|
||||
|
||||
}
|
||||
|
||||
function getModelParameterCount(modelName) {
|
||||
@@ -4371,6 +4375,7 @@ function load_model() {
|
||||
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
@@ -4382,6 +4387,7 @@ function load_model() {
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
data['plugin'] = document.getElementById("modelplugin").value;
|
||||
|
@@ -1686,6 +1686,7 @@ function model_settings_checker() {
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
@@ -1697,6 +1698,7 @@ function model_settings_checker() {
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
data['plugin'] = document.getElementById("modelplugin").value;
|
||||
@@ -1965,6 +1967,7 @@ function load_model() {
|
||||
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
if (settings_area) {
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
var element_data = element.value;
|
||||
if (element.getAttribute("data_type") == "int") {
|
||||
@@ -1976,6 +1979,7 @@ function load_model() {
|
||||
}
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||
}
|
||||
}
|
||||
data = {...data, ...selected_model_data};
|
||||
|
||||
data['plugin'] = document.getElementById("modelplugin").value;
|
||||
|
Reference in New Issue
Block a user