mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Make sure that soft_tokens is on the correct device
This commit is contained in:
@ -1471,6 +1471,7 @@ def generate(txt, min, max):
|
|||||||
soft_tokens = torch.arange(
|
soft_tokens = torch.arange(
|
||||||
model.config.vocab_size,
|
model.config.vocab_size,
|
||||||
model.config.vocab_size + vars.sp.shape[0],
|
model.config.vocab_size + vars.sp.shape[0],
|
||||||
|
device=genout.device,
|
||||||
)
|
)
|
||||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||||
diff = gen_in.shape[-1] - genout.shape[-1]
|
diff = gen_in.shape[-1] - genout.shape[-1]
|
||||||
|
Reference in New Issue
Block a user