Allow sampler seed and full determinism to be read/written in /config
This commit is contained in:
parent
1a59a4acea
commit
dd1c25241d
64
aiserver.py
64
aiserver.py
|
@ -9880,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
|
|||
settingschanged()
|
||||
return {}
|
||||
|
||||
class SamplerSeedSettingSchema(KoboldSchema):
|
||||
value: int = fields.Integer(required=True)
|
||||
|
||||
@api_v1.get("/config/sampler_seed")
|
||||
@api_schema_wrap
|
||||
def get_config_sampler_seed():
|
||||
"""---
|
||||
get:
|
||||
summary: Retrieve the current global sampler seed value
|
||||
tags:
|
||||
- config
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: SamplerSeedSettingSchema
|
||||
example:
|
||||
value: 3475097509890965500
|
||||
"""
|
||||
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
|
||||
|
||||
@api_v1.put("/config/sampler_seed")
|
||||
@api_schema_wrap
|
||||
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
|
||||
"""---
|
||||
put:
|
||||
summary: Set the global sampler seed value
|
||||
tags:
|
||||
- config
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: SamplerSeedSettingSchema
|
||||
example:
|
||||
value: 3475097509890965500
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: EmptySchema
|
||||
{api_validation_error_response}
|
||||
"""
|
||||
if vars.use_colab_tpu:
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.set_rng_seed(body.value)
|
||||
else:
|
||||
import torch
|
||||
torch.manual_seed(body.value)
|
||||
vars.seed = body.value
|
||||
return {}
|
||||
|
||||
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
||||
|
||||
def config_endpoint_schema(c: Type[KoboldSchema]):
|
||||
|
@ -10087,6 +10141,16 @@ class SamplerOrderSettingSchema(KoboldSchema):
|
|||
name = "sampler order"
|
||||
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
|
||||
|
||||
@config_endpoint_schema
|
||||
class SamplerFullDeterminismSettingSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "sampler_full_determinism"
|
||||
obj = "vars"
|
||||
var_name = "full_determinism"
|
||||
name = "sampler full determinism"
|
||||
example_yaml_value = "false"
|
||||
|
||||
|
||||
for schema in config_endpoint_schemas:
|
||||
create_config_endpoint(schema=schema.__name__, method="GET")
|
||||
|
|
Loading…
Reference in New Issue