mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
IPEX fix SDPA and linalg solve
This commit is contained in:
@@ -4,160 +4,167 @@ import contextlib
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
|
from .attention import attention_init
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
def ipex_init(): # pylint: disable=too-many-statements
|
def ipex_init(): # pylint: disable=too-many-statements
|
||||||
#Replace cuda with xpu:
|
|
||||||
torch.cuda.current_device = torch.xpu.current_device
|
|
||||||
torch.cuda.current_stream = torch.xpu.current_stream
|
|
||||||
torch.cuda.device = torch.xpu.device
|
|
||||||
torch.cuda.device_count = torch.xpu.device_count
|
|
||||||
torch.cuda.device_of = torch.xpu.device_of
|
|
||||||
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
|
||||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
|
||||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
|
||||||
torch.cuda.init = torch.xpu.init
|
|
||||||
torch.cuda.is_available = torch.xpu.is_available
|
|
||||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
|
||||||
torch.cuda.is_current_stream_capturing = lambda: False
|
|
||||||
torch.cuda.set_device = torch.xpu.set_device
|
|
||||||
torch.cuda.stream = torch.xpu.stream
|
|
||||||
torch.cuda.synchronize = torch.xpu.synchronize
|
|
||||||
torch.cuda.Event = torch.xpu.Event
|
|
||||||
torch.cuda.Stream = torch.xpu.Stream
|
|
||||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
|
||||||
torch.Tensor.cuda = torch.Tensor.xpu
|
|
||||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
|
||||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
|
||||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
|
||||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
|
||||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
|
||||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
|
||||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
|
||||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
|
||||||
torch.cuda.Optional = torch.xpu.Optional
|
|
||||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
|
||||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
|
||||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
|
||||||
torch.cuda.Tuple = torch.xpu.Tuple
|
|
||||||
torch.cuda.streams = torch.xpu.streams
|
|
||||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
|
||||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
|
||||||
torch.cuda.Any = torch.xpu.Any
|
|
||||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
|
||||||
torch.cuda.default_generators = torch.xpu.default_generators
|
|
||||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
|
||||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
|
||||||
torch.cuda.__path__ = torch.xpu.__path__
|
|
||||||
torch.cuda.Device = torch.xpu.Device
|
|
||||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
|
||||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
|
||||||
torch.cuda.set_stream = torch.xpu.set_stream
|
|
||||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
|
||||||
torch.cuda.os = torch.xpu.os
|
|
||||||
torch.cuda.torch = torch.xpu.torch
|
|
||||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
|
||||||
torch.cuda.Union = torch.xpu.Union
|
|
||||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
|
||||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
|
||||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
|
||||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
|
||||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
|
||||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
|
||||||
torch.cuda.__package__ = torch.xpu.__package__
|
|
||||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
|
||||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
|
||||||
torch.cuda.List = torch.xpu.List
|
|
||||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
|
||||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
|
||||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
|
||||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
|
||||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
|
||||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
|
||||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
|
||||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
|
||||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
|
||||||
torch.cuda.random = torch.xpu.random
|
|
||||||
torch.cuda._device = torch.xpu._device
|
|
||||||
torch.cuda.classproperty = torch.xpu.classproperty
|
|
||||||
torch.cuda.__name__ = torch.xpu.__name__
|
|
||||||
torch.cuda._device_t = torch.xpu._device_t
|
|
||||||
torch.cuda.warnings = torch.xpu.warnings
|
|
||||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
|
||||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
|
||||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
|
||||||
torch.cuda.__file__ = torch.xpu.__file__
|
|
||||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
|
||||||
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
|
||||||
|
|
||||||
#Memory:
|
|
||||||
torch.cuda.memory = torch.xpu.memory
|
|
||||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
|
||||||
torch.xpu.empty_cache = lambda: None
|
|
||||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
|
||||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
|
||||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
|
||||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
|
||||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
|
||||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
|
||||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
|
||||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
|
||||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
|
||||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
|
||||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
|
||||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
|
||||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
|
||||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
|
||||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
|
||||||
|
|
||||||
#RNG:
|
|
||||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
|
||||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
|
||||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
|
||||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
|
||||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
|
||||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
|
||||||
torch.cuda.seed = torch.xpu.seed
|
|
||||||
torch.cuda.seed_all = torch.xpu.seed_all
|
|
||||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
|
||||||
|
|
||||||
#AMP:
|
|
||||||
torch.cuda.amp = torch.xpu.amp
|
|
||||||
if not hasattr(torch.cuda.amp, "common"):
|
|
||||||
torch.cuda.amp.common = contextlib.nullcontext()
|
|
||||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
|
||||||
try:
|
try:
|
||||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
#Replace cuda with xpu:
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
torch.cuda.current_device = torch.xpu.current_device
|
||||||
|
torch.cuda.current_stream = torch.xpu.current_stream
|
||||||
|
torch.cuda.device = torch.xpu.device
|
||||||
|
torch.cuda.device_count = torch.xpu.device_count
|
||||||
|
torch.cuda.device_of = torch.xpu.device_of
|
||||||
|
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
||||||
|
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||||
|
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||||
|
torch.cuda.init = torch.xpu.init
|
||||||
|
torch.cuda.is_available = torch.xpu.is_available
|
||||||
|
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||||
|
torch.cuda.is_current_stream_capturing = lambda: False
|
||||||
|
torch.cuda.set_device = torch.xpu.set_device
|
||||||
|
torch.cuda.stream = torch.xpu.stream
|
||||||
|
torch.cuda.synchronize = torch.xpu.synchronize
|
||||||
|
torch.cuda.Event = torch.xpu.Event
|
||||||
|
torch.cuda.Stream = torch.xpu.Stream
|
||||||
|
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||||
|
torch.Tensor.cuda = torch.Tensor.xpu
|
||||||
|
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||||
|
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||||
|
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||||
|
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||||
|
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||||
|
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||||
|
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||||
|
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||||
|
torch.cuda.Optional = torch.xpu.Optional
|
||||||
|
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||||
|
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||||
|
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||||
|
torch.cuda.Tuple = torch.xpu.Tuple
|
||||||
|
torch.cuda.streams = torch.xpu.streams
|
||||||
|
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||||
|
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||||
|
torch.cuda.Any = torch.xpu.Any
|
||||||
|
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||||
|
torch.cuda.default_generators = torch.xpu.default_generators
|
||||||
|
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||||
|
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||||
|
torch.cuda.__path__ = torch.xpu.__path__
|
||||||
|
torch.cuda.Device = torch.xpu.Device
|
||||||
|
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||||
|
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||||
|
torch.cuda.set_stream = torch.xpu.set_stream
|
||||||
|
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||||
|
torch.cuda.os = torch.xpu.os
|
||||||
|
torch.cuda.torch = torch.xpu.torch
|
||||||
|
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||||
|
torch.cuda.Union = torch.xpu.Union
|
||||||
|
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||||
|
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||||
|
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||||
|
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||||
|
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||||
|
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||||
|
torch.cuda.__package__ = torch.xpu.__package__
|
||||||
|
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||||
|
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||||
|
torch.cuda.List = torch.xpu.List
|
||||||
|
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||||
|
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||||
|
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||||
|
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||||
|
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||||
|
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||||
|
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||||
|
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||||
|
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||||
|
torch.cuda.random = torch.xpu.random
|
||||||
|
torch.cuda._device = torch.xpu._device
|
||||||
|
torch.cuda.classproperty = torch.xpu.classproperty
|
||||||
|
torch.cuda.__name__ = torch.xpu.__name__
|
||||||
|
torch.cuda._device_t = torch.xpu._device_t
|
||||||
|
torch.cuda.warnings = torch.xpu.warnings
|
||||||
|
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||||
|
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||||
|
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||||
|
torch.cuda.__file__ = torch.xpu.__file__
|
||||||
|
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||||
|
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||||
|
|
||||||
|
#Memory:
|
||||||
|
torch.cuda.memory = torch.xpu.memory
|
||||||
|
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||||
|
torch.xpu.empty_cache = lambda: None
|
||||||
|
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||||
|
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||||
|
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||||
|
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||||
|
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||||
|
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||||
|
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||||
|
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||||
|
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||||
|
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||||
|
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||||
|
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||||
|
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||||
|
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||||
|
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||||
|
|
||||||
|
#RNG:
|
||||||
|
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||||
|
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||||
|
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||||
|
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||||
|
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||||
|
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||||
|
torch.cuda.seed = torch.xpu.seed
|
||||||
|
torch.cuda.seed_all = torch.xpu.seed_all
|
||||||
|
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||||
|
|
||||||
|
#AMP:
|
||||||
|
torch.cuda.amp = torch.xpu.amp
|
||||||
|
if not hasattr(torch.cuda.amp, "common"):
|
||||||
|
torch.cuda.amp.common = contextlib.nullcontext()
|
||||||
|
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||||
try:
|
try:
|
||||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
|
||||||
gradscaler_init()
|
|
||||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
try:
|
||||||
|
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||||
|
gradscaler_init()
|
||||||
|
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||||
|
|
||||||
#C
|
#C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2023
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 2
|
||||||
|
|
||||||
#Fix functions with ipex:
|
#Fix functions with ipex:
|
||||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
|
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||||
torch._utils._get_available_device_type = lambda: "xpu"
|
torch._utils._get_available_device_type = lambda: "xpu"
|
||||||
torch.has_cuda = True
|
torch.has_cuda = True
|
||||||
torch.cuda.has_half = True
|
torch.cuda.has_half = True
|
||||||
torch.cuda.is_bf16_supported = True
|
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||||
torch.version.cuda = "11.7"
|
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||||
torch.cuda.get_device_capability = lambda: [11,7]
|
torch.version.cuda = "11.7"
|
||||||
torch.cuda.get_device_properties.major = 11
|
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
||||||
torch.cuda.get_device_properties.minor = 7
|
torch.cuda.get_device_properties.major = 11
|
||||||
torch.cuda.ipc_collect = lambda: None
|
torch.cuda.get_device_properties.minor = 7
|
||||||
torch.cuda.utilization = lambda: 0
|
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||||
|
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||||
|
|
||||||
ipex_hijacks()
|
ipex_hijacks()
|
||||||
try:
|
attention_init()
|
||||||
from .diffusers import ipex_diffusers # pylint: disable=import-outside-toplevel, import-error
|
try:
|
||||||
ipex_diffusers()
|
from .diffusers import ipex_diffusers
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
ipex_diffusers()
|
||||||
pass
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
return False, e
|
||||||
|
return True, None
|
||||||
|
128
modeling/ipex/attention.py
Normal file
128
modeling/ipex/attention.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import torch
|
||||||
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if input.dtype != mat2.dtype:
|
||||||
|
mat2 = mat2.to(input.dtype)
|
||||||
|
|
||||||
|
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
|
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||||
|
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
|
||||||
|
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
|
if block_size >= 4000:
|
||||||
|
do_split = True
|
||||||
|
#Find something divisible with the input_tokens
|
||||||
|
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
|
||||||
|
split_slice_size = split_slice_size // 2
|
||||||
|
if split_slice_size <= 1:
|
||||||
|
split_slice_size = 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
do_split = False
|
||||||
|
|
||||||
|
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
|
||||||
|
split_2_slice_size = input_tokens
|
||||||
|
if split_block_size >= 4000:
|
||||||
|
do_split_2 = True
|
||||||
|
#Find something divisible with the input_tokens
|
||||||
|
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
|
||||||
|
split_2_slice_size = split_2_slice_size // 2
|
||||||
|
if split_2_slice_size <= 1:
|
||||||
|
split_2_slice_size = 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
do_split_2 = False
|
||||||
|
|
||||||
|
if do_split:
|
||||||
|
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
|
if do_split_2:
|
||||||
|
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||||
|
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||||
|
input[start_idx:end_idx],
|
||||||
|
mat2[start_idx:end_idx],
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
|
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
|
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||||
|
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
||||||
|
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
|
if block_size >= 4000:
|
||||||
|
do_split = True
|
||||||
|
#Find something divisible with the shape_one
|
||||||
|
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
|
||||||
|
split_slice_size = split_slice_size // 2
|
||||||
|
if split_slice_size <= 1:
|
||||||
|
split_slice_size = 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
do_split = False
|
||||||
|
|
||||||
|
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
|
||||||
|
split_2_slice_size = query_tokens
|
||||||
|
if split_block_size >= 4000:
|
||||||
|
do_split_2 = True
|
||||||
|
#Find something divisible with the batch_size_attention
|
||||||
|
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
|
||||||
|
split_2_slice_size = split_2_slice_size // 2
|
||||||
|
if split_2_slice_size <= 1:
|
||||||
|
split_2_slice_size = 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
do_split_2 = False
|
||||||
|
|
||||||
|
if do_split:
|
||||||
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
|
if do_split_2:
|
||||||
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||||
|
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||||
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||||
|
query[:, start_idx:end_idx],
|
||||||
|
key[:, start_idx:end_idx],
|
||||||
|
value[:, start_idx:end_idx],
|
||||||
|
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||||
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return original_scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def attention_init():
|
||||||
|
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||||
|
torch.bmm = torch_bmm
|
||||||
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
@@ -1,12 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import torch.nn.functional as F # pylint: disable=ungrouped-imports
|
|
||||||
import diffusers #0.20.2 # pylint: disable=import-error
|
import diffusers #0.20.2 # pylint: disable=import-error
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
Attention = diffusers.models.attention_processor.Attention
|
|
||||||
|
|
||||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing sliced attention.
|
Processor for implementing sliced attention.
|
||||||
@@ -20,7 +17,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
def __init__(self, slice_size):
|
def __init__(self, slice_size):
|
||||||
self.slice_size = slice_size
|
self.slice_size = slice_size
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
def __call__(self, attn: diffusers.models.attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
input_ndim = hidden_states.ndim
|
||||||
@@ -116,147 +113,6 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
class AttnProcessor2_0: # pylint: disable=too-few-public-methods, invalid-name
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def __call__( # pylint: disable=too-many-arguments, too-many-statements, too-many-locals, too-many-branches
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
):
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
||||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
|
||||||
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
|
|
||||||
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
|
|
||||||
split_slice_size = batch_size_attention
|
|
||||||
if block_size >= 4000:
|
|
||||||
do_split = True
|
|
||||||
#Find something divisible with the shape_one
|
|
||||||
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
|
|
||||||
split_slice_size = split_slice_size // 2
|
|
||||||
if split_slice_size <= 1:
|
|
||||||
split_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split = False
|
|
||||||
|
|
||||||
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
|
|
||||||
split_2_slice_size = query_tokens
|
|
||||||
if split_block_size >= 4000:
|
|
||||||
do_split_2 = True
|
|
||||||
#Find something divisible with the batch_size_attention
|
|
||||||
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
|
|
||||||
if do_split:
|
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
|
||||||
start_idx = i * split_slice_size
|
|
||||||
end_idx = (i + 1) * split_slice_size
|
|
||||||
if do_split_2:
|
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
||||||
|
|
||||||
query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = F.scaled_dot_product_attention(
|
|
||||||
query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
|
||||||
else:
|
|
||||||
query_slice = query[:, start_idx:end_idx]
|
|
||||||
key_slice = key[:, start_idx:end_idx]
|
|
||||||
attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = F.scaled_dot_product_attention(
|
|
||||||
query_slice, key_slice, value[:, start_idx:end_idx],
|
|
||||||
attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
hidden_states[:, start_idx:end_idx] = attn_slice
|
|
||||||
else:
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def ipex_diffusers():
|
def ipex_diffusers():
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||||
diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0
|
|
||||||
|
@@ -34,7 +34,7 @@ class CondFunc: # pylint: disable=missing-class-docstring
|
|||||||
|
|
||||||
_utils = torch.utils.data._utils
|
_utils = torch.utils.data._utils
|
||||||
def _shutdown_workers(self):
|
def _shutdown_workers(self):
|
||||||
if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
|
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
||||||
return
|
return
|
||||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
if hasattr(self, "_shutdown") and not self._shutdown:
|
||||||
self._shutdown = True
|
self._shutdown = True
|
||||||
@@ -50,13 +50,13 @@ def _shutdown_workers(self):
|
|||||||
if self._persistent_workers or self._workers_status[worker_id]:
|
if self._persistent_workers or self._workers_status[worker_id]:
|
||||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
for w in self._workers: # pylint: disable=invalid-name
|
||||||
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
||||||
for q in self._index_queues: # pylint: disable=invalid-name
|
for q in self._index_queues: # pylint: disable=invalid-name
|
||||||
q.cancel_join_thread()
|
q.cancel_join_thread()
|
||||||
q.close()
|
q.close()
|
||||||
finally:
|
finally:
|
||||||
if self._worker_pids_set:
|
if self._worker_pids_set:
|
||||||
_utils.signal_handling._remove_worker_pids(id(self))
|
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
||||||
self._worker_pids_set = False
|
self._worker_pids_set = False
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
for w in self._workers: # pylint: disable=invalid-name
|
||||||
if w.is_alive():
|
if w.is_alive():
|
||||||
@@ -75,7 +75,7 @@ def check_device(device):
|
|||||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||||
|
|
||||||
def return_xpu(device):
|
def return_xpu(device):
|
||||||
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||||
|
|
||||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
def ipex_no_cuda(orig_func, *args, **kwargs):
|
||||||
torch.cuda.is_available = lambda: False
|
torch.cuda.is_available = lambda: False
|
||||||
@@ -84,7 +84,7 @@ def ipex_no_cuda(orig_func, *args, **kwargs):
|
|||||||
|
|
||||||
original_autocast = torch.autocast
|
original_autocast = torch.autocast
|
||||||
def ipex_autocast(*args, **kwargs):
|
def ipex_autocast(*args, **kwargs):
|
||||||
if args[0] == "cuda" or args[0] == "xpu":
|
if len(args) > 0 and (args[0] == "cuda" or args[0] == "xpu"):
|
||||||
if "dtype" in kwargs:
|
if "dtype" in kwargs:
|
||||||
return original_autocast("xpu", *args[1:], **kwargs)
|
return original_autocast("xpu", *args[1:], **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -114,9 +114,9 @@ original_linalg_solve = torch.linalg.solve
|
|||||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
||||||
return_device = A.device
|
return_device = A.device
|
||||||
original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
||||||
else:
|
else:
|
||||||
original_linalg_solve(A, B, *args, **kwargs)
|
return original_linalg_solve(A, B, *args, **kwargs)
|
||||||
|
|
||||||
def ipex_hijacks():
|
def ipex_hijacks():
|
||||||
CondFunc('torch.Tensor.to',
|
CondFunc('torch.Tensor.to',
|
||||||
@@ -169,9 +169,9 @@ def ipex_hijacks():
|
|||||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
CondFunc('torch.bmm',
|
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||||
lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs),
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype)
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
CondFunc('torch.nn.functional.layer_norm',
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||||
|
Reference in New Issue
Block a user