Merge pull request #250 from one-some/ui2-atnbiasfix

Fix attention bias aux device
This commit is contained in:
ebolam
2022-10-27 07:36:41 -04:00
committed by GitHub

View File

@@ -5054,7 +5054,8 @@ def calcsubmit(txt):
bias += [1] * (i - top_index) bias += [1] * (i - top_index)
bias[i] = b["multiplier"] bias[i] = b["multiplier"]
attention_bias.attention_bias = torch.Tensor(bias).to(breakmodel.primary_device) device = get_auxilary_device()
attention_bias.attention_bias = torch.Tensor(bias).to(device)
logger.info(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}") logger.info(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}")
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time)) logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
@@ -5307,6 +5308,13 @@ class GenerationSettings:
overrides.get(setting, getattr(koboldai_vars, setting)) overrides.get(setting, getattr(koboldai_vars, setting))
) )
def get_auxilary_device():
# NOTE: Does not include TPU!
if koboldai_vars.hascuda and koboldai_vars.usegpu:
return koboldai_vars.gpu_device
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
return breakmodel.primary_device
return "cpu"
def raw_generate( def raw_generate(
# prompt is either a string (text) or a list (token ids) # prompt is either a string (text) or a list (token ids)
@@ -5469,11 +5477,7 @@ def torch_raw_generate(
else: else:
gen_in = prompt_tokens gen_in = prompt_tokens
device = "cpu" device = get_auxilary_device()
if koboldai_vars.hascuda and koboldai_vars.usegpu:
device = koboldai_vars.gpu_device
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
device = breakmodel.primary_device
gen_in = gen_in.to(device) gen_in = gen_in.to(device)
with torch.no_grad(): with torch.no_grad():