Automatically support soft prompts for all transformers models
This commit is contained in:
parent
cc56718a7e
commit
042cf3e560
74
aiserver.py
74
aiserver.py
|
@ -1210,8 +1210,37 @@ def get_oai_models(key):
|
||||||
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
||||||
print(req.json())
|
print(req.json())
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
||||||
|
|
||||||
|
|
||||||
|
# Function to patch transformers to use our soft prompt
|
||||||
|
def patch_causallm(cls):
|
||||||
|
if(getattr(cls, "_koboldai_patch_causallm_patched", False)):
|
||||||
|
return
|
||||||
|
old_forward = cls.forward
|
||||||
|
def new_causallm_forward(self, *args, **kwargs):
|
||||||
|
input_ids = kwargs.get('input_ids').to(self.device)
|
||||||
|
assert input_ids is not None
|
||||||
|
kwargs['input_ids'] = None
|
||||||
|
if(vars.sp is not None):
|
||||||
|
shifted_input_ids = input_ids - self.config.vocab_size
|
||||||
|
input_ids.clamp_(max=self.config.vocab_size-1)
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
if(vars.sp is not None):
|
||||||
|
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||||
|
inputs_embeds = torch.where(
|
||||||
|
(shifted_input_ids >= 0)[..., None],
|
||||||
|
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
|
||||||
|
inputs_embeds *= self.model.embed_scale
|
||||||
|
kwargs['inputs_embeds'] = inputs_embeds
|
||||||
|
return old_forward(self, *args, **kwargs)
|
||||||
|
cls.forward = new_causallm_forward
|
||||||
|
cls._koboldai_patch_causallm_patched = True
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def patch_transformers():
|
def patch_transformers():
|
||||||
global transformers
|
global transformers
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
|
@ -1259,42 +1288,6 @@ def patch_transformers():
|
||||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||||
|
|
||||||
# Patch transformers to use our soft prompt
|
|
||||||
def patch_causallm(cls):
|
|
||||||
old_forward = cls.forward
|
|
||||||
def new_causallm_forward(self, *args, **kwargs):
|
|
||||||
input_ids = kwargs.get('input_ids').to(self.device)
|
|
||||||
assert input_ids is not None
|
|
||||||
kwargs['input_ids'] = None
|
|
||||||
if(vars.sp is not None):
|
|
||||||
shifted_input_ids = input_ids - self.config.vocab_size
|
|
||||||
input_ids.clamp_(max=self.config.vocab_size-1)
|
|
||||||
if(hasattr(self, "transformer")):
|
|
||||||
inputs_embeds = self.transformer.wte(input_ids)
|
|
||||||
elif(not hasattr(self.model, "decoder")):
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
else:
|
|
||||||
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
|
|
||||||
if(vars.sp is not None):
|
|
||||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
|
||||||
inputs_embeds = torch.where(
|
|
||||||
(shifted_input_ids >= 0)[..., None],
|
|
||||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
|
||||||
inputs_embeds,
|
|
||||||
)
|
|
||||||
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
|
|
||||||
inputs_embeds *= self.model.embed_scale
|
|
||||||
kwargs['inputs_embeds'] = inputs_embeds
|
|
||||||
return old_forward(self, *args, **kwargs)
|
|
||||||
cls.forward = new_causallm_forward
|
|
||||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
|
||||||
patch_causallm(cls)
|
|
||||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
|
|
||||||
try:
|
|
||||||
patch_causallm(getattr(__import__("transformers"), c))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
||||||
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
|
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
|
||||||
|
@ -1796,6 +1789,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
||||||
else:
|
else:
|
||||||
model = model.to('cpu').float()
|
model = model.to('cpu').float()
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
patch_causallm(model.__class__)
|
||||||
# Use the Generic implementation
|
# Use the Generic implementation
|
||||||
else:
|
else:
|
||||||
lowmem = maybe_low_cpu_mem_usage()
|
lowmem = maybe_low_cpu_mem_usage()
|
||||||
|
@ -1923,7 +1917,9 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
|
patch_causallm(model.__class__)
|
||||||
|
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
|
|
Loading…
Reference in New Issue