mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
IPEX fix SDPA and linalg solve
This commit is contained in:
@@ -4,10 +4,12 @@ import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from .hijacks import ipex_hijacks
|
||||
from .attention import attention_init
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
#Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
@@ -147,17 +149,22 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "11.7"
|
||||
torch.cuda.get_device_capability = lambda: [11,7]
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
||||
torch.cuda.get_device_properties.major = 11
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.cuda.ipc_collect = lambda: None
|
||||
torch.cuda.utilization = lambda: 0
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
attention_init()
|
||||
try:
|
||||
from .diffusers import ipex_diffusers # pylint: disable=import-outside-toplevel, import-error
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
|
128
modeling/ipex/attention.py
Normal file
128
modeling/ipex/attention.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
|
||||
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size >= 4000:
|
||||
do_split = True
|
||||
#Find something divisible with the input_tokens
|
||||
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
||||
split_2_slice_size = input_tokens
|
||||
if split_block_size >= 4000:
|
||||
do_split_2 = True
|
||||
#Find something divisible with the input_tokens
|
||||
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
if do_split:
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
||||
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size >= 4000:
|
||||
do_split = True
|
||||
#Find something divisible with the shape_one
|
||||
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||
split_2_slice_size = query_tokens
|
||||
if split_block_size >= 4000:
|
||||
do_split_2 = True
|
||||
#Find something divisible with the batch_size_attention
|
||||
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
if do_split:
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx],
|
||||
key[:, start_idx:end_idx],
|
||||
value[:, start_idx:end_idx],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def attention_init():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
torch.bmm = torch_bmm
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
@@ -1,12 +1,9 @@
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import torch.nn.functional as F # pylint: disable=ungrouped-imports
|
||||
import diffusers #0.20.2 # pylint: disable=import-error
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
Attention = diffusers.models.attention_processor.Attention
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
@@ -20,7 +17,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
def __call__(self, attn: diffusers.models.attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
@@ -116,147 +113,6 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
|
||||
return hidden_states
|
||||
|
||||
class AttnProcessor2_0: # pylint: disable=too-few-public-methods, invalid-name
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__( # pylint: disable=too-many-arguments, too-many-statements, too-many-locals, too-many-branches
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
||||
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size >= 4000:
|
||||
do_split = True
|
||||
#Find something divisible with the shape_one
|
||||
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||
split_2_slice_size = query_tokens
|
||||
if split_block_size >= 4000:
|
||||
do_split_2 = True
|
||||
#Find something divisible with the batch_size_attention
|
||||
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
if do_split:
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
|
||||
query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = F.scaled_dot_product_attention(
|
||||
query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
else:
|
||||
query_slice = query[:, start_idx:end_idx]
|
||||
key_slice = key[:, start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = F.scaled_dot_product_attention(
|
||||
query_slice, key_slice, value[:, start_idx:end_idx],
|
||||
attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states[:, start_idx:end_idx] = attn_slice
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0
|
||||
|
@@ -34,7 +34,7 @@ class CondFunc: # pylint: disable=missing-class-docstring
|
||||
|
||||
_utils = torch.utils.data._utils
|
||||
def _shutdown_workers(self):
|
||||
if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
|
||||
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
||||
return
|
||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
||||
self._shutdown = True
|
||||
@@ -50,13 +50,13 @@ def _shutdown_workers(self):
|
||||
if self._persistent_workers or self._workers_status[worker_id]:
|
||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
||||
for w in self._workers: # pylint: disable=invalid-name
|
||||
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||||
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
||||
for q in self._index_queues: # pylint: disable=invalid-name
|
||||
q.cancel_join_thread()
|
||||
q.close()
|
||||
finally:
|
||||
if self._worker_pids_set:
|
||||
_utils.signal_handling._remove_worker_pids(id(self))
|
||||
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
||||
self._worker_pids_set = False
|
||||
for w in self._workers: # pylint: disable=invalid-name
|
||||
if w.is_alive():
|
||||
@@ -75,7 +75,7 @@ def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
||||
torch.cuda.is_available = lambda: False
|
||||
@@ -84,7 +84,7 @@ def ipex_no_cuda(orig_func, *args, **kwargs):
|
||||
|
||||
original_autocast = torch.autocast
|
||||
def ipex_autocast(*args, **kwargs):
|
||||
if args[0] == "cuda" or args[0] == "xpu":
|
||||
if len(args) > 0 and (args[0] == "cuda" or args[0] == "xpu"):
|
||||
if "dtype" in kwargs:
|
||||
return original_autocast("xpu", *args[1:], **kwargs)
|
||||
else:
|
||||
@@ -114,9 +114,9 @@ original_linalg_solve = torch.linalg.solve
|
||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
||||
return_device = A.device
|
||||
original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
||||
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
||||
else:
|
||||
original_linalg_solve(A, B, *args, **kwargs)
|
||||
return original_linalg_solve(A, B, *args, **kwargs)
|
||||
|
||||
def ipex_hijacks():
|
||||
CondFunc('torch.Tensor.to',
|
||||
@@ -169,9 +169,9 @@ def ipex_hijacks():
|
||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.bmm',
|
||||
lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs),
|
||||
lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype)
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
|
Reference in New Issue
Block a user