Hide the warning about `torch.distributed.reduce_op` being deprecated

This commit is contained in:
vfbd 2022-08-11 18:42:56 -04:00
parent 8c7ed92fef
commit e879d1c5f3
1 changed files with 10 additions and 7 deletions

View File

@ -37,6 +37,7 @@ import bisect
import functools
import traceback
import inspect
import warnings
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
@ -1958,13 +1959,15 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
model = None
generator = None
model_config = None
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
with torch.no_grad():
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
gc.collect()
try:
torch.cuda.empty_cache()