Allow sampler seed and full determinism to be read/written in /config
This commit is contained in:
parent
1a59a4acea
commit
dd1c25241d
64
aiserver.py
64
aiserver.py
|
@ -9880,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
|
||||||
settingschanged()
|
settingschanged()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class SamplerSeedSettingSchema(KoboldSchema):
|
||||||
|
value: int = fields.Integer(required=True)
|
||||||
|
|
||||||
|
@api_v1.get("/config/sampler_seed")
|
||||||
|
@api_schema_wrap
|
||||||
|
def get_config_sampler_seed():
|
||||||
|
"""---
|
||||||
|
get:
|
||||||
|
summary: Retrieve the current global sampler seed value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SamplerSeedSettingSchema
|
||||||
|
example:
|
||||||
|
value: 3475097509890965500
|
||||||
|
"""
|
||||||
|
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
|
||||||
|
|
||||||
|
@api_v1.put("/config/sampler_seed")
|
||||||
|
@api_schema_wrap
|
||||||
|
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
|
||||||
|
"""---
|
||||||
|
put:
|
||||||
|
summary: Set the global sampler seed value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SamplerSeedSettingSchema
|
||||||
|
example:
|
||||||
|
value: 3475097509890965500
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: EmptySchema
|
||||||
|
{api_validation_error_response}
|
||||||
|
"""
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
import tpu_mtj_backend
|
||||||
|
tpu_mtj_backend.set_rng_seed(body.value)
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(body.value)
|
||||||
|
vars.seed = body.value
|
||||||
|
return {}
|
||||||
|
|
||||||
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
||||||
|
|
||||||
def config_endpoint_schema(c: Type[KoboldSchema]):
|
def config_endpoint_schema(c: Type[KoboldSchema]):
|
||||||
|
@ -10087,6 +10141,16 @@ class SamplerOrderSettingSchema(KoboldSchema):
|
||||||
name = "sampler order"
|
name = "sampler order"
|
||||||
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
|
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class SamplerFullDeterminismSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "sampler_full_determinism"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "full_determinism"
|
||||||
|
name = "sampler full determinism"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
|
||||||
for schema in config_endpoint_schemas:
|
for schema in config_endpoint_schemas:
|
||||||
create_config_endpoint(schema=schema.__name__, method="GET")
|
create_config_endpoint(schema=schema.__name__, method="GET")
|
||||||
|
|
Loading…
Reference in New Issue