Compile TPU backend in background
This commit is contained in:
parent
38a3eddd57
commit
fbc3a73c0f
68
aiserver.py
68
aiserver.py
|
@ -7,10 +7,10 @@
|
||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
import eventlet
|
import eventlet
|
||||||
eventlet.monkey_patch()
|
eventlet.monkey_patch(all=True, thread=False)
|
||||||
import os
|
import os
|
||||||
os.system("")
|
os.system("")
|
||||||
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
|
os.environ['EVENTLET_THREADPOOL_SIZE'] = '50'
|
||||||
from eventlet import tpool
|
from eventlet import tpool
|
||||||
|
|
||||||
from os import path, getcwd
|
from os import path, getcwd
|
||||||
|
@ -21,6 +21,7 @@ import zipfile
|
||||||
import packaging
|
import packaging
|
||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
|
import threading
|
||||||
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
|
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -976,6 +977,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
else:
|
else:
|
||||||
|
def tpumtjgetsofttokens():
|
||||||
|
soft_tokens = None
|
||||||
|
if(vars.sp is None):
|
||||||
|
global np
|
||||||
|
if 'np' not in globals():
|
||||||
|
import numpy as np
|
||||||
|
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||||
|
rows = tensor.shape[0]
|
||||||
|
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||||
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||||
|
tensor = tensor.reshape(
|
||||||
|
tpu_mtj_backend.params["cores_per_replica"],
|
||||||
|
-1,
|
||||||
|
tpu_mtj_backend.params["d_model"],
|
||||||
|
)
|
||||||
|
vars.sp = tensor
|
||||||
|
soft_tokens = np.arange(
|
||||||
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||||
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||||
|
dtype=np.uint32
|
||||||
|
)
|
||||||
|
return soft_tokens
|
||||||
|
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
@ -992,6 +1016,17 @@ else:
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
|
threading.Thread( # Compile backend code in background
|
||||||
|
target=tpu_mtj_backend.infer,
|
||||||
|
args=(np.uint32((23403, 727, 20185)),),
|
||||||
|
kwargs={
|
||||||
|
"soft_embeddings": vars.sp,
|
||||||
|
"soft_tokens": soft_tokens,
|
||||||
|
"gen_len": 1,
|
||||||
|
"numseqs": vars.numseqs,
|
||||||
|
},
|
||||||
|
).start()
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
@ -1583,7 +1618,8 @@ def execute_outmod():
|
||||||
# Lua runtime startup
|
# Lua runtime startup
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
|
print("", end="", flush=True)
|
||||||
|
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="", flush=True)
|
||||||
|
|
||||||
# Set up Lua state
|
# Set up Lua state
|
||||||
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
||||||
|
@ -2864,26 +2900,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
if(vars.dynamicscan):
|
if(vars.dynamicscan):
|
||||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||||
|
|
||||||
soft_tokens = None
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
if(vars.sp is None):
|
|
||||||
global np
|
|
||||||
if 'np' not in globals():
|
|
||||||
import numpy as np
|
|
||||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
|
||||||
rows = tensor.shape[0]
|
|
||||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
|
||||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
|
||||||
tensor = tensor.reshape(
|
|
||||||
tpu_mtj_backend.params["cores_per_replica"],
|
|
||||||
-1,
|
|
||||||
tpu_mtj_backend.params["d_model"],
|
|
||||||
)
|
|
||||||
vars.sp = tensor
|
|
||||||
soft_tokens = np.arange(
|
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
|
||||||
dtype=np.uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
tpu_mtj_backend.infer,
|
tpu_mtj_backend.infer,
|
||||||
|
@ -4335,8 +4352,9 @@ loadsettings()
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Final startup commands to launch Flask app
|
# Final startup commands to launch Flask app
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
print("", end="", flush=True)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END))
|
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True)
|
||||||
|
|
||||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||||
|
|
||||||
|
@ -4361,4 +4379,4 @@ if __name__ == "__main__":
|
||||||
socketio.run(app, port=5000)
|
socketio.run(app, port=5000)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END))
|
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)
|
||||||
|
|
Loading…
Reference in New Issue