mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
XGLM breakmodel
This commit is contained in:
38
aiserver.py
38
aiserver.py
@ -374,18 +374,29 @@ def device_config(model):
|
||||
return
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
if(hasattr(model, "transformer")):
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
else:
|
||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||
model.model.layer_norm.to(breakmodel.primary_device)
|
||||
model.model.lm_head.to(breakmodel.primary_device)
|
||||
model.model.embed_positions.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||
if("GPTJModel" in globals()):
|
||||
GPTJModel.forward = breakmodel.new_forward
|
||||
GPTJModel.forward = breakmodel.new_forward_neo
|
||||
if("XGLMModel" in globals()):
|
||||
XGLMModel.forward = breakmodel.new_forward_xglm
|
||||
generator = model.generate
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
if(hasattr(model, "transformer")):
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
else:
|
||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||
|
||||
#==================================================================#
|
||||
# Allow the models to override some settings
|
||||
@ -723,10 +734,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
try:
|
||||
from transformers import GPTJModel
|
||||
except:
|
||||
pass
|
||||
for m in ("GPTJModel", "XGLMModel"):
|
||||
try:
|
||||
globals()[m] = __import__("transformers." + m, fromlist=[...])
|
||||
except:
|
||||
pass
|
||||
import transformers.generation_utils
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
|
Reference in New Issue
Block a user