Added support for running model remotely on Google Colab
This commit is contained in:
parent
0b113a75b4
commit
3c0638bc73
75
aiserver.py
75
aiserver.py
|
@ -42,7 +42,8 @@ modellist = [
|
|||
["GPT-2 XL", "gpt2-xl", "16GB"],
|
||||
["InferKit API (requires API key)", "InferKit", ""],
|
||||
["Custom Neo (eg Neo-horni)", "NeoCustom", ""],
|
||||
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""]
|
||||
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""],
|
||||
["Google Colab", "Colab", ""]
|
||||
]
|
||||
|
||||
# Variables
|
||||
|
@ -69,6 +70,7 @@ class vars:
|
|||
mode = "play" # Whether the interface is in play, memory, or edit mode
|
||||
editln = 0 # Which line was last selected in Edit Mode
|
||||
url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL
|
||||
colaburl = "" # Ngrok url for Google Colab mode
|
||||
apikey = "" # API key to use for InferKit API calls
|
||||
savedir = getcwd()+"\stories"
|
||||
hascuda = False # Whether torch has detected CUDA on the system
|
||||
|
@ -134,7 +136,7 @@ print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n"
|
|||
getModelSelection()
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(vars.model != "InferKit" and vars.hascuda):
|
||||
if((not vars.model in ["InferKit", "Colab"]) and vars.hascuda):
|
||||
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
|
||||
print(" 1 - GPU\n 2 - CPU\n")
|
||||
genselected = False
|
||||
|
@ -185,6 +187,11 @@ if(vars.model == "InferKit"):
|
|||
finally:
|
||||
file.close()
|
||||
|
||||
# Ask for ngrok url if Google Colab was selected
|
||||
if(vars.model == "Colab"):
|
||||
print("{0}Please enter the ngrok.io URL displayed in Google Colab:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.colaburl = input("URL> ") + "/request"
|
||||
|
||||
# Set logging level to reduce chatter from Flask
|
||||
import logging
|
||||
log = logging.getLogger('werkzeug')
|
||||
|
@ -200,7 +207,7 @@ socketio = SocketIO(app)
|
|||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
if(vars.model != "InferKit"):
|
||||
if(not vars.model in ["InferKit", "Colab"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
|
||||
|
@ -236,6 +243,10 @@ if(vars.model != "InferKit"):
|
|||
else:
|
||||
# Import requests library for HTTPS calls
|
||||
import requests
|
||||
# If we're running Colab, we still need a tokenizer.
|
||||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
|
@ -604,11 +615,19 @@ def calcsubmit(txt):
|
|||
|
||||
# Send completed bundle to generator
|
||||
ln = len(tokens)
|
||||
|
||||
if(vars.model != "Colab"):
|
||||
generate (
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
else:
|
||||
sendtocolab(
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
# For InferKit web API
|
||||
else:
|
||||
|
||||
|
@ -685,6 +704,56 @@ def generate(txt, min, max):
|
|||
|
||||
set_aibusy(0)
|
||||
|
||||
#==================================================================#
|
||||
# Send transformers-style request to ngrok/colab host
|
||||
#==================================================================#
|
||||
def sendtocolab(txt, min, max):
|
||||
# Log request to console
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
|
||||
|
||||
# Build request JSON data
|
||||
reqdata = {
|
||||
'text': txt,
|
||||
'min': min,
|
||||
'max': max,
|
||||
'rep_pen': vars.rep_pen,
|
||||
'temperature': vars.temp,
|
||||
'top_p': vars.top_p
|
||||
}
|
||||
|
||||
# Create request
|
||||
req = requests.post(
|
||||
vars.colaburl,
|
||||
json = reqdata
|
||||
)
|
||||
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
genout = req.json()["data"]["text"]
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
|
||||
# Format output before continuing
|
||||
genout = applyoutputformatting(getnewcontent(genout))
|
||||
|
||||
# Add formatted text to Actions array and refresh the game screen
|
||||
vars.actions.append(genout)
|
||||
refresh_story()
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
|
||||
|
||||
set_aibusy(0)
|
||||
else:
|
||||
# Send error message to web client
|
||||
er = req.json()
|
||||
if("error" in er):
|
||||
code = er["error"]["extensions"]["code"]
|
||||
elif("errors" in er):
|
||||
code = er["errors"][0]["extensions"]["code"]
|
||||
|
||||
errmsg = "InferKit API Error: {0} - {1}".format(req.status_code, code)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
|
||||
set_aibusy(0)
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Replaces returns and newlines with HTML breaks
|
||||
#==================================================================#
|
||||
|
|
Loading…
Reference in New Issue