Use torch.where to inject the soft prompt instead of torch.cat

This commit is contained in:
Gnome Ann 2021-10-28 13:20:14 -04:00
parent 248e0bd24b
commit 1556bd32a5
1 changed files with 45 additions and 23 deletions

View File

@ -515,17 +515,25 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
def patch_causallm(cls):
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
num_embeddings = self.config.vocab_size
if(vars.sp is not None):
num_embeddings += vars.sp.shape[0]
if(self.transformer.wte.num_embeddings != num_embeddings):
self.resize_token_embeddings(num_embeddings)
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
inputs_embeds = self.transformer.wte(input_ids)
input_ids -= self.config.vocab_size
if(vars.sp is not None):
inputs_embeds = torch.cat((
vars.sp.tile((inputs_embeds.shape[0], 1, 1)),
inputs_embeds
), dim=1).to(self.device)
vars.sp = vars.sp.to(inputs_embeds.device)
inputs_embeds = torch.where(
(input_ids >= 0)[:, :, None],
vars.sp[input_ids.clamp(min=0)],
inputs_embeds,
)
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(*args, **kwargs)
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
patch_causallm(cls)
@ -543,13 +551,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda):
if(vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
model = model.to(0)
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
device_config(model)
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
# If custom GPT2 model was chosen
elif(vars.model == "GPT2Custom"):
model_config = open(vars.custmodpth + "/config.json", "r")
@ -559,9 +568,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda and vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
model = model.to(0)
generator = model.generate
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
# If base HuggingFace model was chosen
else:
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
@ -570,16 +580,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.usegpu):
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
vars.modeldim = int(model.transformer.hidden_size)
generator = pipeline('text-generation', model=model, device=0)
model = model.to(0)
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
model = AutoModelForCausalLM.from_pretrained(vars.model)
device_config(model)
else:
model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = pipeline('text-generation', model=vars.model)
generator = model.generate
else:
model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = pipeline('text-generation', model=vars.model)
generator = model.generate
# Suppress Author's Note by flagging square brackets (Old implementation)
#vocab = tokenizer.get_vocab()
@ -1176,7 +1187,10 @@ def calcsubmit(txt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(vars.sp is not None):
budget -= vars.sp.shape[0]
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
@ -1336,14 +1350,23 @@ def generate(txt, min, max):
top_k = vars.top_k if vars.top_k > 0 else None
tfs = vars.tfs if vars.tfs > 0.0 else None
# generator() only accepts a torch tensor of tokens (long datatype) as
# its first argument if we're using breakmodel, otherwise a string
# is fine
if(vars.hascuda and vars.breakmodel):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
elif(vars.hascuda and vars.breakmodel):
gen_in = gen_in.to(breakmodel.primary_device)
elif(vars.hascuda):
gen_in = gen_in.to(0)
else:
gen_in = txt
gen_in = gen_in.to('cpu')
with torch.no_grad():
genout = generator(
gen_in,
@ -1367,8 +1390,7 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
if(vars.hascuda and vars.breakmodel):
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])