Fix and add some documentation to basic hf backend

This commit is contained in:
somebody
2023-07-12 17:16:05 -05:00
parent 8077d6c3f9
commit 60473d4c23

View File

@@ -25,6 +25,9 @@ model_backend_type = "Huggingface"
class model_backend(HFInferenceModel): class model_backend(HFInferenceModel):
# Model backends must inherit from InferenceModel. We inherit from HFInferenceModel here,
# as it provides some helpers for handling Huggingface configs.
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.model_name = "Basic Huggingface" self.model_name = "Basic Huggingface"
@@ -60,16 +63,10 @@ class model_backend(HFInferenceModel):
self.tokenizer = self._get_tokenizer(self.get_local_model_path()) self.tokenizer = self._get_tokenizer(self.get_local_model_path())
# Patch sampler to use KAI samplers self.model.kai_model = self
def _patched_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim
processors = _patched_get_logits_processor.original(*args, **kwargs)
return processors
use_core_manipulations.get_logits_processor = _patched_get_logits_processor
_patched_get_logits_processor.original = (
transformers.GenerationMixin._get_logits_processor
)
# Patch Huggingface stuff to use our samplers
class KoboldLogitsWarperList(LogitsProcessorList): class KoboldLogitsWarperList(LogitsProcessorList):
def __call__( def __call__(
_self, # Unused _self, # Unused
@@ -78,8 +75,10 @@ class model_backend(HFInferenceModel):
*args, *args,
**kwargs, **kwargs,
): ):
# Kobold sampling is done here.
scores = self._apply_warpers(scores=scores, input_ids=input_ids) scores = self._apply_warpers(scores=scores, input_ids=input_ids)
# Things like Lua integration, phrase bias, and probability visualization are done here.
for processor in self.logits_processors: for processor in self.logits_processors:
scores = processor(self, scores=scores, input_ids=input_ids) scores = processor(self, scores=scores, input_ids=input_ids)
assert ( assert (
@@ -89,7 +88,7 @@ class model_backend(HFInferenceModel):
def new_sample(self, *args, **kwargs): def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = lambda: KoboldLogitsWarperList() kwargs["logits_warper"] = KoboldLogitsWarperList()
if utils.koboldai_vars.newlinemode in ["s", "ns"]: if utils.koboldai_vars.newlinemode in ["s", "ns"]:
kwargs["eos_token_id"] = -1 kwargs["eos_token_id"] = -1
@@ -100,12 +99,18 @@ class model_backend(HFInferenceModel):
new_sample.old_sample = transformers.GenerationMixin.sample new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample use_core_manipulations.sample = new_sample
self.model.kai_model = self
utils.koboldai_vars.modeldim = self.model.get_input_embeddings().embedding_dim
def _apply_warpers( def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
"""Applies samplers/warpers to the given scores, returning the altered scores.
Args:
scores (torch.Tensor): The original scores.
input_ids (torch.Tensor): The input token sequence.
Returns:
torch.Tensor: The altered scores.
"""
warpers.update_settings() warpers.update_settings()
for sid in utils.koboldai_vars.sampler_order: for sid in utils.koboldai_vars.sampler_order:
@@ -115,7 +120,7 @@ class model_backend(HFInferenceModel):
continue continue
if warper == warpers.RepetitionPenalty: if warper == warpers.RepetitionPenalty:
# Rep pen needs more data than other samplers # Rep pen needs access to input tokens to decide what to penalize
scores = warper.torch(scores, input_ids=input_ids) scores = warper.torch(scores, input_ids=input_ids)
else: else:
scores = warper.torch(scores) scores = warper.torch(scores)
@@ -137,6 +142,7 @@ class model_backend(HFInferenceModel):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else: else:
gen_in = prompt_tokens gen_in = prompt_tokens
if not self.usegpu: if not self.usegpu:
gen_in = gen_in.to("cpu") gen_in = gen_in.to("cpu")
else: else:
@@ -161,6 +167,7 @@ class model_backend(HFInferenceModel):
use_cache=True, use_cache=True,
num_return_sequences=batch_count, num_return_sequences=batch_count,
) )
logger.debug( logger.debug(
"torch_raw_generate: run generator {}s".format(time.time() - start_time) "torch_raw_generate: run generator {}s".format(time.time() - start_time)
) )