Merge pull request #290 from one-some/ui2-logit-things-are-just-too-invasive-get-them-away-NOW

Add context manager for sandboxing invasive transformers patches
This commit is contained in:
ebolam
2022-11-09 18:44:39 -05:00
committed by GitHub

View File

@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
return tokenizer return tokenizer
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
# These must be set by wherever they get setup
get_logits_processor: callable
sample: callable
get_stopping_criteria: callable
# We set these automatically
old_get_logits_processor: callable
old_sample: callable
old_get_stopping_criteria: callable
def __enter__(self):
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
#==================================================================# #==================================================================#
# Variables & Storage # Variables & Storage
#==================================================================# #==================================================================#
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
def patch_transformers(): def patch_transformers():
global transformers global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download() patch_transformers_download()
@@ -1933,7 +1959,6 @@ def patch_transformers():
PreTrainedModel._kai_patched = True PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")): if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename) utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename utils.from_pretrained_index_filename = index_filename
@@ -1961,7 +1986,6 @@ def patch_transformers():
if max_pos > self.weights.size(0): if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1981,7 +2005,6 @@ def patch_transformers():
self.model = OPTModel(config) self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init() self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init OPTForCausalLM.__init__ = new_init
@@ -2170,8 +2193,8 @@ def patch_transformers():
processors.append(PhraseBiasLogitsProcessor()) processors.append(PhraseBiasLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor())
return processors return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList): class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs): def __init__(self, beams: int = 1, **kwargs):
@@ -2204,9 +2227,9 @@ def patch_transformers():
kwargs["eos_token_id"] = -1 kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2) kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs) return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = new_sample
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# Allow bad words filter to ban <|endoftext|> token # Allow bad words filter to ban <|endoftext|> token
import transformers.generation_logits_process import transformers.generation_logits_process
@@ -2374,7 +2397,7 @@ def patch_transformers():
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs): def new_get_stopping_criteria(self, *args, **kwargs):
global tokenizer global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@@ -2392,7 +2415,7 @@ def patch_transformers():
stopping_criteria.insert(0, token_streamer) stopping_criteria.insert(0, token_streamer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
return stopping_criteria return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
def reset_model_settings(): def reset_model_settings():
koboldai_vars.socketio = socketio koboldai_vars.socketio = socketio
@@ -5395,56 +5418,57 @@ def raw_generate(
result: GenerationResult result: GenerationResult
time_start = time.time() time_start = time.time()
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): with use_core_manipulations():
batch_encoded = tpu_raw_generate( if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
prompt_tokens=prompt_tokens, batch_encoded = tpu_raw_generate(
max_new=max_new, prompt_tokens=prompt_tokens,
batch_count=batch_count, max_new=max_new,
gen_settings=gen_settings batch_count=batch_count,
) gen_settings=gen_settings
result = GenerationResult( )
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True result = GenerationResult(
) out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
elif koboldai_vars.model in model_functions: )
batch_encoded = model_functions[koboldai_vars.model]( elif koboldai_vars.model in model_functions:
prompt_tokens=prompt_tokens, batch_encoded = model_functions[koboldai_vars.model](
max_new=max_new, prompt_tokens=prompt_tokens,
batch_count=batch_count, max_new=max_new,
gen_settings=gen_settings batch_count=batch_count,
) gen_settings=gen_settings
result = GenerationResult( )
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True result = GenerationResult(
) out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
elif koboldai_vars.model.startswith("RWKV"): )
batch_encoded = rwkv_raw_generate( elif koboldai_vars.model.startswith("RWKV"):
prompt_tokens=prompt_tokens, batch_encoded = rwkv_raw_generate(
max_new=max_new, prompt_tokens=prompt_tokens,
batch_count=batch_count, max_new=max_new,
gen_settings=gen_settings batch_count=batch_count,
) gen_settings=gen_settings
result = GenerationResult( )
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True result = GenerationResult(
) out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
else: )
# Torch HF else:
start_time = time.time() # Torch HF
batch_encoded = torch_raw_generate( start_time = time.time()
prompt_tokens=prompt_tokens, batch_encoded = torch_raw_generate(
max_new=max_new if not bypass_hf_maxlength else int(2e9), prompt_tokens=prompt_tokens,
do_streaming=do_streaming, max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_dynamic_wi=do_dynamic_wi, do_streaming=do_streaming,
batch_count=batch_count, do_dynamic_wi=do_dynamic_wi,
gen_settings=gen_settings batch_count=batch_count,
) gen_settings=gen_settings
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) )
start_time = time.time() logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
result = GenerationResult( start_time = time.time()
out_batches=batch_encoded, result = GenerationResult(
prompt=prompt_tokens, out_batches=batch_encoded,
is_whole_generation=False, prompt=prompt_tokens,
output_includes_prompt=True, is_whole_generation=False,
) output_includes_prompt=True,
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time)) )
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
time_end = round(time.time() - time_start, 2) time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
@@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
#Actual sumarization #Actual sumarization
start_time = time.time() start_time = time.time()
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
#make sure text is less than 1024 tokens, otherwise we'll crash #make sure text is less than 1024 tokens, otherwise we'll crash
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000: if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000]) text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
logger.debug("Time to summarize: {}".format(time.time()-start_time)) logger.debug("Time to summarize: {}".format(time.time()-start_time))
#move model back to CPU to save precious vram #move model back to CPU to save precious vram
torch.cuda.empty_cache() torch.cuda.empty_cache()