Compile TPU backend in background
This commit is contained in:
parent
38a3eddd57
commit
fbc3a73c0f
68
aiserver.py
68
aiserver.py
|
@ -7,10 +7,10 @@
|
|||
|
||||
# External packages
|
||||
import eventlet
|
||||
eventlet.monkey_patch()
|
||||
eventlet.monkey_patch(all=True, thread=False)
|
||||
import os
|
||||
os.system("")
|
||||
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
|
||||
os.environ['EVENTLET_THREADPOOL_SIZE'] = '50'
|
||||
from eventlet import tpool
|
||||
|
||||
from os import path, getcwd
|
||||
|
@ -21,6 +21,7 @@ import zipfile
|
|||
import packaging
|
||||
import contextlib
|
||||
import traceback
|
||||
import threading
|
||||
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
|
@ -976,6 +977,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
else:
|
||||
def tpumtjgetsofttokens():
|
||||
soft_tokens = None
|
||||
if(vars.sp is None):
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tensor
|
||||
soft_tokens = np.arange(
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||
dtype=np.uint32
|
||||
)
|
||||
return soft_tokens
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
@ -992,6 +1016,17 @@ else:
|
|||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
threading.Thread( # Compile backend code in background
|
||||
target=tpu_mtj_backend.infer,
|
||||
args=(np.uint32((23403, 727, 20185)),),
|
||||
kwargs={
|
||||
"soft_embeddings": vars.sp,
|
||||
"soft_tokens": soft_tokens,
|
||||
"gen_len": 1,
|
||||
"numseqs": vars.numseqs,
|
||||
},
|
||||
).start()
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
|
@ -1583,7 +1618,8 @@ def execute_outmod():
|
|||
# Lua runtime startup
|
||||
#==================================================================#
|
||||
|
||||
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
|
||||
print("", end="", flush=True)
|
||||
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="", flush=True)
|
||||
|
||||
# Set up Lua state
|
||||
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
||||
|
@ -2864,26 +2900,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
if(vars.dynamicscan):
|
||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||
|
||||
soft_tokens = None
|
||||
if(vars.sp is None):
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tensor
|
||||
soft_tokens = np.arange(
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||
dtype=np.uint32
|
||||
)
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer,
|
||||
|
@ -4335,8 +4352,9 @@ loadsettings()
|
|||
#==================================================================#
|
||||
# Final startup commands to launch Flask app
|
||||
#==================================================================#
|
||||
print("", end="", flush=True)
|
||||
if __name__ == "__main__":
|
||||
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END))
|
||||
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True)
|
||||
|
||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||
|
||||
|
@ -4361,4 +4379,4 @@ if __name__ == "__main__":
|
|||
socketio.run(app, port=5000)
|
||||
|
||||
else:
|
||||
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END))
|
||||
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)
|
||||
|
|
Loading…
Reference in New Issue