This commit is contained in:
ebolam 2022-06-06 21:47:15 -04:00
parent 1b35b55d86
commit afb894f5a0
1 changed files with 1 additions and 2 deletions

View File

@ -54,6 +54,7 @@ import utils
import structures import structures
import torch import torch
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
import tpu_mtj_backend
if lupa.LUA_VERSION[:2] != (5, 4): if lupa.LUA_VERSION[:2] != (5, 4):
@ -934,7 +935,6 @@ def general_startup():
#==================================================================# #==================================================================#
def tpumtjgetsofttokens(): def tpumtjgetsofttokens():
import tpu_mtj_backend
soft_tokens = None soft_tokens = None
if(vars.sp is None): if(vars.sp is None):
global np global np
@ -5611,7 +5611,6 @@ def final_startup():
# Precompile TPU backend if required # Precompile TPU backend if required
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
import tpu_mtj_backend
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread( threading.Thread(