mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #290 from one-some/ui2-logit-things-are-just-too-invasive-get-them-away-NOW
Add context manager for sandboxing invasive transformers patches
This commit is contained in:
148
aiserver.py
148
aiserver.py
@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
|
||||
return tokenizer
|
||||
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
|
||||
|
||||
# We only want to use logit manipulations and such on our core text model
|
||||
class use_core_manipulations:
|
||||
# These must be set by wherever they get setup
|
||||
get_logits_processor: callable
|
||||
sample: callable
|
||||
get_stopping_criteria: callable
|
||||
|
||||
# We set these automatically
|
||||
old_get_logits_processor: callable
|
||||
old_sample: callable
|
||||
old_get_stopping_criteria: callable
|
||||
|
||||
def __enter__(self):
|
||||
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
|
||||
|
||||
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
|
||||
|
||||
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
|
||||
|
||||
#==================================================================#
|
||||
# Variables & Storage
|
||||
#==================================================================#
|
||||
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
|
||||
|
||||
def patch_transformers():
|
||||
global transformers
|
||||
global old_transfomers_functions
|
||||
old_transfomers_functions = {}
|
||||
|
||||
patch_transformers_download()
|
||||
|
||||
@@ -1933,7 +1959,6 @@ def patch_transformers():
|
||||
PreTrainedModel._kai_patched = True
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
@@ -1961,7 +1986,6 @@ def patch_transformers():
|
||||
if max_pos > self.weights.size(0):
|
||||
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
|
||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||
|
||||
|
||||
@@ -1981,7 +2005,6 @@ def patch_transformers():
|
||||
self.model = OPTModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
|
||||
OPTForCausalLM.__init__ = new_init
|
||||
|
||||
|
||||
@@ -2170,8 +2193,8 @@ def patch_transformers():
|
||||
processors.append(PhraseBiasLogitsProcessor())
|
||||
processors.append(ProbabilityVisualizerLogitsProcessor())
|
||||
return processors
|
||||
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __init__(self, beams: int = 1, **kwargs):
|
||||
@@ -2204,9 +2227,9 @@ def patch_transformers():
|
||||
kwargs["eos_token_id"] = -1
|
||||
kwargs.setdefault("pad_token_id", 2)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||
transformers.generation_utils.GenerationMixin.sample = new_sample
|
||||
|
||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||
use_core_manipulations.sample = new_sample
|
||||
|
||||
# Allow bad words filter to ban <|endoftext|> token
|
||||
import transformers.generation_logits_process
|
||||
@@ -2374,7 +2397,7 @@ def patch_transformers():
|
||||
|
||||
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
|
||||
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
global tokenizer
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
@@ -2392,7 +2415,7 @@ def patch_transformers():
|
||||
stopping_criteria.insert(0, token_streamer)
|
||||
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
def reset_model_settings():
|
||||
koboldai_vars.socketio = socketio
|
||||
@@ -5395,56 +5418,57 @@ def raw_generate(
|
||||
result: GenerationResult
|
||||
time_start = time.time()
|
||||
|
||||
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
batch_encoded = tpu_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model in model_functions:
|
||||
batch_encoded = model_functions[koboldai_vars.model](
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model.startswith("RWKV"):
|
||||
batch_encoded = rwkv_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
|
||||
)
|
||||
else:
|
||||
# Torch HF
|
||||
start_time = time.time()
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
|
||||
start_time = time.time()
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
|
||||
with use_core_manipulations():
|
||||
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
batch_encoded = tpu_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model in model_functions:
|
||||
batch_encoded = model_functions[koboldai_vars.model](
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model.startswith("RWKV"):
|
||||
batch_encoded = rwkv_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
|
||||
)
|
||||
else:
|
||||
# Torch HF
|
||||
start_time = time.time()
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
|
||||
start_time = time.time()
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||
@@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
|
||||
|
||||
#Actual sumarization
|
||||
start_time = time.time()
|
||||
global old_transfomers_functions
|
||||
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
|
||||
#make sure text is less than 1024 tokens, otherwise we'll crash
|
||||
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
|
||||
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
|
||||
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
|
||||
logger.debug("Time to summarize: {}".format(time.time()-start_time))
|
||||
#move model back to CPU to save precious vram
|
||||
torch.cuda.empty_cache()
|
||||
|
Reference in New Issue
Block a user