mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #250 from one-some/ui2-atnbiasfix
Fix attention bias aux device
This commit is contained in:
16
aiserver.py
16
aiserver.py
@@ -5054,7 +5054,8 @@ def calcsubmit(txt):
|
|||||||
bias += [1] * (i - top_index)
|
bias += [1] * (i - top_index)
|
||||||
bias[i] = b["multiplier"]
|
bias[i] = b["multiplier"]
|
||||||
|
|
||||||
attention_bias.attention_bias = torch.Tensor(bias).to(breakmodel.primary_device)
|
device = get_auxilary_device()
|
||||||
|
attention_bias.attention_bias = torch.Tensor(bias).to(device)
|
||||||
logger.info(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}")
|
logger.info(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}")
|
||||||
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
|
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
|
||||||
|
|
||||||
@@ -5307,6 +5308,13 @@ class GenerationSettings:
|
|||||||
overrides.get(setting, getattr(koboldai_vars, setting))
|
overrides.get(setting, getattr(koboldai_vars, setting))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_auxilary_device():
|
||||||
|
# NOTE: Does not include TPU!
|
||||||
|
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||||
|
return koboldai_vars.gpu_device
|
||||||
|
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||||
|
return breakmodel.primary_device
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
def raw_generate(
|
def raw_generate(
|
||||||
# prompt is either a string (text) or a list (token ids)
|
# prompt is either a string (text) or a list (token ids)
|
||||||
@@ -5469,11 +5477,7 @@ def torch_raw_generate(
|
|||||||
else:
|
else:
|
||||||
gen_in = prompt_tokens
|
gen_in = prompt_tokens
|
||||||
|
|
||||||
device = "cpu"
|
device = get_auxilary_device()
|
||||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
|
||||||
device = koboldai_vars.gpu_device
|
|
||||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
|
||||||
device = breakmodel.primary_device
|
|
||||||
gen_in = gen_in.to(device)
|
gen_in = gen_in.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Reference in New Issue
Block a user