From 1df88e1696a1baf4b75b5c3eed407ecb3a48390f Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 7 Jun 2022 09:05:51 -0400 Subject: [PATCH] TPU fix Attempt --- aiserver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index cb443a6a..5827b96a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -54,7 +54,7 @@ import utils import structures import torch from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer -import tpu_mtj_backend +global tpu_mtj_backend if lupa.LUA_VERSION[:2] != (5, 4): @@ -1892,6 +1892,8 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" loadsettings() # Load the TPU backend if requested elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): + global tpu_mtj_backend + import tpu_mtj_backend if(vars.model == "TPUMeshTransformerGPTNeoX"): vars.badwordsids = vars.badwordsids_neox print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) @@ -2946,7 +2948,7 @@ def get_message(msg): f.write(msg['gpu_layers']) f.close() vars.colaburl = msg['url'] + "/request" - load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'], url=msg['url']) + load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model']) elif(msg['cmd'] == 'show_model'): print("Model Name: {}".format(getmodelname())) emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)