Merge pull request #290 from one-some/ui2-logit-things-are-just-too-invasive-get-them-away-NOW

Add context manager for sandboxing invasive transformers patches
This commit is contained in:
ebolam
2022-11-09 18:44:39 -05:00
committed by GitHub

View File

@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
return tokenizer
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
# These must be set by wherever they get setup
get_logits_processor: callable
sample: callable
get_stopping_criteria: callable
# We set these automatically
old_get_logits_processor: callable
old_sample: callable
old_get_stopping_criteria: callable
def __enter__(self):
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
#==================================================================#
# Variables & Storage
#==================================================================#
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
def patch_transformers():
global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download()
@@ -1933,7 +1959,6 @@ def patch_transformers():
PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
@@ -1961,7 +1986,6 @@ def patch_transformers():
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1981,7 +2005,6 @@ def patch_transformers():
self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init
@@ -2170,8 +2193,8 @@ def patch_transformers():
processors.append(PhraseBiasLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
@@ -2204,9 +2227,9 @@ def patch_transformers():
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = new_sample
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# Allow bad words filter to ban <|endoftext|> token
import transformers.generation_logits_process
@@ -2374,7 +2397,7 @@ def patch_transformers():
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@@ -2392,7 +2415,7 @@ def patch_transformers():
stopping_criteria.insert(0, token_streamer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
def reset_model_settings():
koboldai_vars.socketio = socketio
@@ -5395,56 +5418,57 @@ def raw_generate(
result: GenerationResult
time_start = time.time()
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model](
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
else:
# Torch HF
start_time = time.time()
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
start_time = time.time()
result = GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
with use_core_manipulations():
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model](
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
else:
# Torch HF
start_time = time.time()
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
start_time = time.time()
result = GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
@@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
#Actual sumarization
start_time = time.time()
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
#make sure text is less than 1024 tokens, otherwise we'll crash
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
logger.debug("Time to summarize: {}".format(time.time()-start_time))
#move model back to CPU to save precious vram
torch.cuda.empty_cache()