Use torch.where to inject the soft prompt instead of torch.cat
This commit is contained in:
parent
248e0bd24b
commit
1556bd32a5
68
aiserver.py
68
aiserver.py
|
@ -515,17 +515,25 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
def patch_causallm(cls):
|
||||
old_forward = cls.forward
|
||||
def new_causallm_forward(self, *args, **kwargs):
|
||||
num_embeddings = self.config.vocab_size
|
||||
if(vars.sp is not None):
|
||||
num_embeddings += vars.sp.shape[0]
|
||||
if(self.transformer.wte.num_embeddings != num_embeddings):
|
||||
self.resize_token_embeddings(num_embeddings)
|
||||
input_ids = kwargs.get('input_ids').to(self.device)
|
||||
assert input_ids is not None
|
||||
kwargs['input_ids'] = None
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
input_ids -= self.config.vocab_size
|
||||
if(vars.sp is not None):
|
||||
inputs_embeds = torch.cat((
|
||||
vars.sp.tile((inputs_embeds.shape[0], 1, 1)),
|
||||
inputs_embeds
|
||||
), dim=1).to(self.device)
|
||||
vars.sp = vars.sp.to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
(input_ids >= 0)[:, :, None],
|
||||
vars.sp[input_ids.clamp(min=0)],
|
||||
inputs_embeds,
|
||||
)
|
||||
kwargs['inputs_embeds'] = inputs_embeds
|
||||
return old_forward(*args, **kwargs)
|
||||
return old_forward(self, *args, **kwargs)
|
||||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
|
@ -543,13 +551,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
|
||||
model = model.to(0)
|
||||
generator = model.generate
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
device_config(model)
|
||||
else:
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
generator = model.generate
|
||||
else:
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
generator = model.generate
|
||||
# If custom GPT2 model was chosen
|
||||
elif(vars.model == "GPT2Custom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
|
@ -559,9 +568,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
|
||||
model = model.to(0)
|
||||
generator = model.generate
|
||||
else:
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
generator = model.generate
|
||||
# If base HuggingFace model was chosen
|
||||
else:
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
|
@ -570,16 +580,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
if(vars.usegpu):
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
|
||||
vars.modeldim = int(model.transformer.hidden_size)
|
||||
generator = pipeline('text-generation', model=model, device=0)
|
||||
model = model.to(0)
|
||||
generator = model.generate
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
device_config(model)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
generator = model.generate
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
generator = model.generate
|
||||
|
||||
# Suppress Author's Note by flagging square brackets (Old implementation)
|
||||
#vocab = tokenizer.get_vocab()
|
||||
|
@ -1176,7 +1187,10 @@ def calcsubmit(txt):
|
|||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
else:
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
||||
|
||||
|
||||
if(vars.sp is not None):
|
||||
budget -= vars.sp.shape[0]
|
||||
|
||||
if(actionlen == 0):
|
||||
# First/Prompt action
|
||||
subtxt = vars.memory + winfo + anotetxt + vars.prompt
|
||||
|
@ -1336,14 +1350,23 @@ def generate(txt, min, max):
|
|||
top_k = vars.top_k if vars.top_k > 0 else None
|
||||
tfs = vars.tfs if vars.tfs > 0.0 else None
|
||||
|
||||
# generator() only accepts a torch tensor of tokens (long datatype) as
|
||||
# its first argument if we're using breakmodel, otherwise a string
|
||||
# is fine
|
||||
if(vars.hascuda and vars.breakmodel):
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
|
||||
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
gen_in = gen_in.to(0)
|
||||
elif(vars.hascuda and vars.breakmodel):
|
||||
gen_in = gen_in.to(breakmodel.primary_device)
|
||||
elif(vars.hascuda):
|
||||
gen_in = gen_in.to(0)
|
||||
else:
|
||||
gen_in = txt
|
||||
|
||||
gen_in = gen_in.to('cpu')
|
||||
|
||||
with torch.no_grad():
|
||||
genout = generator(
|
||||
gen_in,
|
||||
|
@ -1367,8 +1390,7 @@ def generate(txt, min, max):
|
|||
return
|
||||
|
||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
||||
if(vars.hascuda and vars.breakmodel):
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
|
|
Loading…
Reference in New Issue