From dd1c25241d4c0816d35523f1e02448fdd94eebbd Mon Sep 17 00:00:00 2001 From: vfbd Date: Sun, 2 Oct 2022 17:43:54 -0400 Subject: [PATCH] Allow sampler seed and full determinism to be read/written in /config --- aiserver.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/aiserver.py b/aiserver.py index 1b74167f..8d9bd27d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -9880,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema): settingschanged() return {} +class SamplerSeedSettingSchema(KoboldSchema): + value: int = fields.Integer(required=True) + +@api_v1.get("/config/sampler_seed") +@api_schema_wrap +def get_config_sampler_seed(): + """--- + get: + summary: Retrieve the current global sampler seed value + tags: + - config + responses: + 200: + description: Successful request + content: + application/json: + schema: SamplerSeedSettingSchema + example: + value: 3475097509890965500 + """ + return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()} + +@api_v1.put("/config/sampler_seed") +@api_schema_wrap +def put_config_sampler_seed(body: SamplerSeedSettingSchema): + """--- + put: + summary: Set the global sampler seed value + tags: + - config + requestBody: + required: true + content: + application/json: + schema: SamplerSeedSettingSchema + example: + value: 3475097509890965500 + responses: + 200: + description: Successful request + content: + application/json: + schema: EmptySchema + {api_validation_error_response} + """ + if vars.use_colab_tpu: + import tpu_mtj_backend + tpu_mtj_backend.set_rng_seed(body.value) + else: + import torch + torch.manual_seed(body.value) + vars.seed = body.value + return {} + config_endpoint_schemas: List[Type[KoboldSchema]] = [] def config_endpoint_schema(c: Type[KoboldSchema]): @@ -10087,6 +10141,16 @@ class SamplerOrderSettingSchema(KoboldSchema): name = "sampler order" example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]" +@config_endpoint_schema +class SamplerFullDeterminismSettingSchema(KoboldSchema): + value = fields.Boolean(required=True) + class KoboldMeta: + route_name = "sampler_full_determinism" + obj = "vars" + var_name = "full_determinism" + name = "sampler full determinism" + example_yaml_value = "false" + for schema in config_endpoint_schemas: create_config_endpoint(schema=schema.__name__, method="GET")