Breakmodel support for GPTJModel

This commit is contained in:
Gnome Ann
2021-11-25 18:09:16 -05:00
parent f8bcc3411b
commit 25c9be5d02
2 changed files with 13 additions and 21 deletions

View File

@ -298,10 +298,12 @@ def device_config(model):
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward
if("GPTJModel" in globals()):
GPTJModel.forward = breakmodel.new_forward
generator = model.generate
breakmodel.move_hidden_layers(model.transformer)