Hide the warning about `torch.distributed.reduce_op` being deprecated

This commit is contained in:
vfbd 2022-08-11 18:42:56 -04:00
parent 8c7ed92fef
commit e879d1c5f3
1 changed files with 10 additions and 7 deletions

View File

@ -37,6 +37,7 @@ import bisect
import functools import functools
import traceback import traceback
import inspect import inspect
import warnings
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
@ -1958,10 +1959,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
model = None model = None
generator = None generator = None
model_config = None model_config = None
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects(): for tensor in gc.get_objects():
try: try:
if torch.is_tensor(tensor): if torch.is_tensor(tensor):
with torch.no_grad():
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype)) tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except: except:
pass pass