mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Soft prompt support (6B Colabs not supported yet)
This commit is contained in:
93
aiserver.py
93
aiserver.py
@ -12,7 +12,8 @@ import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
import json
|
||||
import collections
|
||||
from typing import Union
|
||||
import zipfile
|
||||
from typing import Union, Tuple
|
||||
|
||||
import requests
|
||||
import html
|
||||
@ -103,7 +104,9 @@ class vars:
|
||||
formatoptns = {'frmttriminc': True, 'frmtrmblln': False, 'frmtrmspch': False, 'frmtadsnsp': False} # Container for state of formatting options
|
||||
importnum = -1 # Selection on import popup list
|
||||
importjs = {} # Temporary storage for import data
|
||||
loadselect = "" # Temporary storage for filename to load
|
||||
loadselect = "" # Temporary storage for story filename to load
|
||||
spselect = "" # Temporary storage for soft prompt filename to load
|
||||
sp = None # Current soft prompt tensor (as a NumPy array)
|
||||
svowname = "" # Filename that was flagged for overwrite confirm
|
||||
saveow = False # Whether or not overwrite confirm has been displayed
|
||||
genseqs = [] # Temporary storage for generated sequences
|
||||
@ -113,6 +116,8 @@ class vars:
|
||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
|
||||
smandelete = False # Whether stories can be deleted from inside the browser
|
||||
smanrename = False # Whether stories can be renamed from inside the browser
|
||||
allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
|
||||
modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
|
||||
laststory = None # Filename (without extension) of most recent story JSON file we loaded
|
||||
acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them
|
||||
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
||||
@ -311,6 +316,7 @@ else:
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
vars.allowsp = True
|
||||
# Test for GPU support
|
||||
import torch
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
@ -503,11 +509,31 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
|
||||
|
||||
# Patch transformers to use our soft prompt
|
||||
def patch_causallm(cls):
|
||||
old_forward = cls.forward
|
||||
def new_causallm_forward(self, *args, **kwargs):
|
||||
input_ids = kwargs.get('input_ids').to(self.device)
|
||||
assert input_ids is not None
|
||||
kwargs['input_ids'] = None
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
if(vars.sp is not None):
|
||||
inputs_embeds = torch.cat((
|
||||
vars.sp.tile((inputs_embeds.shape[0], 1, 1)),
|
||||
inputs_embeds
|
||||
), dim=1).to(self.device)
|
||||
kwargs['inputs_embeds'] = inputs_embeds
|
||||
return old_forward(*args, **kwargs)
|
||||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
vars.modeldim = int(js['hidden_size'])
|
||||
if("model_type" in js):
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth)
|
||||
else:
|
||||
@ -525,6 +551,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
# If custom GPT2 model was chosen
|
||||
elif(vars.model == "GPT2Custom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
vars.modeldim = int(js['hidden_size'])
|
||||
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
@ -538,13 +567,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.model)
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=vars.model, device=0)
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
|
||||
vars.modeldim = int(model.transformer.hidden_size)
|
||||
generator = pipeline('text-generation', model=model, device=0)
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
device_config(model)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
|
||||
# Suppress Author's Note by flagging square brackets (Old implementation)
|
||||
@ -807,10 +840,16 @@ def get_message(msg):
|
||||
save()
|
||||
elif(msg['cmd'] == 'loadlistrequest'):
|
||||
getloadlist()
|
||||
elif(msg['cmd'] == 'splistrequest'):
|
||||
getsplist()
|
||||
elif(msg['cmd'] == 'loadselect'):
|
||||
vars.loadselect = msg["data"]
|
||||
elif(msg['cmd'] == 'spselect'):
|
||||
vars.spselect = msg["data"]
|
||||
elif(msg['cmd'] == 'loadrequest'):
|
||||
loadRequest(fileops.storypath(vars.loadselect))
|
||||
elif(msg['cmd'] == 'sprequest'):
|
||||
spRequest(vars.spselect)
|
||||
elif(msg['cmd'] == 'deletestory'):
|
||||
deletesave(msg['data'])
|
||||
elif(msg['cmd'] == 'renamestory'):
|
||||
@ -846,6 +885,8 @@ def get_message(msg):
|
||||
#==================================================================#
|
||||
def setStartState():
|
||||
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI</span>! You are running <span class=\"color_green\">"+getmodelname()+"</span>.<br/>"
|
||||
if(vars.allowsp):
|
||||
emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True)
|
||||
if(not vars.noai):
|
||||
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
|
||||
else:
|
||||
@ -1123,10 +1164,12 @@ def calcsubmit(txt):
|
||||
anotetkns = tokenizer.encode(anotetxt)
|
||||
lnanote = len(anotetkns)
|
||||
|
||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||
|
||||
if(vars.useprompt):
|
||||
budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
else:
|
||||
budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
||||
|
||||
if(actionlen == 0):
|
||||
# First/Prompt action
|
||||
@ -2131,11 +2174,18 @@ def saveRequest(savpath):
|
||||
print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
|
||||
|
||||
#==================================================================#
|
||||
# Load a saved story via file browser
|
||||
# Show list of saved stories
|
||||
#==================================================================#
|
||||
def getloadlist():
|
||||
emit('from_server', {'cmd': 'buildload', 'data': fileops.getstoryfiles()})
|
||||
|
||||
#==================================================================#
|
||||
# Show list of soft prompts
|
||||
#==================================================================#
|
||||
def getsplist():
|
||||
if(vars.allowsp):
|
||||
emit('from_server', {'cmd': 'buildsp', 'data': fileops.getspfiles(vars.modeldim)})
|
||||
|
||||
#==================================================================#
|
||||
# Load a saved story via file browser
|
||||
#==================================================================#
|
||||
@ -2221,6 +2271,35 @@ def loadRequest(loadpath):
|
||||
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
|
||||
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END))
|
||||
|
||||
#==================================================================#
|
||||
# Load a soft prompt from a file
|
||||
#==================================================================#
|
||||
def spRequest(filename):
|
||||
if(len(filename) == 0):
|
||||
vars.sp = None
|
||||
return
|
||||
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
|
||||
with z.open('tensor.npy') as f:
|
||||
tensor = np.load(f, allow_pickle=False)
|
||||
|
||||
# If the tensor is in bfloat16 format, convert it to float32
|
||||
if(tensor.dtype == 'V2'):
|
||||
tensor.dtype = np.uint16
|
||||
tensor = np.uint32(tensor) << 16
|
||||
tensor.dtype = np.float32
|
||||
|
||||
tensor = np.float16(tensor)
|
||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
||||
#==================================================================#
|
||||
# Import an AIDungon game exported with Mimi's tool
|
||||
#==================================================================#
|
||||
|
Reference in New Issue
Block a user