TPU fix Attempt

This commit is contained in:
ebolam 2022-06-07 09:05:51 -04:00
parent bf4af94abb
commit 1df88e1696
1 changed files with 4 additions and 2 deletions

View File

@ -54,7 +54,7 @@ import utils
import structures import structures
import torch import torch
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
import tpu_mtj_backend global tpu_mtj_backend
if lupa.LUA_VERSION[:2] != (5, 4): if lupa.LUA_VERSION[:2] != (5, 4):
@ -1892,6 +1892,8 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
loadsettings() loadsettings()
# Load the TPU backend if requested # Load the TPU backend if requested
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
global tpu_mtj_backend
import tpu_mtj_backend
if(vars.model == "TPUMeshTransformerGPTNeoX"): if(vars.model == "TPUMeshTransformerGPTNeoX"):
vars.badwordsids = vars.badwordsids_neox vars.badwordsids = vars.badwordsids_neox
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
@ -2946,7 +2948,7 @@ def get_message(msg):
f.write(msg['gpu_layers']) f.write(msg['gpu_layers'])
f.close() f.close()
vars.colaburl = msg['url'] + "/request" vars.colaburl = msg['url'] + "/request"
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'], url=msg['url']) load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'])
elif(msg['cmd'] == 'show_model'): elif(msg['cmd'] == 'show_model'):
print("Model Name: {}".format(getmodelname())) print("Model Name: {}".format(getmodelname()))
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True) emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)