diff --git a/environments/ipex.yml b/environments/ipex.yml index c9794e48..1eac5915 100644 --- a/environments/ipex.yml +++ b/environments/ipex.yml @@ -29,15 +29,18 @@ dependencies: - torch==2.0.0a0; sys_platform == 'win32' - intel_extension_for_pytorch==2.0.110+xpu; sys_platform == 'linux' - intel_extension_for_pytorch==2.0.110+gitba7f6c1; sys_platform == 'win32' - - intel-extension-for-transformers + - openvino + - onnxruntime-openvino - flask-cloudflared==0.0.10 - flask-ngrok - flask-cors - Werkzeug==2.3.7 - lupa==1.10 - transformers[sentencepiece]==4.34.0 + - intel-extension-for-transformers - huggingface_hub==0.16.4 - - optimum[onnxruntime,openvino,nncf,neural-compressor]==1.13.2 + - optimum[onnxruntime]==1.13.2 + - optimum-intel - safetensors==0.3.3 - accelerate==0.21.0 - git+https://github.com/VE-FORBRYDERNE/mkultra diff --git a/modeling/ipex/__init__.py b/modeling/ipex/__init__.py index 43accd9f..dc1985ed 100644 --- a/modeling/ipex/__init__.py +++ b/modeling/ipex/__init__.py @@ -30,6 +30,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker diff --git a/modeling/ipex/attention.py b/modeling/ipex/attention.py index 84848b6a..52016466 100644 --- a/modeling/ipex/attention.py +++ b/modeling/ipex/attention.py @@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + block_multiply = query.element_size() slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply block_size = batch_size_attention * slice_block_size diff --git a/modeling/ipex/gradscaler.py b/modeling/ipex/gradscaler.py new file mode 100644 index 00000000..53021210 --- /dev/null +++ b/modeling/ipex/gradscaler.py @@ -0,0 +1,179 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + +def update(self, new_scale=None): + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device="cpu", non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + to_device = _scale.device + _scale = _scale.to("cpu") + _growth_tracker = _growth_tracker.to("cpu") + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + _scale = _scale.to(to_device) + _growth_tracker = _growth_tracker.to(to_device) + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + +def gradscaler_init(): + torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ + torch.xpu.amp.GradScaler.unscale_ = unscale_ + torch.xpu.amp.GradScaler.update = update + return torch.xpu.amp.GradScaler diff --git a/modeling/ipex/hijacks.py b/modeling/ipex/hijacks.py index cf0f2233..88827837 100644 --- a/modeling/ipex/hijacks.py +++ b/modeling/ipex/hijacks.py @@ -92,6 +92,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) +#Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -99,6 +100,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) +#Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -118,19 +120,28 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +def is_cuda(self): + return self.device.type == 'xpu' + def ipex_hijacks(): + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.Tensor.to', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.Tensor.cuda', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.__init__', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.empty', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) CondFunc('torch.randn', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) @@ -140,17 +151,19 @@ def ipex_hijacks(): CondFunc('torch.zeros', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.linspace', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: + orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + #TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -166,17 +179,23 @@ def ipex_hijacks(): CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + #Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + #BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) + #SwinIR BF16: + CondFunc('torch.nn.functional.pad', + lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), + lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) #Diffusers Float64 (ARC GPUs doesn't support double or Float64): if not torch.xpu.has_fp64_dtype(): @@ -185,6 +204,7 @@ def ipex_hijacks(): lambda orig_func, ndarray: ndarray.dtype == float) #Broken functions when torch.cuda.is_available is True: + #Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) @@ -195,5 +215,6 @@ def ipex_hijacks(): torch.autocast = ipex_autocast torch.cat = torch_cat torch.linalg.solve = linalg_solve + torch.UntypedStorage.is_cuda = is_cuda torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context