Compare commits

..

4 Commits

Author SHA1 Message Date
henk717 025db3bd04
Merge pull request #138 from VE-FORBRYDERNE/lazy-loader
Fix for lazy loader in PyTorch 1.12
2022-07-12 23:02:58 +02:00
henk717 836759d826
Merge pull request #137 from VE-FORBRYDERNE/jaxlib
TPU Colab hotfix
2022-07-12 23:02:40 +02:00
vfbd 39d48495ce Fix for lazy loader in PyTorch 1.12
There is no `torch._StorageBase` in PyTorch 1.12, but otherwise it still
works.
2022-07-12 16:48:01 -04:00
vfbd 70aa182671 Restrict jaxlib version in TPU Colabs 2022-07-12 16:30:26 -04:00
2 changed files with 3 additions and 2 deletions

View File

@ -5,6 +5,7 @@ requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5
jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
transformers >= 4.19
progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck

View File

@ -51,7 +51,7 @@ import zipfile
import pickle
import torch
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@ -72,7 +72,7 @@ STORAGE_TYPE_MAP = {
class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type
self.key = key
self.location = location