mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Integrate TPU backend
This commit puts the TPU backend code directly in to the KoboldAI code to make it easier to modify.
This commit is contained in:
65
aiserver.py
65
aiserver.py
@ -179,7 +179,7 @@ def getmodelname():
|
||||
if(args.configname):
|
||||
modelname = args.configname
|
||||
return modelname
|
||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")):
|
||||
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||
return modelname
|
||||
else:
|
||||
@ -340,7 +340,7 @@ else:
|
||||
getModelSelection()
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
vars.allowsp = True
|
||||
# Test for GPU support
|
||||
import torch
|
||||
@ -530,7 +530,7 @@ socketio = SocketIO(app)
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
@ -692,6 +692,13 @@ else:
|
||||
elif(vars.model == "OAI"):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
@ -1357,19 +1364,23 @@ def calcsubmit(txt):
|
||||
if(vars.model != "InferKit"):
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
||||
if(actionlen == 0):
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
else:
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
|
||||
# For InferKit web API
|
||||
else:
|
||||
@ -1658,7 +1669,49 @@ def sendtocolab(txt, min, max):
|
||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Send text to TPU mesh transformer backend
|
||||
#==================================================================#
|
||||
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
if(found_entries is None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
if(vars.sp is not None):
|
||||
raise ValueError("Softprompts are not supported by the TPU backend yet")
|
||||
if(vars.dynamicscan):
|
||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||
genout = tpu_mtj_backend.infer(
|
||||
txt,
|
||||
gen_len = maximum-minimum+1,
|
||||
temp=vars.temp,
|
||||
top_p=vars.top_p,
|
||||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
genout = [{"generated_text": txt} for txt in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
|
||||
set_aibusy(0)
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Replaces returns and newlines with HTML breaks
|
||||
|
Reference in New Issue
Block a user