mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Ended up not moving to utils.py because most of the stuff in there isn't really model related, and it feels messy to just throw whatever in there. Originally the file was named "modeling/utils.py" and was going to be a place for assorted model-related functions, but I think this is better.
112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import collections
|
|
import contextlib
|
|
import pickle
|
|
|
|
import _codecs
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
import modeling
|
|
|
|
|
|
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
|
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
|
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
|
|
|
ret = func(*args)
|
|
|
|
# BEGIN PATCH
|
|
transformation_ok = isinstance(ret, modeling.lazy_loader.LazyTensor) and new_type == Tensor
|
|
if type(ret) is not new_type and not transformation_ok:
|
|
# END PATCH
|
|
ret = ret.as_subclass(new_type)
|
|
|
|
# Tensor does define __setstate__ even though it doesn't define
|
|
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
|
# on Tensor
|
|
if (
|
|
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
|
is not Tensor.__setstate__
|
|
):
|
|
ret.__setstate__(state)
|
|
else:
|
|
ret = torch._utils._set_obj_state(ret, state)
|
|
return ret
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def original_persistent_load(self, saved_id):
|
|
return super().persistent_load(saved_id)
|
|
|
|
def forced_persistent_load(self, saved_id):
|
|
if saved_id[0] != "storage":
|
|
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
|
return self.original_persistent_load(saved_id)
|
|
|
|
def find_class(self, module, name):
|
|
if module == "collections" and name == "OrderedDict":
|
|
return collections.OrderedDict
|
|
elif module == "torch._utils" and name in (
|
|
"_rebuild_tensor_v2",
|
|
"_rebuild_meta_tensor_no_storage",
|
|
):
|
|
return getattr(torch._utils, name)
|
|
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
|
return _patched_rebuild_from_type_v2
|
|
elif module == "torch" and name in (
|
|
"DoubleStorage",
|
|
"FloatStorage",
|
|
"HalfStorage",
|
|
"LongStorage",
|
|
"IntStorage",
|
|
"ShortStorage",
|
|
"CharStorage",
|
|
"ByteStorage",
|
|
"BoolStorage",
|
|
"BFloat16Storage",
|
|
"Tensor",
|
|
"float16",
|
|
):
|
|
return getattr(torch, name)
|
|
elif module == "numpy.core.multiarray" and name == "scalar":
|
|
return np.core.multiarray.scalar
|
|
elif module == "numpy" and name == "dtype":
|
|
return np.dtype
|
|
elif module == "_codecs" and name == "encode":
|
|
return _codecs.encode
|
|
else:
|
|
# Forbid everything else.
|
|
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
|
raise pickle.UnpicklingError(
|
|
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
|
)
|
|
|
|
def load(self, *args, **kwargs):
|
|
self.original_persistent_load = getattr(
|
|
self, "persistent_load", pickle.Unpickler.persistent_load
|
|
)
|
|
self.persistent_load = self.forced_persistent_load
|
|
return super().load(*args, **kwargs)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def use_custom_unpickler(unpickler: pickle.Unpickler = RestrictedUnpickler):
|
|
try:
|
|
old_unpickler = pickle.Unpickler
|
|
pickle.Unpickler = unpickler
|
|
|
|
old_pickle_load = pickle.load
|
|
|
|
def new_pickle_load(*args, **kwargs):
|
|
return pickle.Unpickler(*args, **kwargs).load()
|
|
|
|
pickle.load = new_pickle_load
|
|
yield
|
|
|
|
finally:
|
|
pickle.Unpickler = old_unpickler
|
|
pickle.load = old_pickle_load
|