Integrate TPU backend

This commit puts the TPU backend code directly in to the KoboldAI code
to make it easier to modify.
This commit is contained in:
Gnome Ann
2021-11-19 18:06:57 -05:00
parent b926170fb0
commit a65c4de840
2 changed files with 393 additions and 6 deletions

View File

@ -179,7 +179,7 @@ def getmodelname():
if(args.configname):
modelname = args.configname
return modelname
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")):
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
return modelname
else:
@ -340,7 +340,7 @@ else:
getModelSelection()
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
vars.allowsp = True
# Test for GPU support
import torch
@ -530,7 +530,7 @@ socketio = SocketIO(app)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
@ -692,6 +692,13 @@ else:
elif(vars.model == "OAI"):
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.load_model(vars.custmodpth)
tokenizer = tpu_mtj_backend.tokenizer
# Set up Flask routes
@app.route('/')
@ -1357,19 +1364,23 @@ def calcsubmit(txt):
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else:
if(not vars.model in ["Colab", "OAI"]):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API
else:
@ -1658,7 +1669,49 @@ def sendtocolab(txt, min, max):
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
#==================================================================#
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
# Submit input text to generator
try:
if(vars.sp is not None):
raise ValueError("Softprompts are not supported by the TPU backend yet")
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
genout = tpu_mtj_backend.infer(
txt,
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
)
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
set_aibusy(0)
return
genout = [{"generated_text": txt} for txt in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
set_aibusy(0)
#==================================================================#
# Replaces returns and newlines with HTML breaks