mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-07 13:52:10 +01:00
Make transformers warpers dynamically update their parameters
So that if you change, e.g., `top_p`, from a Lua generation modifier or from the settings menu during generation, the rest of the generation will use the new setting value instead of retaining the settings it had when generation began.
This commit is contained in:
parent
91b6289897
commit
380b54167a
15
aiserver.py
15
aiserver.py
@ -7,6 +7,7 @@
|
||||
|
||||
# External packages
|
||||
import eventlet
|
||||
from transformers.generation_logits_process import RepetitionPenaltyLogitsProcessor
|
||||
eventlet.monkey_patch()
|
||||
import os
|
||||
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
|
||||
@ -617,6 +618,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name):
|
||||
old_call = cls.__call__
|
||||
def new_call(self, *args, **kwargs):
|
||||
setattr(self, field_name, getattr(vars, var_name))
|
||||
return old_call(self, *args, **kwargs)
|
||||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen")
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k")
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p")
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp")
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
@ -628,6 +641,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.tfs = vars.tfs
|
||||
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
|
Loading…
Reference in New Issue
Block a user