mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Move pickle stuff into modeling/pickling.py
Ended up not moving to utils.py because most of the stuff in there isn't really model related, and it feels messy to just throw whatever in there. Originally the file was named "modeling/utils.py" and was going to be a place for assorted model-related functions, but I think this is better.
This commit is contained in:
74
aiserver.py
74
aiserver.py
@@ -71,6 +71,8 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForToken
|
|||||||
import transformers
|
import transformers
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from modeling.inference_models.utils import RestrictedUnpickler, use_custom_unpickler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||||
except:
|
except:
|
||||||
@@ -1678,79 +1680,7 @@ def unload_model():
|
|||||||
#Reload our badwords
|
#Reload our badwords
|
||||||
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def original_persistent_load(self, saved_id):
|
|
||||||
return super().persistent_load(saved_id)
|
|
||||||
|
|
||||||
def forced_persistent_load(self, saved_id):
|
|
||||||
if saved_id[0] != "storage":
|
|
||||||
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
|
||||||
return self.original_persistent_load(saved_id)
|
|
||||||
|
|
||||||
def find_class(self, module, name):
|
|
||||||
if module == "collections" and name == "OrderedDict":
|
|
||||||
return collections.OrderedDict
|
|
||||||
elif module == "torch._utils" and name in (
|
|
||||||
"_rebuild_tensor_v2",
|
|
||||||
"_rebuild_meta_tensor_no_storage",
|
|
||||||
):
|
|
||||||
return getattr(torch._utils, name)
|
|
||||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
|
||||||
return torch._tensor._rebuild_from_type_v2
|
|
||||||
elif module == "torch" and name in (
|
|
||||||
"DoubleStorage",
|
|
||||||
"FloatStorage",
|
|
||||||
"HalfStorage",
|
|
||||||
"LongStorage",
|
|
||||||
"IntStorage",
|
|
||||||
"ShortStorage",
|
|
||||||
"CharStorage",
|
|
||||||
"ByteStorage",
|
|
||||||
"BoolStorage",
|
|
||||||
"BFloat16Storage",
|
|
||||||
"Tensor",
|
|
||||||
"float16",
|
|
||||||
):
|
|
||||||
return getattr(torch, name)
|
|
||||||
elif module == "numpy.core.multiarray" and name == "scalar":
|
|
||||||
return np.core.multiarray.scalar
|
|
||||||
elif module == "numpy" and name == "dtype":
|
|
||||||
return np.dtype
|
|
||||||
elif module == "_codecs" and name == "encode":
|
|
||||||
return _codecs.encode
|
|
||||||
else:
|
|
||||||
# Forbid everything else.
|
|
||||||
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
|
||||||
raise pickle.UnpicklingError(
|
|
||||||
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
|
||||||
self.original_persistent_load = getattr(
|
|
||||||
self, "persistent_load", pickle.Unpickler.persistent_load
|
|
||||||
)
|
|
||||||
self.persistent_load = self.forced_persistent_load
|
|
||||||
return super().load(*args, **kwargs)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
|
||||||
try:
|
|
||||||
old_unpickler = pickle.Unpickler
|
|
||||||
pickle.Unpickler = unpickler
|
|
||||||
|
|
||||||
old_pickle_load = pickle.load
|
|
||||||
|
|
||||||
def new_pickle_load(*args, **kwargs):
|
|
||||||
return pickle.Unpickler(*args, **kwargs).load()
|
|
||||||
|
|
||||||
pickle.load = new_pickle_load
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pickle.Unpickler = old_unpickler
|
|
||||||
pickle.load = old_pickle_load
|
|
||||||
|
|
||||||
def load_model(model_backend, initial_load=False):
|
def load_model(model_backend, initial_load=False):
|
||||||
global model
|
global model
|
||||||
global tokenizer
|
global tokenizer
|
||||||
|
@@ -51,15 +51,12 @@ import time
|
|||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
import collections
|
|
||||||
import _codecs
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.storage import UntypedStorage
|
from torch.storage import UntypedStorage
|
||||||
|
from modeling.pickling import RestrictedUnpickler, use_custom_unpickler
|
||||||
from modeling.patches import LazyloadPatches
|
from modeling.patches import LazyloadPatches
|
||||||
|
|
||||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||||
@@ -236,84 +233,6 @@ class SafetensorsLazyTensor(LazyTensor):
|
|||||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||||
)
|
)
|
||||||
|
|
||||||
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
|
||||||
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
|
||||||
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
|
||||||
|
|
||||||
ret = func(*args)
|
|
||||||
|
|
||||||
# BEGIN PATCH
|
|
||||||
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
|
|
||||||
if type(ret) is not new_type and not transformation_ok:
|
|
||||||
# END PATCH
|
|
||||||
ret = ret.as_subclass(new_type)
|
|
||||||
|
|
||||||
# Tensor does define __setstate__ even though it doesn't define
|
|
||||||
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
|
||||||
# on Tensor
|
|
||||||
if (
|
|
||||||
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
|
||||||
is not Tensor.__setstate__
|
|
||||||
):
|
|
||||||
ret.__setstate__(state)
|
|
||||||
else:
|
|
||||||
ret = torch._utils._set_obj_state(ret, state)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def original_persistent_load(self, saved_id):
|
|
||||||
return super().persistent_load(saved_id)
|
|
||||||
|
|
||||||
def forced_persistent_load(self, saved_id):
|
|
||||||
if saved_id[0] != "storage":
|
|
||||||
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
|
||||||
return self.original_persistent_load(saved_id)
|
|
||||||
|
|
||||||
def find_class(self, module, name):
|
|
||||||
if module == "collections" and name == "OrderedDict":
|
|
||||||
return collections.OrderedDict
|
|
||||||
elif module == "torch._utils" and name in (
|
|
||||||
"_rebuild_tensor_v2",
|
|
||||||
"_rebuild_meta_tensor_no_storage",
|
|
||||||
):
|
|
||||||
return getattr(torch._utils, name)
|
|
||||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
|
||||||
return _patched_rebuild_from_type_v2
|
|
||||||
elif module == "torch" and name in (
|
|
||||||
"DoubleStorage",
|
|
||||||
"FloatStorage",
|
|
||||||
"HalfStorage",
|
|
||||||
"LongStorage",
|
|
||||||
"IntStorage",
|
|
||||||
"ShortStorage",
|
|
||||||
"CharStorage",
|
|
||||||
"ByteStorage",
|
|
||||||
"BoolStorage",
|
|
||||||
"BFloat16Storage",
|
|
||||||
"Tensor",
|
|
||||||
"float16",
|
|
||||||
):
|
|
||||||
return getattr(torch, name)
|
|
||||||
elif module == "numpy.core.multiarray" and name == "scalar":
|
|
||||||
return np.core.multiarray.scalar
|
|
||||||
elif module == "numpy" and name == "dtype":
|
|
||||||
return np.dtype
|
|
||||||
elif module == "_codecs" and name == "encode":
|
|
||||||
return _codecs.encode
|
|
||||||
else:
|
|
||||||
# Forbid everything else.
|
|
||||||
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
|
||||||
raise pickle.UnpicklingError(
|
|
||||||
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
|
||||||
self.original_persistent_load = getattr(
|
|
||||||
self, "persistent_load", pickle.Unpickler.persistent_load
|
|
||||||
)
|
|
||||||
self.persistent_load = self.forced_persistent_load
|
|
||||||
return super().load(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class _LazyUnpickler(RestrictedUnpickler):
|
class _LazyUnpickler(RestrictedUnpickler):
|
||||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||||
@@ -412,25 +331,6 @@ def patch_safetensors(callback):
|
|||||||
safetensors.torch.load_file = safetensors_load
|
safetensors.torch.load_file = safetensors_load
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
|
||||||
try:
|
|
||||||
old_unpickler = pickle.Unpickler
|
|
||||||
pickle.Unpickler = unpickler
|
|
||||||
|
|
||||||
old_pickle_load = pickle.load
|
|
||||||
|
|
||||||
def new_pickle_load(*args, **kwargs):
|
|
||||||
return pickle.Unpickler(*args, **kwargs).load()
|
|
||||||
|
|
||||||
pickle.load = new_pickle_load
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pickle.Unpickler = old_unpickler
|
|
||||||
pickle.load = old_pickle_load
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
111
modeling/pickling.py
Normal file
111
modeling/pickling.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import _codecs
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
import modeling
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
||||||
|
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
||||||
|
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
||||||
|
|
||||||
|
ret = func(*args)
|
||||||
|
|
||||||
|
# BEGIN PATCH
|
||||||
|
transformation_ok = isinstance(ret, modeling.lazy_loader.LazyTensor) and new_type == Tensor
|
||||||
|
if type(ret) is not new_type and not transformation_ok:
|
||||||
|
# END PATCH
|
||||||
|
ret = ret.as_subclass(new_type)
|
||||||
|
|
||||||
|
# Tensor does define __setstate__ even though it doesn't define
|
||||||
|
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
||||||
|
# on Tensor
|
||||||
|
if (
|
||||||
|
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
||||||
|
is not Tensor.__setstate__
|
||||||
|
):
|
||||||
|
ret.__setstate__(state)
|
||||||
|
else:
|
||||||
|
ret = torch._utils._set_obj_state(ret, state)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def original_persistent_load(self, saved_id):
|
||||||
|
return super().persistent_load(saved_id)
|
||||||
|
|
||||||
|
def forced_persistent_load(self, saved_id):
|
||||||
|
if saved_id[0] != "storage":
|
||||||
|
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
||||||
|
return self.original_persistent_load(saved_id)
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == "collections" and name == "OrderedDict":
|
||||||
|
return collections.OrderedDict
|
||||||
|
elif module == "torch._utils" and name in (
|
||||||
|
"_rebuild_tensor_v2",
|
||||||
|
"_rebuild_meta_tensor_no_storage",
|
||||||
|
):
|
||||||
|
return getattr(torch._utils, name)
|
||||||
|
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||||
|
return _patched_rebuild_from_type_v2
|
||||||
|
elif module == "torch" and name in (
|
||||||
|
"DoubleStorage",
|
||||||
|
"FloatStorage",
|
||||||
|
"HalfStorage",
|
||||||
|
"LongStorage",
|
||||||
|
"IntStorage",
|
||||||
|
"ShortStorage",
|
||||||
|
"CharStorage",
|
||||||
|
"ByteStorage",
|
||||||
|
"BoolStorage",
|
||||||
|
"BFloat16Storage",
|
||||||
|
"Tensor",
|
||||||
|
"float16",
|
||||||
|
):
|
||||||
|
return getattr(torch, name)
|
||||||
|
elif module == "numpy.core.multiarray" and name == "scalar":
|
||||||
|
return np.core.multiarray.scalar
|
||||||
|
elif module == "numpy" and name == "dtype":
|
||||||
|
return np.dtype
|
||||||
|
elif module == "_codecs" and name == "encode":
|
||||||
|
return _codecs.encode
|
||||||
|
else:
|
||||||
|
# Forbid everything else.
|
||||||
|
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||||
|
raise pickle.UnpicklingError(
|
||||||
|
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
self.original_persistent_load = getattr(
|
||||||
|
self, "persistent_load", pickle.Unpickler.persistent_load
|
||||||
|
)
|
||||||
|
self.persistent_load = self.forced_persistent_load
|
||||||
|
return super().load(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def use_custom_unpickler(unpickler: pickle.Unpickler = RestrictedUnpickler):
|
||||||
|
try:
|
||||||
|
old_unpickler = pickle.Unpickler
|
||||||
|
pickle.Unpickler = unpickler
|
||||||
|
|
||||||
|
old_pickle_load = pickle.load
|
||||||
|
|
||||||
|
def new_pickle_load(*args, **kwargs):
|
||||||
|
return pickle.Unpickler(*args, **kwargs).load()
|
||||||
|
|
||||||
|
pickle.load = new_pickle_load
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pickle.Unpickler = old_unpickler
|
||||||
|
pickle.load = old_pickle_load
|
Reference in New Issue
Block a user