TPU fix Attempt

This commit is contained in:
ebolam 2022-06-07 09:05:51 -04:00
parent bf4af94abb
commit 1df88e1696
1 changed files with 4 additions and 2 deletions

View File

@ -54,7 +54,7 @@ import utils
import structures
import torch
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
import tpu_mtj_backend
global tpu_mtj_backend
if lupa.LUA_VERSION[:2] != (5, 4):
@ -1892,6 +1892,8 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
loadsettings()
# Load the TPU backend if requested
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
global tpu_mtj_backend
import tpu_mtj_backend
if(vars.model == "TPUMeshTransformerGPTNeoX"):
vars.badwordsids = vars.badwordsids_neox
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
@ -2946,7 +2948,7 @@ def get_message(msg):
f.write(msg['gpu_layers'])
f.close()
vars.colaburl = msg['url'] + "/request"
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'], url=msg['url'])
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'])
elif(msg['cmd'] == 'show_model'):
print("Model Name: {}".format(getmodelname()))
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)