Allow sampler seed and full determinism to be read/written in /config

This commit is contained in:
vfbd 2022-10-02 17:43:54 -04:00
parent 1a59a4acea
commit dd1c25241d
1 changed files with 64 additions and 0 deletions

View File

@ -9880,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
settingschanged()
return {}
class SamplerSeedSettingSchema(KoboldSchema):
value: int = fields.Integer(required=True)
@api_v1.get("/config/sampler_seed")
@api_schema_wrap
def get_config_sampler_seed():
"""---
get:
summary: Retrieve the current global sampler seed value
tags:
- config
responses:
200:
description: Successful request
content:
application/json:
schema: SamplerSeedSettingSchema
example:
value: 3475097509890965500
"""
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
@api_v1.put("/config/sampler_seed")
@api_schema_wrap
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
"""---
put:
summary: Set the global sampler seed value
tags:
- config
requestBody:
required: true
content:
application/json:
schema: SamplerSeedSettingSchema
example:
value: 3475097509890965500
responses:
200:
description: Successful request
content:
application/json:
schema: EmptySchema
{api_validation_error_response}
"""
if vars.use_colab_tpu:
import tpu_mtj_backend
tpu_mtj_backend.set_rng_seed(body.value)
else:
import torch
torch.manual_seed(body.value)
vars.seed = body.value
return {}
config_endpoint_schemas: List[Type[KoboldSchema]] = []
def config_endpoint_schema(c: Type[KoboldSchema]):
@ -10087,6 +10141,16 @@ class SamplerOrderSettingSchema(KoboldSchema):
name = "sampler order"
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
@config_endpoint_schema
class SamplerFullDeterminismSettingSchema(KoboldSchema):
value = fields.Boolean(required=True)
class KoboldMeta:
route_name = "sampler_full_determinism"
obj = "vars"
var_name = "full_determinism"
name = "sampler full determinism"
example_yaml_value = "false"
for schema in config_endpoint_schemas:
create_config_endpoint(schema=schema.__name__, method="GET")