mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Breakmodel support for GPTJModel
This commit is contained in:
@ -298,10 +298,12 @@ def device_config(model):
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
if("GPTJModel" in globals()):
|
||||
GPTJModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
|
||||
|
Reference in New Issue
Block a user