Hide the warning about `torch.distributed.reduce_op` being deprecated
This commit is contained in:
parent
8c7ed92fef
commit
e879d1c5f3
|
@ -37,6 +37,7 @@ import bisect
|
||||||
import functools
|
import functools
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
|
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
|
||||||
|
|
||||||
|
@ -1958,10 +1959,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
model = None
|
model = None
|
||||||
generator = None
|
generator = None
|
||||||
model_config = None
|
model_config = None
|
||||||
|
with torch.no_grad():
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
|
||||||
for tensor in gc.get_objects():
|
for tensor in gc.get_objects():
|
||||||
try:
|
try:
|
||||||
if torch.is_tensor(tensor):
|
if torch.is_tensor(tensor):
|
||||||
with torch.no_grad():
|
|
||||||
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
|
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue