Compare commits

...

10 Commits

Author SHA1 Message Date
henk717 59e3a40496
Merge pull request #165 from henk717/united
1.19.1
2022-10-12 15:35:09 +02:00
Henk 64715b18d6 Version bump 2022-10-12 14:54:11 +02:00
Henk d5143eeb80 LUA Error as Error 2022-10-12 01:23:00 +02:00
henk717 739cf0aae7
Merge pull request #227 from VE-FORBRYDERNE/pickle
Custom unpickler to avoid pickle's arbitrary code execution vulnerability
2022-10-07 02:12:53 +02:00
vfbd 323f593a96 Custom unpickler to avoid pickle's arbitrary code execution vulnerability 2022-10-06 20:08:08 -04:00
henk717 b85d74f22c
Merge branch 'KoboldAI:main' into united 2022-10-05 19:51:29 +02:00
henk717 9f18811ff9
Merge pull request #226 from VE-FORBRYDERNE/api-settings
Allow changing and reading sampler seed and sampler order from API
2022-10-04 20:30:25 +02:00
vfbd bdfa6d86b7 Seed has to be a 64-bit unsigned int or PyTorch will throw an error
tpu_mtj_backend's seed can be an integer of arbitrary size but we will
limit it to a 64-bit unsigned integer anyways for consistency.
2022-10-02 17:50:32 -04:00
vfbd dd1c25241d Allow sampler seed and full determinism to be read/written in /config 2022-10-02 17:43:54 -04:00
vfbd 1a59a4acea Allow changing sampler seed and sampler order from API 2022-10-02 16:25:51 -04:00
3 changed files with 203 additions and 16 deletions

View File

@ -1,7 +1,7 @@
#!/usr/bin/python3 #!/usr/bin/python3
#==================================================================# #==================================================================#
# KoboldAI # KoboldAI
# Version: 1.19.0 # Version: 1.19.1
# By: The KoboldAI Community # By: The KoboldAI Community
#==================================================================# #==================================================================#
@ -377,6 +377,7 @@ class vars:
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
sampler_order = utils.default_sampler_order.copy() sampler_order = utils.default_sampler_order.copy()
rng_states = {} # Used by the POST /generate endpoint to store sampler RNG states
chatmode = False chatmode = False
chatname = "You" chatname = "You"
adventure = False adventure = False
@ -630,7 +631,7 @@ tags = [
api_version = None # This gets set automatically so don't change this value api_version = None # This gets set automatically so don't change this value
api_v1 = KoboldAPISpec( api_v1 = KoboldAPISpec(
version="1.1.4", version="1.2.0",
prefixes=["/api/v1", "/api/latest"], prefixes=["/api/v1", "/api/latest"],
tags=tags, tags=tags,
) )
@ -2963,7 +2964,7 @@ def load_lua_scripts():
if(vars.serverstarted): if(vars.serverstarted):
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
sendUSStatItems() sendUSStatItems()
logger.debug('LUA ERROR: ' + str(e).replace("\033", "")) logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.") logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
if(vars.serverstarted): if(vars.serverstarted):
set_aibusy(0) set_aibusy(0)
@ -7450,6 +7451,13 @@ def story_load_validator(name: str):
raise ValidationError("Must be a valid story name.") raise ValidationError("Must be a valid story name.")
return True return True
def permutation_validator(lst: list):
if any(not isinstance(e, int) for e in lst):
return
if min(lst) != 0 or max(lst) != len(lst) - 1 or len(set(lst)) != len(lst):
raise ValidationError("Must be a permutation of the first N non-negative integers, where N is the length of this array")
return True
class GenerationInputSchema(SamplerSettingsSchema): class GenerationInputSchema(SamplerSettingsSchema):
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."}) prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."}) use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
@ -7469,6 +7477,9 @@ class GenerationInputSchema(SamplerSettingsSchema):
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"}) disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."}) frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."}) quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
class GenerationResultSchema(KoboldSchema): class GenerationResultSchema(KoboldSchema):
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."}) text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
@ -7559,6 +7570,29 @@ def _generate_text(body: GenerationInputSchema):
"msg": "Server is busy; please try again later.", "msg": "Server is busy; please try again later.",
"type": "service_unavailable", "type": "service_unavailable",
}}), mimetype="application/json", status=503)) }}), mimetype="application/json", status=503))
if vars.use_colab_tpu:
import tpu_mtj_backend
if hasattr(body, "sampler_seed"):
# If a seed was specified, we need to save the global RNG state so we
# can restore it later
old_seed = vars.seed
old_rng_state = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
vars.seed = body.sampler_seed
# We should try to use a previously saved RNG state with the same seed
if body.sampler_seed in vars.rng_states:
if vars.use_colab_tpu:
tpu_mtj_backend.set_rng_state(vars.rng_states[body.sampler_seed])
else:
torch.set_rng_state(vars.rng_states[body.sampler_seed])
else:
if vars.use_colab_tpu:
tpu_mtj_backend.set_rng_state(tpu_mtj_backend.new_rng_state(body.sampler_seed))
else:
torch.manual_seed(body.sampler_seed)
vars.rng_states[body.sampler_seed] = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
if hasattr(body, "sampler_order"):
if len(body.sampler_order) < 7:
body.sampler_order = [6] + body.sampler_order
# This maps each property of the setting to use when sending the generate idempotently # This maps each property of the setting to use when sending the generate idempotently
# To the object which typically contains it's value # To the object which typically contains it's value
# This allows to set the property only for the API generation, and then revert the setting # This allows to set the property only for the API generation, and then revert the setting
@ -7584,6 +7618,8 @@ def _generate_text(body: GenerationInputSchema):
"max_context_length": ("vars", "max_length", None), "max_context_length": ("vars", "max_length", None),
"n": ("vars", "numseqs", None), "n": ("vars", "numseqs", None),
"quiet": ("vars", "quiet", None), "quiet": ("vars", "quiet", None),
"sampler_order": ("vars", "sampler_order", None),
"sampler_full_determinism": ("vars", "full_determinism", None),
} }
saved_settings = {} saved_settings = {}
set_aibusy(1) set_aibusy(1)
@ -7633,6 +7669,12 @@ def _generate_text(body: GenerationInputSchema):
vars.output_streaming = output_streaming vars.output_streaming = output_streaming
if vars.allowsp and getattr(body, "soft_prompt", None) is not None: if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
spRequest(old_spfilename) spRequest(old_spfilename)
if hasattr(body, "sampler_seed"):
vars.seed = old_seed
if vars.use_colab_tpu:
tpu_mtj_backend.set_rng_state(old_rng_state)
else:
torch.set_rng_state(old_rng_state)
set_aibusy(0) set_aibusy(0)
return output return output
@ -9838,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
settingschanged() settingschanged()
return {} return {}
class SamplerSeedSettingSchema(KoboldSchema):
value: int = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), required=True)
@api_v1.get("/config/sampler_seed")
@api_schema_wrap
def get_config_sampler_seed():
"""---
get:
summary: Retrieve the current global sampler seed value
tags:
- config
responses:
200:
description: Successful request
content:
application/json:
schema: SamplerSeedSettingSchema
example:
value: 3475097509890965500
"""
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
@api_v1.put("/config/sampler_seed")
@api_schema_wrap
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
"""---
put:
summary: Set the global sampler seed value
tags:
- config
requestBody:
required: true
content:
application/json:
schema: SamplerSeedSettingSchema
example:
value: 3475097509890965500
responses:
200:
description: Successful request
content:
application/json:
schema: EmptySchema
{api_validation_error_response}
"""
if vars.use_colab_tpu:
import tpu_mtj_backend
tpu_mtj_backend.set_rng_seed(body.value)
else:
import torch
torch.manual_seed(body.value)
vars.seed = body.value
return {}
config_endpoint_schemas: List[Type[KoboldSchema]] = [] config_endpoint_schemas: List[Type[KoboldSchema]] = []
def config_endpoint_schema(c: Type[KoboldSchema]): def config_endpoint_schema(c: Type[KoboldSchema]):
@ -10035,6 +10131,25 @@ class AddSentenceSpacingSettingsSchema(KoboldSchema):
name = "add sentence spacing (input formatting)" name = "add sentence spacing (input formatting)"
example_yaml_value = "false" example_yaml_value = "false"
@config_endpoint_schema
class SamplerOrderSettingSchema(KoboldSchema):
value = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], required=True)
class KoboldMeta:
route_name = "sampler_order"
obj = "vars"
var_name = "sampler_order"
name = "sampler order"
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
@config_endpoint_schema
class SamplerFullDeterminismSettingSchema(KoboldSchema):
value = fields.Boolean(required=True)
class KoboldMeta:
route_name = "sampler_full_determinism"
obj = "vars"
var_name = "full_determinism"
name = "sampler full determinism"
example_yaml_value = "false"
for schema in config_endpoint_schemas: for schema in config_endpoint_schemas:

View File

@ -50,9 +50,12 @@ import itertools
import zipfile import zipfile
import pickle import pickle
import torch import torch
import numpy as np
import collections
import _codecs
import utils import utils
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@ -111,8 +114,50 @@ class LazyTensor:
tensor._backward_hooks = self.backward_hooks tensor._backward_hooks = self.backward_hooks
return tensor return tensor
class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id):
return super().persistent_load(saved_id)
class _LazyUnpickler(pickle.Unpickler): def forced_persistent_load(self, saved_id):
if saved_id[0] != "storage":
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
return self.original_persistent_load(saved_id)
def find_class(self, module, name):
if module == "collections" and name == "OrderedDict":
return collections.OrderedDict
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2
elif module == "torch" and name in (
"DoubleStorage",
"FloatStorage",
"HalfStorage",
"LongStorage",
"IntStorage",
"ShortStorage",
"CharStorage",
"ByteStorage",
"BoolStorage",
"BFloat16Storage",
):
return getattr(torch, name)
elif module == "numpy.core.multiarray" and name == "scalar":
return np.core.multiarray.scalar
elif module == "numpy" and name == "dtype":
return np.dtype
elif module == "_codecs" and name == "encode":
return _codecs.encode
else:
# Forbid everything else.
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
def load(self, *args, **kwargs):
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
self.persistent_load = self.forced_persistent_load
return super().load(*args, **kwargs)
class _LazyUnpickler(RestrictedUnpickler):
lazy_loaded_storages: Dict[str, LazyTensor] lazy_loaded_storages: Dict[str, LazyTensor]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -127,7 +172,6 @@ class _LazyUnpickler(pickle.Unpickler):
return LazyTensor(storage_type, key, location) return LazyTensor(storage_type, key, location)
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
self.persistent_load = self.forced_persistent_load
retval = super().load(*args, **kwargs) retval = super().load(*args, **kwargs)
self.lazy_loaded_storages = {} self.lazy_loaded_storages = {}
return retval return retval
@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
unexpected_keys.append(key) unexpected_keys.append(key)
@contextlib.contextmanager
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = unpickler
old_pickle_load = pickle.load
def new_pickle_load(*args, **kwargs):
return pickle.Unpickler(*args, **kwargs).load()
pickle.load = new_pickle_load
yield
finally:
pickle.Unpickler = old_unpickler
pickle.load = old_pickle_load
@contextlib.contextmanager @contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False): def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable: if not enable:
with use_custom_unpickler(RestrictedUnpickler):
yield False yield False
return return
try: try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = _LazyUnpickler
old_rebuild_tensor = torch._utils._rebuild_tensor old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor
@ -261,10 +322,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
old_load_from_state_dict = torch.nn.Module._load_from_state_dict old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler):
yield True yield True
finally: finally:
pickle.Unpickler = old_unpickler
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load torch.load = old_torch_load
if dematerialized_modules: if dematerialized_modules:

View File

@ -55,7 +55,7 @@ from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
__seed = random.randrange(sys.maxsize) __seed = random.randrange(2**64)
rng = random.Random(__seed) rng = random.Random(__seed)
@ -69,8 +69,17 @@ def set_rng_seed(seed: int):
return seed return seed
def randomize_rng_seed(): def randomize_rng_seed():
return set_rng_seed(random.randrange(sys.maxsize)) return set_rng_seed(random.randrange(2**64))
def get_rng_state():
return rng
def set_rng_state(state):
global rng
rng = state
def new_rng_state(seed: int):
return random.Random(seed)
def warper_callback(logits) -> np.array: def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
@ -946,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
import torch import torch
import torch.utils.dlpack import torch.utils.dlpack
import torch_lazy_loader
from tqdm.auto import tqdm from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap( move_xmap = jax.experimental.maps.xmap(
@ -987,6 +997,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
continue continue
layer = checkpoint_layer - 2 layer = checkpoint_layer - 2
shards = [] shards = []
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
for checkpoint_shard in range(checkpoint_shards): for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]: for key in shards[0]: