Allow changing sampler seed and sampler order from API
This commit is contained in:
parent
7bd3125f5a
commit
1a59a4acea
53
aiserver.py
53
aiserver.py
|
@ -377,6 +377,7 @@ class vars:
|
|||
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
||||
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
||||
sampler_order = utils.default_sampler_order.copy()
|
||||
rng_states = {} # Used by the POST /generate endpoint to store sampler RNG states
|
||||
chatmode = False
|
||||
chatname = "You"
|
||||
adventure = False
|
||||
|
@ -630,7 +631,7 @@ tags = [
|
|||
api_version = None # This gets set automatically so don't change this value
|
||||
|
||||
api_v1 = KoboldAPISpec(
|
||||
version="1.1.4",
|
||||
version="1.2.0",
|
||||
prefixes=["/api/v1", "/api/latest"],
|
||||
tags=tags,
|
||||
)
|
||||
|
@ -7450,6 +7451,13 @@ def story_load_validator(name: str):
|
|||
raise ValidationError("Must be a valid story name.")
|
||||
return True
|
||||
|
||||
def permutation_validator(lst: list):
|
||||
if any(not isinstance(e, int) for e in lst):
|
||||
return
|
||||
if min(lst) != 0 or max(lst) != len(lst) - 1 or len(set(lst)) != len(lst):
|
||||
raise ValidationError("Must be a permutation of the first N non-negative integers, where N is the length of this array")
|
||||
return True
|
||||
|
||||
class GenerationInputSchema(SamplerSettingsSchema):
|
||||
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
||||
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
||||
|
@ -7469,6 +7477,9 @@ class GenerationInputSchema(SamplerSettingsSchema):
|
|||
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
|
||||
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
|
||||
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
|
||||
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
|
||||
sampler_seed: Optional[int] = fields.Integer(metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
|
||||
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
|
||||
|
||||
class GenerationResultSchema(KoboldSchema):
|
||||
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
|
||||
|
@ -7559,6 +7570,29 @@ def _generate_text(body: GenerationInputSchema):
|
|||
"msg": "Server is busy; please try again later.",
|
||||
"type": "service_unavailable",
|
||||
}}), mimetype="application/json", status=503))
|
||||
if vars.use_colab_tpu:
|
||||
import tpu_mtj_backend
|
||||
if hasattr(body, "sampler_seed"):
|
||||
# If a seed was specified, we need to save the global RNG state so we
|
||||
# can restore it later
|
||||
old_seed = vars.seed
|
||||
old_rng_state = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
|
||||
vars.seed = body.sampler_seed
|
||||
# We should try to use a previously saved RNG state with the same seed
|
||||
if body.sampler_seed in vars.rng_states:
|
||||
if vars.use_colab_tpu:
|
||||
tpu_mtj_backend.set_rng_state(vars.rng_states[body.sampler_seed])
|
||||
else:
|
||||
torch.set_rng_state(vars.rng_states[body.sampler_seed])
|
||||
else:
|
||||
if vars.use_colab_tpu:
|
||||
tpu_mtj_backend.set_rng_state(tpu_mtj_backend.new_rng_state(body.sampler_seed))
|
||||
else:
|
||||
torch.manual_seed(body.sampler_seed)
|
||||
vars.rng_states[body.sampler_seed] = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
|
||||
if hasattr(body, "sampler_order"):
|
||||
if len(body.sampler_order) < 7:
|
||||
body.sampler_order = [6] + body.sampler_order
|
||||
# This maps each property of the setting to use when sending the generate idempotently
|
||||
# To the object which typically contains it's value
|
||||
# This allows to set the property only for the API generation, and then revert the setting
|
||||
|
@ -7584,6 +7618,8 @@ def _generate_text(body: GenerationInputSchema):
|
|||
"max_context_length": ("vars", "max_length", None),
|
||||
"n": ("vars", "numseqs", None),
|
||||
"quiet": ("vars", "quiet", None),
|
||||
"sampler_order": ("vars", "sampler_order", None),
|
||||
"sampler_full_determinism": ("vars", "full_determinism", None),
|
||||
}
|
||||
saved_settings = {}
|
||||
set_aibusy(1)
|
||||
|
@ -7633,6 +7669,12 @@ def _generate_text(body: GenerationInputSchema):
|
|||
vars.output_streaming = output_streaming
|
||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||
spRequest(old_spfilename)
|
||||
if hasattr(body, "sampler_seed"):
|
||||
vars.seed = old_seed
|
||||
if vars.use_colab_tpu:
|
||||
tpu_mtj_backend.set_rng_state(old_rng_state)
|
||||
else:
|
||||
torch.set_rng_state(old_rng_state)
|
||||
set_aibusy(0)
|
||||
return output
|
||||
|
||||
|
@ -10035,6 +10077,15 @@ class AddSentenceSpacingSettingsSchema(KoboldSchema):
|
|||
name = "add sentence spacing (input formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
@config_endpoint_schema
|
||||
class SamplerOrderSettingSchema(KoboldSchema):
|
||||
value = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "sampler_order"
|
||||
obj = "vars"
|
||||
var_name = "sampler_order"
|
||||
name = "sampler order"
|
||||
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
|
||||
|
||||
|
||||
for schema in config_endpoint_schemas:
|
||||
|
|
|
@ -71,6 +71,15 @@ def set_rng_seed(seed: int):
|
|||
def randomize_rng_seed():
|
||||
return set_rng_seed(random.randrange(sys.maxsize))
|
||||
|
||||
def get_rng_state():
|
||||
return rng
|
||||
|
||||
def set_rng_state(state):
|
||||
global rng
|
||||
rng = state
|
||||
|
||||
def new_rng_state(seed: int):
|
||||
return random.Random(seed)
|
||||
|
||||
def warper_callback(logits) -> np.array:
|
||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||
|
|
Loading…
Reference in New Issue