Compile TPU backend in background

This commit is contained in:
Gnome Ann 2022-01-07 13:47:21 -05:00
parent 38a3eddd57
commit fbc3a73c0f
1 changed files with 44 additions and 26 deletions

View File

@ -7,10 +7,10 @@
# External packages
import eventlet
eventlet.monkey_patch()
eventlet.monkey_patch(all=True, thread=False)
import os
os.system("")
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
os.environ['EVENTLET_THREADPOOL_SIZE'] = '50'
from eventlet import tpool
from os import path, getcwd
@ -21,6 +21,7 @@ import zipfile
import packaging
import contextlib
import traceback
import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
import requests
@ -976,6 +977,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
else:
def tpumtjgetsofttokens():
soft_tokens = None
if(vars.sp is None):
global np
if 'np' not in globals():
import numpy as np
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
dtype=np.uint32
)
return soft_tokens
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
@ -992,6 +1016,17 @@ else:
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"gen_len": 1,
"numseqs": vars.numseqs,
},
).start()
# Set up Flask routes
@app.route('/')
@ -1583,7 +1618,8 @@ def execute_outmod():
# Lua runtime startup
#==================================================================#
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
print("", end="", flush=True)
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="", flush=True)
# Set up Lua state
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
@ -2863,27 +2899,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
try:
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
soft_tokens = None
if(vars.sp is None):
global np
if 'np' not in globals():
import numpy as np
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
dtype=np.uint32
)
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer,
@ -4335,8 +4352,9 @@ loadsettings()
#==================================================================#
# Final startup commands to launch Flask app
#==================================================================#
print("", end="", flush=True)
if __name__ == "__main__":
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END))
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True)
# Start Flask/SocketIO (Blocking, so this must be last method!)
@ -4361,4 +4379,4 @@ if __name__ == "__main__":
socketio.run(app, port=5000)
else:
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END))
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)