mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for GPU generation
This commit is contained in:
@@ -125,6 +125,17 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
else:
|
else:
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
|
|
||||||
|
def get_auxilary_device(self):
|
||||||
|
"""Get device auxilary tensors like inputs should be stored on."""
|
||||||
|
|
||||||
|
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
|
||||||
|
if utils.koboldai_vars.hascuda and self.usegpu:
|
||||||
|
return utils.koboldai_vars.gpu_device
|
||||||
|
elif utils.koboldai_vars.hascuda and self.breakmodel:
|
||||||
|
import breakmodel
|
||||||
|
return breakmodel.primary_device
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
def _post_load(m_self) -> None:
|
def _post_load(m_self) -> None:
|
||||||
|
|
||||||
if not utils.koboldai_vars.model_type:
|
if not utils.koboldai_vars.model_type:
|
||||||
@@ -226,7 +237,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
else:
|
else:
|
||||||
gen_in = prompt_tokens
|
gen_in = prompt_tokens
|
||||||
|
|
||||||
device = utils.get_auxilary_device()
|
device = self.get_auxilary_device()
|
||||||
gen_in = gen_in.to(device)
|
gen_in = gen_in.to(device)
|
||||||
|
|
||||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||||
|
@@ -4012,6 +4012,7 @@ function model_settings_checker() {
|
|||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||||
|
if (settings_area) {
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
var element_data = element.value;
|
var element_data = element.value;
|
||||||
if (element.getAttribute("data_type") == "int") {
|
if (element.getAttribute("data_type") == "int") {
|
||||||
@@ -4023,6 +4024,7 @@ function model_settings_checker() {
|
|||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
data['plugin'] = document.getElementById("modelplugin").value;
|
data['plugin'] = document.getElementById("modelplugin").value;
|
||||||
@@ -4259,6 +4261,8 @@ function selected_model_info(sent_data) {
|
|||||||
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_settings_checker();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getModelParameterCount(modelName) {
|
function getModelParameterCount(modelName) {
|
||||||
@@ -4371,6 +4375,7 @@ function load_model() {
|
|||||||
|
|
||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
|
if (settings_area) {
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
var element_data = element.value;
|
var element_data = element.value;
|
||||||
if (element.getAttribute("data_type") == "int") {
|
if (element.getAttribute("data_type") == "int") {
|
||||||
@@ -4382,6 +4387,7 @@ function load_model() {
|
|||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
data['plugin'] = document.getElementById("modelplugin").value;
|
data['plugin'] = document.getElementById("modelplugin").value;
|
||||||
|
@@ -1686,6 +1686,7 @@ function model_settings_checker() {
|
|||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||||
|
if (settings_area) {
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
var element_data = element.value;
|
var element_data = element.value;
|
||||||
if (element.getAttribute("data_type") == "int") {
|
if (element.getAttribute("data_type") == "int") {
|
||||||
@@ -1697,6 +1698,7 @@ function model_settings_checker() {
|
|||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
data['plugin'] = document.getElementById("modelplugin").value;
|
data['plugin'] = document.getElementById("modelplugin").value;
|
||||||
@@ -1965,6 +1967,7 @@ function load_model() {
|
|||||||
|
|
||||||
//get an object of all the input settings from the user
|
//get an object of all the input settings from the user
|
||||||
data = {}
|
data = {}
|
||||||
|
if (settings_area) {
|
||||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||||
var element_data = element.value;
|
var element_data = element.value;
|
||||||
if (element.getAttribute("data_type") == "int") {
|
if (element.getAttribute("data_type") == "int") {
|
||||||
@@ -1976,6 +1979,7 @@ function load_model() {
|
|||||||
}
|
}
|
||||||
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
data[element.id.split("|")[1].replace("_value", "")] = element_data;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
data = {...data, ...selected_model_data};
|
data = {...data, ...selected_model_data};
|
||||||
|
|
||||||
data['plugin'] = document.getElementById("modelplugin").value;
|
data['plugin'] = document.getElementById("modelplugin").value;
|
||||||
|
Reference in New Issue
Block a user