TPU fix Attempt
This commit is contained in:
parent
bf4af94abb
commit
1df88e1696
|
@ -54,7 +54,7 @@ import utils
|
|||
import structures
|
||||
import torch
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
import tpu_mtj_backend
|
||||
global tpu_mtj_backend
|
||||
|
||||
|
||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||
|
@ -1892,6 +1892,8 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||
loadsettings()
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
global tpu_mtj_backend
|
||||
import tpu_mtj_backend
|
||||
if(vars.model == "TPUMeshTransformerGPTNeoX"):
|
||||
vars.badwordsids = vars.badwordsids_neox
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
|
@ -2946,7 +2948,7 @@ def get_message(msg):
|
|||
f.write(msg['gpu_layers'])
|
||||
f.close()
|
||||
vars.colaburl = msg['url'] + "/request"
|
||||
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'], url=msg['url'])
|
||||
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], online_model=msg['online_model'])
|
||||
elif(msg['cmd'] == 'show_model'):
|
||||
print("Model Name: {}".format(getmodelname()))
|
||||
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)
|
||||
|
|
Loading…
Reference in New Issue