This commit is contained in:
Gnome Ann
2021-09-23 20:57:18 -04:00
parent 72669e0489
commit 4d9eab3785
2 changed files with 36 additions and 76 deletions

View File

@@ -413,12 +413,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device)
model.transformer.wte.to(breakmodel.embedding_device)
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device)
model.lm_head.to(breakmodel.embedding_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device)
model.transformer.wpe.to(breakmodel.positional_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
@@ -465,12 +465,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device)
model.transformer.wte.to(breakmodel.embedding_device)
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device)
model.lm_head.to(breakmodel.embedding_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device)
model.transformer.wpe.to(breakmodel.positional_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
@@ -1229,7 +1229,7 @@ def generate(txt, min, max):
# its first argument if we're using breakmodel, otherwise a string
# is fine
if(vars.hascuda and vars.breakmodel):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device)
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device)
else:
gen_in = txt