From d152befe54e0f89ce44e7e6d10e1f7e6c6464c09 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 13 Dec 2023 20:23:19 +0300 Subject: [PATCH] IPEX Torch 2.1 --- environments/ipex.yml | 6 ++-- modeling/ipex/__init__.py | 33 +++++++---------- modeling/ipex/attention.py | 29 +++++---------- modeling/ipex/diffusers.py | 2 +- modeling/ipex/gradscaler.py | 6 +++- modeling/ipex/hijacks.py | 71 ++++++++++++++++++++++++++----------- 6 files changed, 80 insertions(+), 67 deletions(-) diff --git a/environments/ipex.yml b/environments/ipex.yml index 1eac5915..d322697b 100644 --- a/environments/ipex.yml +++ b/environments/ipex.yml @@ -25,10 +25,8 @@ dependencies: - ffmpeg - pip: - --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - - torch==2.0.1a0; sys_platform == 'linux' - - torch==2.0.0a0; sys_platform == 'win32' - - intel_extension_for_pytorch==2.0.110+xpu; sys_platform == 'linux' - - intel_extension_for_pytorch==2.0.110+gitba7f6c1; sys_platform == 'win32' + - torch==2.1.0a0 + - intel-extension-for-pytorch==2.1.10+xpu - openvino - onnxruntime-openvino - flask-cloudflared==0.0.10 diff --git a/modeling/ipex/__init__.py b/modeling/ipex/__init__.py index dc1985ed..c7854791 100644 --- a/modeling/ipex/__init__.py +++ b/modeling/ipex/__init__.py @@ -4,13 +4,12 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -91,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None @@ -113,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -124,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() @@ -139,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: + # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True @@ -157,20 +156,14 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - if hasattr(torch.xpu, 'getDeviceIdListForCard'): - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard - else: - torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card - torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() - attention_init() - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + if not torch.xpu.has_fp64_dtype(): + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass except Exception as e: return False, e return True, None diff --git a/modeling/ipex/attention.py b/modeling/ipex/attention.py index 52016466..ced59637 100644 --- a/modeling/ipex/attention.py +++ b/modeling/ipex/attention.py @@ -4,11 +4,8 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unuse # pylint: disable=protected-access, missing-function-docstring, line-too-long original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: +def torch_bmm_32_bit(input, mat2, *, out=None): + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] block_multiply = input.element_size() slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply @@ -17,7 +14,7 @@ def torch_bmm(input, mat2, *, out=None): split_slice_size = batch_size_attention if block_size > 4: do_split = True - #Find something divisible with the input_tokens + # Find something divisible with the input_tokens while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: @@ -30,7 +27,7 @@ def torch_bmm(input, mat2, *, out=None): if split_slice_size * slice_block_size > 4: slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply do_split_2 = True - #Find something divisible with the input_tokens + # Find something divisible with the input_tokens while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: @@ -64,8 +61,8 @@ def torch_bmm(input, mat2, *, out=None): return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: if len(query.shape) == 3: batch_size_attention, query_tokens, shape_four = query.shape shape_one = 1 @@ -74,11 +71,6 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - block_multiply = query.element_size() slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply block_size = batch_size_attention * slice_block_size @@ -86,7 +78,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. split_slice_size = batch_size_attention if block_size > 4: do_split = True - #Find something divisible with the shape_one + # Find something divisible with the shape_one while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: @@ -99,7 +91,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. if split_slice_size * slice_block_size > 4: slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply do_split_2 = True - #Find something divisible with the batch_size_attention + # Find something divisible with the batch_size_attention while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: @@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal ) return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/modeling/ipex/diffusers.py b/modeling/ipex/diffusers.py index 005ee49f..c32af507 100644 --- a/modeling/ipex/diffusers.py +++ b/modeling/ipex/diffusers.py @@ -1,6 +1,6 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error +import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention # pylint: disable=protected-access, missing-function-docstring, line-too-long diff --git a/modeling/ipex/gradscaler.py b/modeling/ipex/gradscaler.py index 53021210..6eb56bc2 100644 --- a/modeling/ipex/gradscaler.py +++ b/modeling/ipex/gradscaler.py @@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device ) diff --git a/modeling/ipex/hijacks.py b/modeling/ipex/hijacks.py index 88827837..fdb8dea9 100644 --- a/modeling/ipex/hijacks.py +++ b/modeling/ipex/hijacks.py @@ -92,7 +92,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -#Embedding BF16 +# Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -100,7 +100,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) -#Latent antialias: +# Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -120,6 +120,32 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +if torch.xpu.has_fp64_dtype(): + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 64 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + +# dtype errors: +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + +@property def is_cuda(self): return self.device.type == 'xpu' @@ -158,12 +184,12 @@ def ipex_hijacks(): lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) + if hasattr(torch.xpu, "Generator"): + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - - #TiledVAE and ControlNet: + # TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -175,46 +201,51 @@ def ipex_hijacks(): bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - #Functions with dtype errors: + # Functions with dtype errors: CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #Training: + # Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #BF16: + # BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - #SwinIR BF16: + # SwinIR BF16: CondFunc('torch.nn.functional.pad', lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): if not torch.xpu.has_fp64_dtype(): CondFunc('torch.from_numpy', lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), lambda orig_func, ndarray: ndarray.dtype == float) - #Broken functions when torch.cuda.is_available is True: - #Pin Memory: + # Broken functions when torch.cuda.is_available is True: + # Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) - #Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + # Functions that make compile mad with CondFunc: torch.nn.DataParallel = DummyDataParallel + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve - torch.UntypedStorage.is_cuda = is_cuda - torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context + torch.UntypedStorage.is_cuda = is_cuda + + torch.nn.functional.interpolate = interpolate + torch.linalg.solve = linalg_solve + + torch.bmm = torch_bmm + torch.cat = torch_cat + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention