mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #250 from one-some/ui2-atnbiasfix
Fix attention bias aux device
This commit is contained in:
16
aiserver.py
16
aiserver.py
@@ -5054,7 +5054,8 @@ def calcsubmit(txt):
|
||||
bias += [1] * (i - top_index)
|
||||
bias[i] = b["multiplier"]
|
||||
|
||||
attention_bias.attention_bias = torch.Tensor(bias).to(breakmodel.primary_device)
|
||||
device = get_auxilary_device()
|
||||
attention_bias.attention_bias = torch.Tensor(bias).to(device)
|
||||
logger.info(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}")
|
||||
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
|
||||
|
||||
@@ -5307,6 +5308,13 @@ class GenerationSettings:
|
||||
overrides.get(setting, getattr(koboldai_vars, setting))
|
||||
)
|
||||
|
||||
def get_auxilary_device():
|
||||
# NOTE: Does not include TPU!
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
return koboldai_vars.gpu_device
|
||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||
return breakmodel.primary_device
|
||||
return "cpu"
|
||||
|
||||
def raw_generate(
|
||||
# prompt is either a string (text) or a list (token ids)
|
||||
@@ -5469,11 +5477,7 @@ def torch_raw_generate(
|
||||
else:
|
||||
gen_in = prompt_tokens
|
||||
|
||||
device = "cpu"
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
device = koboldai_vars.gpu_device
|
||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||
device = breakmodel.primary_device
|
||||
device = get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
Reference in New Issue
Block a user