diff --git a/aiserver.py b/aiserver.py index a2f3daf3..2b90b54f 100644 --- a/aiserver.py +++ b/aiserver.py @@ -7,10 +7,10 @@ # External packages import eventlet -eventlet.monkey_patch() +eventlet.monkey_patch(all=True, thread=False) import os os.system("") -os.environ['EVENTLET_THREADPOOL_SIZE'] = '1' +os.environ['EVENTLET_THREADPOOL_SIZE'] = '50' from eventlet import tpool from os import path, getcwd @@ -21,6 +21,7 @@ import zipfile import packaging import contextlib import traceback +import threading from typing import Any, Callable, TypeVar, Union, Dict, Set, List import requests @@ -976,6 +977,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") else: + def tpumtjgetsofttokens(): + soft_tokens = None + if(vars.sp is None): + global np + if 'np' not in globals(): + import numpy as np + tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + rows = tensor.shape[0] + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params["d_model"], + ) + vars.sp = tensor + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) + return soft_tokens + # If we're running Colab or OAI, we still need a tokenizer. if(vars.model == "Colab"): from transformers import GPT2TokenizerFast @@ -992,6 +1016,17 @@ else: vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer + soft_tokens = tpumtjgetsofttokens() + threading.Thread( # Compile backend code in background + target=tpu_mtj_backend.infer, + args=(np.uint32((23403, 727, 20185)),), + kwargs={ + "soft_embeddings": vars.sp, + "soft_tokens": soft_tokens, + "gen_len": 1, + "numseqs": vars.numseqs, + }, + ).start() # Set up Flask routes @app.route('/') @@ -1583,7 +1618,8 @@ def execute_outmod(): # Lua runtime startup #==================================================================# -print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="") +print("", end="", flush=True) +print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="", flush=True) # Set up Lua state vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True) @@ -2863,27 +2899,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): try: if(vars.dynamicscan): raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") - - soft_tokens = None - if(vars.sp is None): - global np - if 'np' not in globals(): - import numpy as np - tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) - rows = tensor.shape[0] - padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows - tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) - tensor = tensor.reshape( - tpu_mtj_backend.params["cores_per_replica"], - -1, - tpu_mtj_backend.params["d_model"], - ) - vars.sp = tensor - soft_tokens = np.arange( - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, - dtype=np.uint32 - ) + + soft_tokens = tpumtjgetsofttokens() genout = tpool.execute( tpu_mtj_backend.infer, @@ -4335,8 +4352,9 @@ loadsettings() #==================================================================# # Final startup commands to launch Flask app #==================================================================# +print("", end="", flush=True) if __name__ == "__main__": - print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END)) + print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True) # Start Flask/SocketIO (Blocking, so this must be last method!) @@ -4361,4 +4379,4 @@ if __name__ == "__main__": socketio.run(app, port=5000) else: - print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END)) + print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)