Self-contained sampler patch (Don't merge)

Completely untested 3:00 AM code; beware! I will test and add more
documentation tomorrow.
This commit is contained in:
onesome
2023-07-12 03:22:43 -05:00
parent 20b4b4bcef
commit 8077d6c3f9

View File

@@ -6,16 +6,24 @@ import time
from typing import List, Optional, Union
import torch
import transformers
from transformers import LogitsProcessorList
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
import utils
from logger import logger
from modeling.inference_model import GenerationResult, GenerationSettings
from modeling import warpers
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
use_core_manipulations,
)
from modeling.inference_models.hf import HFInferenceModel
model_backend_name = "Basic Huggingface"
model_backend_type = "Huggingface"
class model_backend(HFInferenceModel):
def __init__(self) -> None:
super().__init__()
@@ -43,16 +51,77 @@ class model_backend(HFInferenceModel):
self.init_model_config()
self.model = AutoModelForCausalLM.from_pretrained(self.get_local_model_path(), low_cpu_mem_usage=True)
self.model = AutoModelForCausalLM.from_pretrained(
self.get_local_model_path(), low_cpu_mem_usage=True
)
if self.usegpu:
self.model = self.model.to("cuda")
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
# Patch sampler to use KAI samplers
def _patched_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = _patched_get_logits_processor.original(*args, **kwargs)
return processors
use_core_manipulations.get_logits_processor = _patched_get_logits_processor
_patched_get_logits_processor.original = (
transformers.GenerationMixin._get_logits_processor
)
class KoboldLogitsWarperList(LogitsProcessorList):
def __call__(
_self, # Unused
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
*args,
**kwargs,
):
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
for processor in self.logits_processors:
scores = processor(self, scores=scores, input_ids=input_ids)
assert (
scores is not None
), f"Scores are None; processor '{processor}' is to blame"
return scores
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = lambda: KoboldLogitsWarperList()
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
self.model.kai_model = self
utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim
def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor:
warpers.update_settings()
for sid in utils.koboldai_vars.sampler_order:
warper = warpers.Warper.from_id(sid)
if not warper.value_is_valid():
continue
if warper == warpers.RepetitionPenalty:
# Rep pen needs more data than other samplers
scores = warper.torch(scores, input_ids=input_ids)
else:
scores = warper.torch(scores)
assert scores is not None, f"Scores are None; warper '{warper}' is to blame"
return scores
def _raw_generate(
self,