mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' of https://github.com/ebolam/KoboldAI into united
This commit is contained in:
55
aiserver.py
55
aiserver.py
@ -185,7 +185,7 @@ class vars:
|
||||
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
|
||||
useprompt = False # Whether to send the full prompt with every submit action
|
||||
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
|
||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM only, currently)
|
||||
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
|
||||
smandelete = False # Whether stories can be deleted from inside the browser
|
||||
smanrename = False # Whether stories can be renamed from inside the browser
|
||||
@ -382,18 +382,29 @@ def device_config(model):
|
||||
return
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
if(hasattr(model, "transformer")):
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
else:
|
||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||
model.model.layer_norm.to(breakmodel.primary_device)
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
model.model.embed_positions.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||
if("GPTJModel" in globals()):
|
||||
GPTJModel.forward = breakmodel.new_forward
|
||||
GPTJModel.forward = breakmodel.new_forward_neo
|
||||
if("XGLMModel" in globals()):
|
||||
XGLMModel.forward = breakmodel.new_forward_xglm
|
||||
generator = model.generate
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
if(hasattr(model, "transformer")):
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
else:
|
||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||
|
||||
#==================================================================#
|
||||
# Allow the models to override some settings
|
||||
@ -544,7 +555,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
loadmodelsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not vars.nobreakmodel
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
||||
if(args.breakmodel_layers is not None):
|
||||
@ -736,10 +747,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
try:
|
||||
from transformers import GPTJModel
|
||||
except:
|
||||
pass
|
||||
for m in ("GPTJModel", "XGLMModel"):
|
||||
try:
|
||||
globals()[m] = getattr(__import__("transformers"), m)
|
||||
except:
|
||||
pass
|
||||
import transformers.generation_utils
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
@ -753,7 +765,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
if(vars.sp is not None):
|
||||
shifted_input_ids = input_ids - self.config.vocab_size
|
||||
input_ids.clamp_(max=self.config.vocab_size-1)
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
if(hasattr(self, "transformer")):
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids) * self.model.embed_scale
|
||||
if(vars.sp is not None):
|
||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
@ -766,11 +781,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
try:
|
||||
from transformers import GPTJForCausalLM
|
||||
patch_causallm(GPTJForCausalLM)
|
||||
except:
|
||||
pass
|
||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM"):
|
||||
try:
|
||||
patch_causallm(getattr(__import__("transformers"), c))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
|
Reference in New Issue
Block a user