diff --git a/modeling/inference_models/basic_hf/class.py b/modeling/inference_models/basic_hf/class.py index 5cb64c30..ecbc55cc 100644 --- a/modeling/inference_models/basic_hf/class.py +++ b/modeling/inference_models/basic_hf/class.py @@ -25,6 +25,9 @@ model_backend_type = "Huggingface" class model_backend(HFInferenceModel): + # Model backends must inherit from InferenceModel. We inherit from HFInferenceModel here, + # as it provides some helpers for handling Huggingface configs. + def __init__(self) -> None: super().__init__() self.model_name = "Basic Huggingface" @@ -60,16 +63,10 @@ class model_backend(HFInferenceModel): self.tokenizer = self._get_tokenizer(self.get_local_model_path()) - # Patch sampler to use KAI samplers - def _patched_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: - processors = _patched_get_logits_processor.original(*args, **kwargs) - return processors - - use_core_manipulations.get_logits_processor = _patched_get_logits_processor - _patched_get_logits_processor.original = ( - transformers.GenerationMixin._get_logits_processor - ) + self.model.kai_model = self + utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim + # Patch Huggingface stuff to use our samplers class KoboldLogitsWarperList(LogitsProcessorList): def __call__( _self, # Unused @@ -78,8 +75,10 @@ class model_backend(HFInferenceModel): *args, **kwargs, ): + # Kobold sampling is done here. scores = self._apply_warpers(scores=scores, input_ids=input_ids) + # Things like Lua integration, phrase bias, and probability visualization are done here. for processor in self.logits_processors: scores = processor(self, scores=scores, input_ids=input_ids) assert ( @@ -89,7 +88,7 @@ class model_backend(HFInferenceModel): def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None - kwargs["logits_warper"] = lambda: KoboldLogitsWarperList() + kwargs["logits_warper"] = KoboldLogitsWarperList() if utils.koboldai_vars.newlinemode in ["s", "ns"]: kwargs["eos_token_id"] = -1 @@ -100,12 +99,18 @@ class model_backend(HFInferenceModel): new_sample.old_sample = transformers.GenerationMixin.sample use_core_manipulations.sample = new_sample - self.model.kai_model = self - utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim - def _apply_warpers( self, scores: torch.Tensor, input_ids: torch.Tensor ) -> torch.Tensor: + """Applies samplers/warpers to the given scores, returning the altered scores. + + Args: + scores (torch.Tensor): The original scores. + input_ids (torch.Tensor): The input token sequence. + + Returns: + torch.Tensor: The altered scores. + """ warpers.update_settings() for sid in utils.koboldai_vars.sampler_order: @@ -115,7 +120,7 @@ class model_backend(HFInferenceModel): continue if warper == warpers.RepetitionPenalty: - # Rep pen needs more data than other samplers + # Rep pen needs access to input tokens to decide what to penalize scores = warper.torch(scores, input_ids=input_ids) else: scores = warper.torch(scores) @@ -137,6 +142,7 @@ class model_backend(HFInferenceModel): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] else: gen_in = prompt_tokens + if not self.usegpu: gen_in = gen_in.to("cpu") else: @@ -161,6 +167,7 @@ class model_backend(HFInferenceModel): use_cache=True, num_return_sequences=batch_count, ) + logger.debug( "torch_raw_generate: run generator {}s".format(time.time() - start_time) )