mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Self-contained sampler patch (Don't merge)
Completely untested 3:00 AM code; beware! I will test and add more documentation tomorrow.
This commit is contained in:
@@ -6,16 +6,24 @@ import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList
|
||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
from modeling.inference_model import GenerationResult, GenerationSettings
|
||||
from modeling import warpers
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
use_core_manipulations,
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
model_backend_name = "Basic Huggingface"
|
||||
model_backend_type = "Huggingface"
|
||||
|
||||
|
||||
class model_backend(HFInferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -25,7 +33,7 @@ class model_backend(HFInferenceModel):
|
||||
# them in subclasses?
|
||||
self.hf_torch = True
|
||||
self.nobreakmodel = True
|
||||
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
utils.koboldai_vars.allowsp = False
|
||||
|
||||
@@ -43,16 +51,77 @@ class model_backend(HFInferenceModel):
|
||||
|
||||
self.init_model_config()
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.get_local_model_path(), low_cpu_mem_usage=True)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.get_local_model_path(), low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
if self.usegpu:
|
||||
self.model = self.model.to("cuda")
|
||||
|
||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||
|
||||
# Patch sampler to use KAI samplers
|
||||
def _patched_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = _patched_get_logits_processor.original(*args, **kwargs)
|
||||
return processors
|
||||
|
||||
use_core_manipulations.get_logits_processor = _patched_get_logits_processor
|
||||
_patched_get_logits_processor.original = (
|
||||
transformers.GenerationMixin._get_logits_processor
|
||||
)
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __call__(
|
||||
_self, # Unused
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
|
||||
for processor in self.logits_processors:
|
||||
scores = processor(self, scores=scores, input_ids=input_ids)
|
||||
assert (
|
||||
scores is not None
|
||||
), f"Scores are None; processor '{processor}' is to blame"
|
||||
return scores
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = lambda: KoboldLogitsWarperList()
|
||||
|
||||
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
|
||||
kwargs["eos_token_id"] = -1
|
||||
kwargs.setdefault("pad_token_id", 2)
|
||||
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
|
||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||
use_core_manipulations.sample = new_sample
|
||||
|
||||
self.model.kai_model = self
|
||||
utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim
|
||||
|
||||
def _apply_warpers(
|
||||
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
warpers.update_settings()
|
||||
|
||||
for sid in utils.koboldai_vars.sampler_order:
|
||||
warper = warpers.Warper.from_id(sid)
|
||||
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
# Rep pen needs more data than other samplers
|
||||
scores = warper.torch(scores, input_ids=input_ids)
|
||||
else:
|
||||
scores = warper.torch(scores)
|
||||
|
||||
assert scores is not None, f"Scores are None; warper '{warper}' is to blame"
|
||||
return scores
|
||||
|
||||
def _raw_generate(
|
||||
self,
|
||||
@@ -102,4 +171,4 @@ class model_backend(HFInferenceModel):
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user