Fix for lazy loader in PyTorch 1.12

There is no `torch._StorageBase` in PyTorch 1.12, but otherwise it still
works.
This commit is contained in:
vfbd 2022-07-12 16:48:01 -04:00
parent dd6da50e58
commit 39d48495ce
1 changed files with 2 additions and 2 deletions

View File

@ -51,7 +51,7 @@ import zipfile
import pickle import pickle
import torch import torch
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@ -72,7 +72,7 @@ STORAGE_TYPE_MAP = {
class LazyTensor: class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type self.storage_type = storage_type
self.key = key self.key = key
self.location = location self.location = location