From c43d0f9aa172a1c1119814f35b4391952fe1156c Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 9 Nov 2022 17:31:25 -0600 Subject: [PATCH] Add context manager for sandboxing invasive transformers patches --- aiserver.py | 148 +++++++++++++++++++++++++++++----------------------- 1 file changed, 84 insertions(+), 64 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0aaf6b80..0845ed48 100644 --- a/aiserver.py +++ b/aiserver.py @@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs): return tokenizer PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained +# We only want to use logit manipulations and such on our core text model +class use_core_manipulations: + # These must be set by wherever they get setup + get_logits_processor: callable + sample: callable + get_stopping_criteria: callable + + # We set these automatically + old_get_logits_processor: callable + old_sample: callable + old_get_stopping_criteria: callable + + def __enter__(self): + use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor + transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor + + use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample + transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample + + use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria + transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor + transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample + transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria + #==================================================================# # Variables & Storage #==================================================================# @@ -1910,8 +1938,6 @@ def patch_transformers_download(): def patch_transformers(): global transformers - global old_transfomers_functions - old_transfomers_functions = {} patch_transformers_download() @@ -1933,7 +1959,6 @@ def patch_transformers(): PreTrainedModel._kai_patched = True if(hasattr(modeling_utils, "get_checkpoint_shard_files")): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files - old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): utils.num_shards = utils.get_num_shards(index_filename) utils.from_pretrained_index_filename = index_filename @@ -1961,7 +1986,6 @@ def patch_transformers(): if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() - old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward XGLMSinusoidalPositionalEmbedding.forward = new_forward @@ -1981,7 +2005,6 @@ def patch_transformers(): self.model = OPTModel(config) self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.post_init() - old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__ OPTForCausalLM.__init__ = new_init @@ -2170,8 +2193,8 @@ def patch_transformers(): processors.append(PhraseBiasLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor()) return processors + use_core_manipulations.get_logits_processor = new_get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor - transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor class KoboldLogitsWarperList(LogitsProcessorList): def __init__(self, beams: int = 1, **kwargs): @@ -2204,9 +2227,9 @@ def patch_transformers(): kwargs["eos_token_id"] = -1 kwargs.setdefault("pad_token_id", 2) return new_sample.old_sample(self, *args, **kwargs) - new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample - transformers.generation_utils.GenerationMixin.sample = new_sample + new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample + use_core_manipulations.sample = new_sample # Allow bad words filter to ban <|endoftext|> token import transformers.generation_logits_process @@ -2374,7 +2397,7 @@ def patch_transformers(): old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria - old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria + def new_get_stopping_criteria(self, *args, **kwargs): global tokenizer stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) @@ -2392,7 +2415,7 @@ def patch_transformers(): stopping_criteria.insert(0, token_streamer) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) return stopping_criteria - transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria + use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria def reset_model_settings(): koboldai_vars.socketio = socketio @@ -5395,56 +5418,57 @@ def raw_generate( result: GenerationResult time_start = time.time() - if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): - batch_encoded = tpu_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True - ) - elif koboldai_vars.model in model_functions: - batch_encoded = model_functions[koboldai_vars.model]( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True - ) - elif koboldai_vars.model.startswith("RWKV"): - batch_encoded = rwkv_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True - ) - else: - # Torch HF - start_time = time.time() - batch_encoded = torch_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new if not bypass_hf_maxlength else int(2e9), - do_streaming=do_streaming, - do_dynamic_wi=do_dynamic_wi, - batch_count=batch_count, - gen_settings=gen_settings - ) - logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) - start_time = time.time() - result = GenerationResult( - out_batches=batch_encoded, - prompt=prompt_tokens, - is_whole_generation=False, - output_includes_prompt=True, - ) - logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time)) + with use_core_manipulations(): + if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): + batch_encoded = tpu_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + ) + elif koboldai_vars.model in model_functions: + batch_encoded = model_functions[koboldai_vars.model]( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + ) + elif koboldai_vars.model.startswith("RWKV"): + batch_encoded = rwkv_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True + ) + else: + # Torch HF + start_time = time.time() + batch_encoded = torch_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new if not bypass_hf_maxlength else int(2e9), + do_streaming=do_streaming, + do_dynamic_wi=do_dynamic_wi, + batch_count=batch_count, + gen_settings=gen_settings + ) + logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) + start_time = time.time() + result = GenerationResult( + out_batches=batch_encoded, + prompt=prompt_tokens, + is_whole_generation=False, + output_includes_prompt=True, + ) + logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time)) time_end = round(time.time() - time_start, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2) @@ -9384,14 +9408,10 @@ def summarize(text, max_length=100, min_length=30, unload=True): #Actual sumarization start_time = time.time() - global old_transfomers_functions - temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria - transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] #make sure text is less than 1024 tokens, otherwise we'll crash if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000: text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000]) output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] - transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp logger.debug("Time to summarize: {}".format(time.time()-start_time)) #move model back to CPU to save precious vram torch.cuda.empty_cache()