TPU fix Attempt
This commit is contained in:
parent
bf4af94abb
commit
1df88e1696
|
@ -54,7 +54,7 @@ import utils
|
||||||
import structures
|
import structures
|
||||||
import torch
|
import torch
|
||||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
import tpu_mtj_backend
|
global tpu_mtj_backend
|
||||||
|
|
||||||
|
|
||||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||||
|
@ -1892,6 +1892,8 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
||||||
loadsettings()
|
loadsettings()
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
|
global tpu_mtj_backend
|
||||||
|
import tpu_mtj_backend
|
||||||
if(vars.model == "TPUMeshTransformerGPTNeoX"):
|
if(vars.model == "TPUMeshTransformerGPTNeoX"):
|
||||||
vars.badwordsids = vars.badwordsids_neox
|
vars.badwordsids = vars.badwordsids_neox
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
|
@ -2946,7 +2948,7 @@ def get_message(msg):
|
||||||
f.write(msg['gpu_layers'])
|
f.write(msg['gpu_layers'])
|
||||||
f.close()
|
f.close()
|
||||||
vars.colaburl = msg['url'] + "/request"
|
vars.colaburl = msg['url'] + "/request"
|
||||||
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'], url=msg['url'])
|
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'])
|
||||||
elif(msg['cmd'] == 'show_model'):
|
elif(msg['cmd'] == 'show_model'):
|
||||||
print("Model Name: {}".format(getmodelname()))
|
print("Model Name: {}".format(getmodelname()))
|
||||||
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)
|
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)
|
||||||
|
|
Loading…
Reference in New Issue