mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
IPEX Torch 2.1
This commit is contained in:
@@ -25,10 +25,8 @@ dependencies:
|
|||||||
- ffmpeg
|
- ffmpeg
|
||||||
- pip:
|
- pip:
|
||||||
- --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
- --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
- torch==2.0.1a0; sys_platform == 'linux'
|
- torch==2.1.0a0
|
||||||
- torch==2.0.0a0; sys_platform == 'win32'
|
- intel-extension-for-pytorch==2.1.10+xpu
|
||||||
- intel_extension_for_pytorch==2.0.110+xpu; sys_platform == 'linux'
|
|
||||||
- intel_extension_for_pytorch==2.0.110+gitba7f6c1; sys_platform == 'win32'
|
|
||||||
- openvino
|
- openvino
|
||||||
- onnxruntime-openvino
|
- onnxruntime-openvino
|
||||||
- flask-cloudflared==0.0.10
|
- flask-cloudflared==0.0.10
|
||||||
|
@@ -4,7 +4,6 @@ import contextlib
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
from .attention import attention_init
|
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
@@ -157,15 +156,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.get_device_properties.minor = 7
|
torch.cuda.get_device_properties.minor = 7
|
||||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||||
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
|
|
||||||
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
|
||||||
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
|
|
||||||
else:
|
|
||||||
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
|
|
||||||
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
|
|
||||||
|
|
||||||
ipex_hijacks()
|
ipex_hijacks()
|
||||||
attention_init()
|
if not torch.xpu.has_fp64_dtype():
|
||||||
try:
|
try:
|
||||||
from .diffusers import ipex_diffusers
|
from .diffusers import ipex_diffusers
|
||||||
ipex_diffusers()
|
ipex_diffusers()
|
||||||
|
@@ -4,10 +4,7 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unuse
|
|||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
original_torch_bmm = torch.bmm
|
original_torch_bmm = torch.bmm
|
||||||
def torch_bmm(input, mat2, *, out=None):
|
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||||
if input.dtype != mat2.dtype:
|
|
||||||
mat2 = mat2.to(input.dtype)
|
|
||||||
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||||
block_multiply = input.element_size()
|
block_multiply = input.element_size()
|
||||||
@@ -64,7 +61,7 @@ def torch_bmm(input, mat2, *, out=None):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
if len(query.shape) == 3:
|
if len(query.shape) == 3:
|
||||||
batch_size_attention, query_tokens, shape_four = query.shape
|
batch_size_attention, query_tokens, shape_four = query.shape
|
||||||
@@ -74,11 +71,6 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||||
no_shape_one = False
|
no_shape_one = False
|
||||||
|
|
||||||
if query.dtype != key.dtype:
|
|
||||||
key = key.to(dtype=query.dtype)
|
|
||||||
if query.dtype != value.dtype:
|
|
||||||
value = value.to(dtype=query.dtype)
|
|
||||||
|
|
||||||
block_multiply = query.element_size()
|
block_multiply = query.element_size()
|
||||||
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
||||||
block_size = batch_size_attention * slice_block_size
|
block_size = batch_size_attention * slice_block_size
|
||||||
@@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_init():
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
|
||||||
torch.bmm = torch_bmm
|
|
||||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import diffusers #0.21.1 # pylint: disable=import-error
|
import diffusers #0.24.0 # pylint: disable=import-error
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
|||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
|
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||||
@@ -96,6 +97,9 @@ def unscale_(self, optimizer):
|
|||||||
|
|
||||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
assert self._scale is not None
|
assert self._scale is not None
|
||||||
|
if device_supports_fp64:
|
||||||
|
inv_scale = self._scale.double().reciprocal().float()
|
||||||
|
else:
|
||||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||||
found_inf = torch.full(
|
found_inf = torch.full(
|
||||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||||
|
@@ -120,6 +120,32 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
|||||||
else:
|
else:
|
||||||
return original_linalg_solve(A, B, *args, **kwargs)
|
return original_linalg_solve(A, B, *args, **kwargs)
|
||||||
|
|
||||||
|
if torch.xpu.has_fp64_dtype():
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
else:
|
||||||
|
# 64 bit attention workarounds for Alchemist:
|
||||||
|
try:
|
||||||
|
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||||
|
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
|
# dtype errors:
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if input.dtype != mat2.dtype:
|
||||||
|
mat2 = mat2.to(input.dtype)
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
|
if query.dtype != key.dtype:
|
||||||
|
key = key.to(dtype=query.dtype)
|
||||||
|
if query.dtype != value.dtype:
|
||||||
|
value = value.to(dtype=query.dtype)
|
||||||
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
|
|
||||||
|
@property
|
||||||
def is_cuda(self):
|
def is_cuda(self):
|
||||||
return self.device.type == 'xpu'
|
return self.device.type == 'xpu'
|
||||||
|
|
||||||
@@ -158,7 +184,7 @@ def ipex_hijacks():
|
|||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
||||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
||||||
|
if hasattr(torch.xpu, "Generator"):
|
||||||
CondFunc('torch.Generator',
|
CondFunc('torch.Generator',
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
||||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
||||||
@@ -197,7 +223,7 @@ def ipex_hijacks():
|
|||||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
||||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
||||||
|
|
||||||
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
|
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||||
if not torch.xpu.has_fp64_dtype():
|
if not torch.xpu.has_fp64_dtype():
|
||||||
CondFunc('torch.from_numpy',
|
CondFunc('torch.from_numpy',
|
||||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
||||||
@@ -210,11 +236,16 @@ def ipex_hijacks():
|
|||||||
lambda orig_func, *args, **kwargs: True)
|
lambda orig_func, *args, **kwargs: True)
|
||||||
|
|
||||||
# Functions that make compile mad with CondFunc:
|
# Functions that make compile mad with CondFunc:
|
||||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
|
||||||
torch.nn.DataParallel = DummyDataParallel
|
torch.nn.DataParallel = DummyDataParallel
|
||||||
|
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
||||||
|
|
||||||
torch.autocast = ipex_autocast
|
torch.autocast = ipex_autocast
|
||||||
torch.cat = torch_cat
|
|
||||||
torch.linalg.solve = linalg_solve
|
|
||||||
torch.UntypedStorage.is_cuda = is_cuda
|
|
||||||
torch.nn.functional.interpolate = interpolate
|
|
||||||
torch.backends.cuda.sdp_kernel = return_null_context
|
torch.backends.cuda.sdp_kernel = return_null_context
|
||||||
|
torch.UntypedStorage.is_cuda = is_cuda
|
||||||
|
|
||||||
|
torch.nn.functional.interpolate = interpolate
|
||||||
|
torch.linalg.solve = linalg_solve
|
||||||
|
|
||||||
|
torch.bmm = torch_bmm
|
||||||
|
torch.cat = torch_cat
|
||||||
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||||
|
Reference in New Issue
Block a user