Merge pull request #290 from one-some/ui2-logit-things-are-just-too-invasive-get-them-away-NOW

Add context manager for sandboxing invasive transformers patches
This commit is contained in:
ebolam
2022-11-09 18:44:39 -05:00
committed by GitHub

View File

@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
return tokenizer return tokenizer
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
# These must be set by wherever they get setup
get_logits_processor: callable
sample: callable
get_stopping_criteria: callable
# We set these automatically
old_get_logits_processor: callable
old_sample: callable
old_get_stopping_criteria: callable
def __enter__(self):
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
#==================================================================# #==================================================================#
# Variables & Storage # Variables & Storage
#==================================================================# #==================================================================#
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
def patch_transformers(): def patch_transformers():
global transformers global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download() patch_transformers_download()
@@ -1933,7 +1959,6 @@ def patch_transformers():
PreTrainedModel._kai_patched = True PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")): if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename) utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename utils.from_pretrained_index_filename = index_filename
@@ -1961,7 +1986,6 @@ def patch_transformers():
if max_pos > self.weights.size(0): if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1981,7 +2005,6 @@ def patch_transformers():
self.model = OPTModel(config) self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init() self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init OPTForCausalLM.__init__ = new_init
@@ -2170,8 +2193,8 @@ def patch_transformers():
processors.append(PhraseBiasLogitsProcessor()) processors.append(PhraseBiasLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor())
return processors return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList): class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs): def __init__(self, beams: int = 1, **kwargs):
@@ -2204,9 +2227,9 @@ def patch_transformers():
kwargs["eos_token_id"] = -1 kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2) kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs) return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = new_sample
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# Allow bad words filter to ban <|endoftext|> token # Allow bad words filter to ban <|endoftext|> token
import transformers.generation_logits_process import transformers.generation_logits_process
@@ -2374,7 +2397,7 @@ def patch_transformers():
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs): def new_get_stopping_criteria(self, *args, **kwargs):
global tokenizer global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@@ -2392,7 +2415,7 @@ def patch_transformers():
stopping_criteria.insert(0, token_streamer) stopping_criteria.insert(0, token_streamer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
return stopping_criteria return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
def reset_model_settings(): def reset_model_settings():
koboldai_vars.socketio = socketio koboldai_vars.socketio = socketio
@@ -5395,6 +5418,7 @@ def raw_generate(
result: GenerationResult result: GenerationResult
time_start = time.time() time_start = time.time()
with use_core_manipulations():
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate( batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
@@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
#Actual sumarization #Actual sumarization
start_time = time.time() start_time = time.time()
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
#make sure text is less than 1024 tokens, otherwise we'll crash #make sure text is less than 1024 tokens, otherwise we'll crash
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000: if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000]) text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
logger.debug("Time to summarize: {}".format(time.time()-start_time)) logger.debug("Time to summarize: {}".format(time.time()-start_time))
#move model back to CPU to save precious vram #move model back to CPU to save precious vram
torch.cuda.empty_cache() torch.cuda.empty_cache()