Merge pull request #20 from VE-FORBRYDERNE/sp

Soft prompt support for PyTorch models
This commit is contained in:
henk717
2021-10-30 00:35:44 +02:00
committed by GitHub
5 changed files with 377 additions and 28 deletions

View File

@ -12,7 +12,8 @@ import tkinter as tk
from tkinter import messagebox
import json
import collections
from typing import Union
import zipfile
from typing import Union, Tuple
import requests
import html
@ -103,7 +104,9 @@ class vars:
formatoptns = {'frmttriminc': True, 'frmtrmblln': False, 'frmtrmspch': False, 'frmtadsnsp': False, 'singleline': False} # Container for state of formatting options
importnum = -1 # Selection on import popup list
importjs = {} # Temporary storage for import data
loadselect = "" # Temporary storage for filename to load
loadselect = "" # Temporary storage for story filename to load
spselect = "" # Temporary storage for soft prompt filename to load
sp = None # Current soft prompt tensor (as a NumPy array)
svowname = "" # Filename that was flagged for overwrite confirm
saveow = False # Whether or not overwrite confirm has been displayed
genseqs = [] # Temporary storage for generated sequences
@ -113,6 +116,8 @@ class vars:
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
smandelete = False # Whether stories can be deleted from inside the browser
smanrename = False # Whether stories can be renamed from inside the browser
allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
laststory = None # Filename (without extension) of most recent story JSON file we loaded
regex_sl = re.compile(r'\n*(?<=.) *\n(.|\n)*') # Pattern for limiting the output to a single line
acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them
@ -312,6 +317,7 @@ else:
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
vars.allowsp = True
# Test for GPU support
import torch
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -506,11 +512,41 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
# Patch transformers to use our soft prompt
def patch_causallm(cls):
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1)
inputs_embeds = self.transformer.wte(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
patch_causallm(cls)
try:
from transformers import GPTJForCausalLM
patch_causallm(GPTJForCausalLM)
except:
pass
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
vars.modeldim = int(js['hidden_size'])
if("model_type" in js):
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth)
else:
@ -519,36 +555,46 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda):
if(vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
model = model.to(0)
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
device_config(model)
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
# If custom GPT2 model was chosen
elif(vars.model == "GPT2Custom"):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
vars.modeldim = int(js['hidden_size'])
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth)
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda and vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
model = model.to(0)
generator = model.generate
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generator = model.generate
# If base HuggingFace model was chosen
else:
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
tokenizer = GPT2Tokenizer.from_pretrained(vars.model)
if(vars.hascuda):
if(vars.usegpu):
generator = pipeline('text-generation', model=vars.model, device=0)
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
vars.modeldim = int(model.transformer.hidden_size)
model = model.to(0)
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
model = AutoModelForCausalLM.from_pretrained(vars.model)
device_config(model)
else:
generator = pipeline('text-generation', model=vars.model)
model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = model.generate
else:
generator = pipeline('text-generation', model=vars.model)
model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = model.generate
# Suppress Author's Note by flagging square brackets (Old implementation)
#vocab = tokenizer.get_vocab()
@ -815,10 +861,16 @@ def get_message(msg):
save()
elif(msg['cmd'] == 'loadlistrequest'):
getloadlist()
elif(msg['cmd'] == 'splistrequest'):
getsplist()
elif(msg['cmd'] == 'loadselect'):
vars.loadselect = msg["data"]
elif(msg['cmd'] == 'spselect'):
vars.spselect = msg["data"]
elif(msg['cmd'] == 'loadrequest'):
loadRequest(fileops.storypath(vars.loadselect))
elif(msg['cmd'] == 'sprequest'):
spRequest(vars.spselect)
elif(msg['cmd'] == 'deletestory'):
deletesave(msg['data'])
elif(msg['cmd'] == 'renamestory'):
@ -854,6 +906,8 @@ def get_message(msg):
#==================================================================#
def setStartState():
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI</span>! You are running <span class=\"color_green\">"+getmodelname()+"</span>.<br/>"
if(vars.allowsp):
emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True)
if(not vars.noai):
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
else:
@ -1131,15 +1185,17 @@ def calcsubmit(txt):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
if(vars.useprompt):
budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnmem + lnwi + lnprompt + lnanote
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
@ -1295,14 +1351,23 @@ def generate(txt, min, max):
top_k = vars.top_k if vars.top_k > 0 else None
tfs = vars.tfs if vars.tfs > 0.0 else None
# generator() only accepts a torch tensor of tokens (long datatype) as
# its first argument if we're using breakmodel, otherwise a string
# is fine
if(vars.hascuda and vars.breakmodel):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
elif(vars.hascuda and vars.breakmodel):
gen_in = gen_in.to(breakmodel.primary_device)
elif(vars.hascuda):
gen_in = gen_in.to(0)
else:
gen_in = txt
gen_in = gen_in.to('cpu')
with torch.no_grad():
genout = generator(
gen_in,
@ -1326,8 +1391,7 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
if(vars.hascuda and vars.breakmodel):
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -2143,11 +2207,18 @@ def saveRequest(savpath):
print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
#==================================================================#
# Load a saved story via file browser
# Show list of saved stories
#==================================================================#
def getloadlist():
emit('from_server', {'cmd': 'buildload', 'data': fileops.getstoryfiles()})
#==================================================================#
# Show list of soft prompts
#==================================================================#
def getsplist():
if(vars.allowsp):
emit('from_server', {'cmd': 'buildsp', 'data': fileops.getspfiles(vars.modeldim)})
#==================================================================#
# Load a saved story via file browser
#==================================================================#
@ -2233,6 +2304,37 @@ def loadRequest(loadpath):
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END))
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
if(len(filename) == 0):
vars.sp = None
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp = torch.from_numpy(tensor)
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
#==================================================================#