XGLM breakmodel

This commit is contained in:
Gnome Ann
2022-02-01 12:49:07 -05:00
parent c14e6fe5d2
commit e7f65cee09
2 changed files with 260 additions and 39 deletions

View File

@ -374,18 +374,29 @@ def device_config(model):
return
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
else:
model.model.embed_tokens.to(breakmodel.primary_device)
model.model.layer_norm.to(breakmodel.primary_device)
model.model.lm_head.to(breakmodel.primary_device)
model.model.embed_positions.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward
GPTNeoModel.forward = breakmodel.new_forward_neo
if("GPTJModel" in globals()):
GPTJModel.forward = breakmodel.new_forward
GPTJModel.forward = breakmodel.new_forward_neo
if("XGLMModel" in globals()):
XGLMModel.forward = breakmodel.new_forward_xglm
generator = model.generate
breakmodel.move_hidden_layers(model.transformer)
if(hasattr(model, "transformer")):
breakmodel.move_hidden_layers(model.transformer)
else:
breakmodel.move_hidden_layers(model.model, model.model.layers)
#==================================================================#
# Allow the models to override some settings
@ -723,10 +734,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
try:
from transformers import GPTJModel
except:
pass
for m in ("GPTJModel", "XGLMModel"):
try:
globals()[m] = __import__("transformers." + m, fromlist=[...])
except:
pass
import transformers.generation_utils
from transformers import __version__ as transformers_version