Replace exllama samplers with kobold's inbuilt ones

This commit is contained in:
0cc4m
2023-06-06 19:21:34 +02:00
parent 94520d5c80
commit 39dfb18455

View File

@@ -15,6 +15,10 @@ import gc
import utils
from logger import logger
from modeling import warpers
from modeling.warpers import Warper
from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
@@ -30,6 +34,11 @@ from exllama.generator import ExLlamaGenerator
model_backend_name = "ExLlama"
# When set to true, messages will appear in the console if samplers are not
# changing the scores. Keep in mind some samplers don't always change the
# scores for each token.
LOG_SAMPLER_NO_EFFECT = False
def load_model_gptq_settings(path):
try:
@@ -86,7 +95,7 @@ class model_backend(InferenceModel):
def _load(self, save_model: bool, initial_load: bool) -> None:
self.model = self._get_model(self.get_local_model_path(), {})
self.tokenizer = self._get_tokenizer(self.get_local_model_path()))
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
self.cache = ExLlamaCache(self.model)
@@ -203,6 +212,34 @@ class model_backend(InferenceModel):
except:
pass
def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor:
warpers.update_settings()
if LOG_SAMPLER_NO_EFFECT:
pre = torch.Tensor(scores)
for sid in utils.koboldai_vars.sampler_order:
warper = Warper.from_id(sid)
if not warper.value_is_valid():
continue
if warper == warpers.RepetitionPenalty:
# Rep pen needs more data than other samplers
scores = warper.torch(scores, input_ids=input_ids)
else:
scores = warper.torch(scores)
assert scores is not None, f"Scores are None; warper '{warper}' is to blame"
if LOG_SAMPLER_NO_EFFECT:
if torch.equal(pre, scores):
logger.info(warper, "had no effect on the scores.")
pre = torch.Tensor(scores)
return scores
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
@@ -228,8 +265,23 @@ class model_backend(InferenceModel):
self.generator.gen_begin(gen_in)
# from pudb.remote import set_trace
# set_trace(term_size=(200, 60))
for i in range(max_new):
token = self.generator.gen_single_token()
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache)
logits[:, :, self.tokenizer.bos_token_id] = -10000.0
logits = torch.unsqueeze(logits[0, -1, :], 0)
scores = self._apply_warpers(logits, gen_in)
scores = torch.softmax(scores, dim=-1)
token = torch.multinomial(scores, 1)
self.generator.gen_accept_token(token)
if token.item() == self.tokenizer.eos_token_id: break
return GenerationResult(