From f5bdd78e2d3ccb660bd19418c6329ada209030e9 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sat, 9 Sep 2023 22:21:16 +0300 Subject: [PATCH] IPEX fix SDPA and linalg solve --- modeling/ipex/__init__.py | 299 +++++++++++++++++++------------------ modeling/ipex/attention.py | 128 ++++++++++++++++ modeling/ipex/diffusers.py | 146 +----------------- modeling/ipex/hijacks.py | 20 +-- 4 files changed, 292 insertions(+), 301 deletions(-) create mode 100644 modeling/ipex/attention.py diff --git a/modeling/ipex/__init__.py b/modeling/ipex/__init__.py index 39415396..9ec69012 100644 --- a/modeling/ipex/__init__.py +++ b/modeling/ipex/__init__.py @@ -4,160 +4,167 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks +from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements - #Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - - #Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - - #RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed - - #AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda: None - torch.cuda.utilization = lambda: 0 + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - try: - from .diffusers import ipex_diffusers # pylint: disable=import-outside-toplevel, import-error - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + attention_init() + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + except Exception as e: + return False, e + return True, None diff --git a/modeling/ipex/attention.py b/modeling/ipex/attention.py new file mode 100644 index 00000000..d7335bfa --- /dev/null +++ b/modeling/ipex/attention.py @@ -0,0 +1,128 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +original_torch_bmm = torch.bmm +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] + block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 + block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the input_tokens + while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_2_slice_size = input_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the input_tokens + while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the shape_one + while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the batch_size_attention + while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx, start_idx_2:end_idx_2], + key[:, start_idx:end_idx, start_idx_2:end_idx_2], + value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx], + key[:, start_idx:end_idx], + value[:, start_idx:end_idx], + attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + return original_scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + return hidden_states + +def attention_init(): + #ARC GPUs can't allocate more than 4GB to a single block: + torch.bmm = torch_bmm + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/modeling/ipex/diffusers.py b/modeling/ipex/diffusers.py index 18563b06..3435abe1 100644 --- a/modeling/ipex/diffusers.py +++ b/modeling/ipex/diffusers.py @@ -1,12 +1,9 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import torch.nn.functional as F # pylint: disable=ungrouped-imports import diffusers #0.20.2 # pylint: disable=import-error # pylint: disable=protected-access, missing-function-docstring, line-too-long -Attention = diffusers.models.attention_processor.Attention - class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -20,7 +17,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: diffusers.models.attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches residual = hidden_states input_ndim = hidden_states.ndim @@ -116,147 +113,6 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods return hidden_states -class AttnProcessor2_0: # pylint: disable=too-few-public-methods, invalid-name - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( # pylint: disable=too-many-arguments, too-many-statements, too-many-locals, too-many-branches - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 - block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB - split_slice_size = batch_size_attention - if block_size >= 4000: - do_split = True - #Find something divisible with the shape_one - while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB - split_2_slice_size = query_tokens - if split_block_size >= 4000: - do_split_2 = True - #Find something divisible with the batch_size_attention - while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - if do_split: - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - - query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = F.scaled_dot_product_attention( - query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False - ) - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - else: - query_slice = query[:, start_idx:end_idx] - key_slice = key[:, start_idx:end_idx] - attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None - - attn_slice = F.scaled_dot_product_attention( - query_slice, key_slice, value[:, start_idx:end_idx], - attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False - ) - hidden_states[:, start_idx:end_idx] = attn_slice - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0 diff --git a/modeling/ipex/hijacks.py b/modeling/ipex/hijacks.py index 8dd619c4..cf0f2233 100644 --- a/modeling/ipex/hijacks.py +++ b/modeling/ipex/hijacks.py @@ -34,7 +34,7 @@ class CondFunc: # pylint: disable=missing-class-docstring _utils = torch.utils.data._utils def _shutdown_workers(self): - if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: return if hasattr(self, "_shutdown") and not self._shutdown: self._shutdown = True @@ -50,13 +50,13 @@ def _shutdown_workers(self): if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable(worker_id, shutdown=True) for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) for q in self._index_queues: # pylint: disable=invalid-name q.cancel_join_thread() q.close() finally: if self._worker_pids_set: - _utils.signal_handling._remove_worker_pids(id(self)) + torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False for w in self._workers: # pylint: disable=invalid-name if w.is_alive(): @@ -75,7 +75,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" def ipex_no_cuda(orig_func, *args, **kwargs): torch.cuda.is_available = lambda: False @@ -84,7 +84,7 @@ def ipex_no_cuda(orig_func, *args, **kwargs): original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): - if args[0] == "cuda" or args[0] == "xpu": + if len(args) > 0 and (args[0] == "cuda" or args[0] == "xpu"): if "dtype" in kwargs: return original_autocast("xpu", *args[1:], **kwargs) else: @@ -114,9 +114,9 @@ original_linalg_solve = torch.linalg.solve def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name if A.device != torch.device("cpu") or B.device != torch.device("cpu"): return_device = A.device - original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) else: - original_linalg_solve(A, B, *args, **kwargs) + return original_linalg_solve(A, B, *args, **kwargs) def ipex_hijacks(): CondFunc('torch.Tensor.to', @@ -169,9 +169,9 @@ def ipex_hijacks(): CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.bmm', - lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs), - lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),