From 8077d6c3f97002d36b4272791f1ad28e57b28813 Mon Sep 17 00:00:00 2001 From: onesome Date: Wed, 12 Jul 2023 03:22:43 -0500 Subject: [PATCH] Self-contained sampler patch (Don't merge) Completely untested 3:00 AM code; beware! I will test and add more documentation tomorrow. --- modeling/inference_models/basic_hf/class.py | 77 +++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/modeling/inference_models/basic_hf/class.py b/modeling/inference_models/basic_hf/class.py index 2914c9bc..5cb64c30 100644 --- a/modeling/inference_models/basic_hf/class.py +++ b/modeling/inference_models/basic_hf/class.py @@ -6,16 +6,24 @@ import time from typing import List, Optional, Union import torch +import transformers +from transformers import LogitsProcessorList from transformers.models.auto.modeling_auto import AutoModelForCausalLM import utils from logger import logger -from modeling.inference_model import GenerationResult, GenerationSettings +from modeling import warpers +from modeling.inference_model import ( + GenerationResult, + GenerationSettings, + use_core_manipulations, +) from modeling.inference_models.hf import HFInferenceModel model_backend_name = "Basic Huggingface" model_backend_type = "Huggingface" + class model_backend(HFInferenceModel): def __init__(self) -> None: super().__init__() @@ -25,7 +33,7 @@ class model_backend(HFInferenceModel): # them in subclasses? self.hf_torch = True self.nobreakmodel = True - + def _load(self, save_model: bool, initial_load: bool) -> None: utils.koboldai_vars.allowsp = False @@ -43,16 +51,77 @@ class model_backend(HFInferenceModel): self.init_model_config() - self.model = AutoModelForCausalLM.from_pretrained(self.get_local_model_path(), low_cpu_mem_usage=True) + self.model = AutoModelForCausalLM.from_pretrained( + self.get_local_model_path(), low_cpu_mem_usage=True + ) if self.usegpu: self.model = self.model.to("cuda") self.tokenizer = self._get_tokenizer(self.get_local_model_path()) + # Patch sampler to use KAI samplers + def _patched_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: + processors = _patched_get_logits_processor.original(*args, **kwargs) + return processors + + use_core_manipulations.get_logits_processor = _patched_get_logits_processor + _patched_get_logits_processor.original = ( + transformers.GenerationMixin._get_logits_processor + ) + + class KoboldLogitsWarperList(LogitsProcessorList): + def __call__( + _self, # Unused + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + *args, + **kwargs, + ): + scores = self._apply_warpers(scores=scores, input_ids=input_ids) + + for processor in self.logits_processors: + scores = processor(self, scores=scores, input_ids=input_ids) + assert ( + scores is not None + ), f"Scores are None; processor '{processor}' is to blame" + return scores + + def new_sample(self, *args, **kwargs): + assert kwargs.pop("logits_warper", None) is not None + kwargs["logits_warper"] = lambda: KoboldLogitsWarperList() + + if utils.koboldai_vars.newlinemode in ["s", "ns"]: + kwargs["eos_token_id"] = -1 + kwargs.setdefault("pad_token_id", 2) + + return new_sample.old_sample(self, *args, **kwargs) + + new_sample.old_sample = transformers.GenerationMixin.sample + use_core_manipulations.sample = new_sample + self.model.kai_model = self utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim + def _apply_warpers( + self, scores: torch.Tensor, input_ids: torch.Tensor + ) -> torch.Tensor: + warpers.update_settings() + + for sid in utils.koboldai_vars.sampler_order: + warper = warpers.Warper.from_id(sid) + + if not warper.value_is_valid(): + continue + + if warper == warpers.RepetitionPenalty: + # Rep pen needs more data than other samplers + scores = warper.torch(scores, input_ids=input_ids) + else: + scores = warper.torch(scores) + + assert scores is not None, f"Scores are None; warper '{warper}' is to blame" + return scores def _raw_generate( self, @@ -102,4 +171,4 @@ class model_backend(HFInferenceModel): prompt=prompt_tokens, is_whole_generation=False, output_includes_prompt=True, - ) \ No newline at end of file + )