diff --git a/aiserver.py b/aiserver.py index c0d784eb..ea1896af 100644 --- a/aiserver.py +++ b/aiserver.py @@ -383,9 +383,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.lm_head.to(breakmodel.gpu_device) model.transformer.wte.to(breakmodel.gpu_device) model.transformer.ln_f.to(breakmodel.gpu_device) + if(hasattr(model), 'lm_head'): + model.lm_head.to(breakmodel.gpu_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): model.transformer.wpe.to(breakmodel.gpu_device) gc.collect() @@ -433,9 +434,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.lm_head.to(breakmodel.gpu_device) model.transformer.wte.to(breakmodel.gpu_device) model.transformer.ln_f.to(breakmodel.gpu_device) + if(hasattr(model), 'lm_head'): + model.lm_head.to(breakmodel.gpu_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): model.transformer.wpe.to(breakmodel.gpu_device) gc.collect()