mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Move pickle stuff into modeling/pickling.py
Ended up not moving to utils.py because most of the stuff in there isn't really model related, and it feels messy to just throw whatever in there. Originally the file was named "modeling/utils.py" and was going to be a place for assorted model-related functions, but I think this is better.
This commit is contained in:
@@ -51,15 +51,12 @@ import time
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.storage import UntypedStorage
|
||||
from modeling.pickling import RestrictedUnpickler, use_custom_unpickler
|
||||
from modeling.patches import LazyloadPatches
|
||||
|
||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||
@@ -236,84 +233,6 @@ class SafetensorsLazyTensor(LazyTensor):
|
||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||
)
|
||||
|
||||
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
||||
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
||||
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
||||
|
||||
ret = func(*args)
|
||||
|
||||
# BEGIN PATCH
|
||||
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
|
||||
if type(ret) is not new_type and not transformation_ok:
|
||||
# END PATCH
|
||||
ret = ret.as_subclass(new_type)
|
||||
|
||||
# Tensor does define __setstate__ even though it doesn't define
|
||||
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
||||
# on Tensor
|
||||
if (
|
||||
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
||||
is not Tensor.__setstate__
|
||||
):
|
||||
ret.__setstate__(state)
|
||||
else:
|
||||
ret = torch._utils._set_obj_state(ret, state)
|
||||
return ret
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def original_persistent_load(self, saved_id):
|
||||
return super().persistent_load(saved_id)
|
||||
|
||||
def forced_persistent_load(self, saved_id):
|
||||
if saved_id[0] != "storage":
|
||||
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
||||
return self.original_persistent_load(saved_id)
|
||||
|
||||
def find_class(self, module, name):
|
||||
if module == "collections" and name == "OrderedDict":
|
||||
return collections.OrderedDict
|
||||
elif module == "torch._utils" and name in (
|
||||
"_rebuild_tensor_v2",
|
||||
"_rebuild_meta_tensor_no_storage",
|
||||
):
|
||||
return getattr(torch._utils, name)
|
||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||
return _patched_rebuild_from_type_v2
|
||||
elif module == "torch" and name in (
|
||||
"DoubleStorage",
|
||||
"FloatStorage",
|
||||
"HalfStorage",
|
||||
"LongStorage",
|
||||
"IntStorage",
|
||||
"ShortStorage",
|
||||
"CharStorage",
|
||||
"ByteStorage",
|
||||
"BoolStorage",
|
||||
"BFloat16Storage",
|
||||
"Tensor",
|
||||
"float16",
|
||||
):
|
||||
return getattr(torch, name)
|
||||
elif module == "numpy.core.multiarray" and name == "scalar":
|
||||
return np.core.multiarray.scalar
|
||||
elif module == "numpy" and name == "dtype":
|
||||
return np.dtype
|
||||
elif module == "_codecs" and name == "encode":
|
||||
return _codecs.encode
|
||||
else:
|
||||
# Forbid everything else.
|
||||
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||
raise pickle.UnpicklingError(
|
||||
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
||||
)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
self.original_persistent_load = getattr(
|
||||
self, "persistent_load", pickle.Unpickler.persistent_load
|
||||
)
|
||||
self.persistent_load = self.forced_persistent_load
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
|
||||
class _LazyUnpickler(RestrictedUnpickler):
|
||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||
@@ -412,25 +331,6 @@ def patch_safetensors(callback):
|
||||
safetensors.torch.load_file = safetensors_load
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||
try:
|
||||
old_unpickler = pickle.Unpickler
|
||||
pickle.Unpickler = unpickler
|
||||
|
||||
old_pickle_load = pickle.load
|
||||
|
||||
def new_pickle_load(*args, **kwargs):
|
||||
return pickle.Unpickler(*args, **kwargs).load()
|
||||
|
||||
pickle.load = new_pickle_load
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
pickle.Unpickler = old_unpickler
|
||||
pickle.load = old_pickle_load
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
Reference in New Issue
Block a user