Fix for non-rotary models without "rotary" in config.json
This commit is contained in:
parent
56c9dc2c04
commit
8bfcf86a8b
|
@ -386,7 +386,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
model.lm_head.to(breakmodel.gpu_device)
|
model.lm_head.to(breakmodel.gpu_device)
|
||||||
model.transformer.wte.to(breakmodel.gpu_device)
|
model.transformer.wte.to(breakmodel.gpu_device)
|
||||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
model.transformer.ln_f.to(breakmodel.gpu_device)
|
||||||
if(not model.config.rotary):
|
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
model.transformer.wpe.to(breakmodel.gpu_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if(vars.bmsupported and args.breakmodel):
|
if(vars.bmsupported and args.breakmodel):
|
||||||
|
@ -436,7 +436,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
model.lm_head.to(breakmodel.gpu_device)
|
model.lm_head.to(breakmodel.gpu_device)
|
||||||
model.transformer.wte.to(breakmodel.gpu_device)
|
model.transformer.wte.to(breakmodel.gpu_device)
|
||||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
model.transformer.ln_f.to(breakmodel.gpu_device)
|
||||||
if(not model.config.rotary):
|
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
model.transformer.wpe.to(breakmodel.gpu_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if(vars.bmsupported and args.breakmodel):
|
if(vars.bmsupported and args.breakmodel):
|
||||||
|
|
|
@ -378,7 +378,7 @@ def new_forward(
|
||||||
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
||||||
offset += emb.shape[1]
|
offset += emb.shape[1]
|
||||||
|
|
||||||
if self.rotary:
|
if hasattr(self, 'rotary') and self.rotary:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
|
|
Loading…
Reference in New Issue