mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Replace exllama samplers with kobold's inbuilt ones
This commit is contained in:
@@ -15,6 +15,10 @@ import gc
|
|||||||
import utils
|
import utils
|
||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
|
from modeling import warpers
|
||||||
|
from modeling.warpers import Warper
|
||||||
|
from modeling.stoppers import Stoppers
|
||||||
|
from modeling.post_token_hooks import PostTokenHooks
|
||||||
from modeling.inference_model import (
|
from modeling.inference_model import (
|
||||||
GenerationResult,
|
GenerationResult,
|
||||||
GenerationSettings,
|
GenerationSettings,
|
||||||
@@ -30,6 +34,11 @@ from exllama.generator import ExLlamaGenerator
|
|||||||
|
|
||||||
model_backend_name = "ExLlama"
|
model_backend_name = "ExLlama"
|
||||||
|
|
||||||
|
# When set to true, messages will appear in the console if samplers are not
|
||||||
|
# changing the scores. Keep in mind some samplers don't always change the
|
||||||
|
# scores for each token.
|
||||||
|
LOG_SAMPLER_NO_EFFECT = False
|
||||||
|
|
||||||
|
|
||||||
def load_model_gptq_settings(path):
|
def load_model_gptq_settings(path):
|
||||||
try:
|
try:
|
||||||
@@ -86,7 +95,7 @@ class model_backend(InferenceModel):
|
|||||||
|
|
||||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
self.model = self._get_model(self.get_local_model_path(), {})
|
self.model = self._get_model(self.get_local_model_path(), {})
|
||||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path()))
|
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||||
|
|
||||||
self.cache = ExLlamaCache(self.model)
|
self.cache = ExLlamaCache(self.model)
|
||||||
|
|
||||||
@@ -203,6 +212,34 @@ class model_backend(InferenceModel):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _apply_warpers(
|
||||||
|
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
warpers.update_settings()
|
||||||
|
|
||||||
|
if LOG_SAMPLER_NO_EFFECT:
|
||||||
|
pre = torch.Tensor(scores)
|
||||||
|
|
||||||
|
for sid in utils.koboldai_vars.sampler_order:
|
||||||
|
warper = Warper.from_id(sid)
|
||||||
|
|
||||||
|
if not warper.value_is_valid():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if warper == warpers.RepetitionPenalty:
|
||||||
|
# Rep pen needs more data than other samplers
|
||||||
|
scores = warper.torch(scores, input_ids=input_ids)
|
||||||
|
else:
|
||||||
|
scores = warper.torch(scores)
|
||||||
|
|
||||||
|
assert scores is not None, f"Scores are None; warper '{warper}' is to blame"
|
||||||
|
|
||||||
|
if LOG_SAMPLER_NO_EFFECT:
|
||||||
|
if torch.equal(pre, scores):
|
||||||
|
logger.info(warper, "had no effect on the scores.")
|
||||||
|
pre = torch.Tensor(scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
def _raw_generate(
|
def _raw_generate(
|
||||||
self,
|
self,
|
||||||
prompt_tokens: Union[List[int], torch.Tensor],
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
@@ -228,8 +265,23 @@ class model_backend(InferenceModel):
|
|||||||
|
|
||||||
self.generator.gen_begin(gen_in)
|
self.generator.gen_begin(gen_in)
|
||||||
|
|
||||||
|
# from pudb.remote import set_trace
|
||||||
|
# set_trace(term_size=(200, 60))
|
||||||
|
|
||||||
for i in range(max_new):
|
for i in range(max_new):
|
||||||
token = self.generator.gen_single_token()
|
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache)
|
||||||
|
logits[:, :, self.tokenizer.bos_token_id] = -10000.0
|
||||||
|
|
||||||
|
logits = torch.unsqueeze(logits[0, -1, :], 0)
|
||||||
|
|
||||||
|
scores = self._apply_warpers(logits, gen_in)
|
||||||
|
|
||||||
|
scores = torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
token = torch.multinomial(scores, 1)
|
||||||
|
|
||||||
|
self.generator.gen_accept_token(token)
|
||||||
|
|
||||||
if token.item() == self.tokenizer.eos_token_id: break
|
if token.item() == self.tokenizer.eos_token_id: break
|
||||||
|
|
||||||
return GenerationResult(
|
return GenerationResult(
|
||||||
|
Reference in New Issue
Block a user