Fix for lazy loader in PyTorch 1.12
There is no `torch._StorageBase` in PyTorch 1.12, but otherwise it still works.
This commit is contained in:
parent
dd6da50e58
commit
39d48495ce
|
@ -51,7 +51,7 @@ import zipfile
|
|||
import pickle
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
@ -72,7 +72,7 @@ STORAGE_TYPE_MAP = {
|
|||
|
||||
|
||||
class LazyTensor:
|
||||
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.location = location
|
||||
|
|
Loading…
Reference in New Issue