Fix typo in previous commit

This commit is contained in:
Gnome Ann
2021-08-21 10:54:57 -04:00
parent a8bbfab87a
commit fae15b8a17

View File

@@ -385,7 +385,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
gc.collect() gc.collect()
model.transformer.wte.to(breakmodel.gpu_device) model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device) model.transformer.ln_f.to(breakmodel.gpu_device)
if(hasattr(model), 'lm_head'): if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device) model.lm_head.to(breakmodel.gpu_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary): if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device) model.transformer.wpe.to(breakmodel.gpu_device)
@@ -436,7 +436,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
gc.collect() gc.collect()
model.transformer.wte.to(breakmodel.gpu_device) model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device) model.transformer.ln_f.to(breakmodel.gpu_device)
if(hasattr(model), 'lm_head'): if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device) model.lm_head.to(breakmodel.gpu_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary): if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device) model.transformer.wpe.to(breakmodel.gpu_device)