diff --git a/README.md b/README.md index 789b78d1..853c9575 100644 --- a/README.md +++ b/README.md @@ -128,8 +128,11 @@ The easiest way for Windows users is to use the [offline installer](https://sour ### Installing KoboldAI on Linux using the KoboldAI Runtime (Easiest) 1. Clone the URL of this Github repository (For example git clone [https://github.com/koboldai/koboldai-client](https://github.com/koboldai/koboldai-client) ) -2. AMD user? Make sure ROCm is installed if you want GPU support. Is yours not compatible with ROCm? Follow the usual instructions. -3. Run play.sh or if your AMD GPU supports ROCm use play-rocm.sh +2. AMD user? Make sure ROCm is installed if you want GPU support. Is yours not compatible with ROCm? Follow the usual instructions. + Intel ARC user? Make sure OneAPI is installed if you want GPU support. +3. Run play.sh if you use an Nvidia GPU or you want to use CPU only + Run play-rocm.sh if you use an AMD GPU supported by ROCm + Run play-ipex.sh if you use an Intel ARC GPU KoboldAI will now automatically configure its dependencies and start up, everything is contained in its own conda runtime so we will not clutter your system. The files will be located in the runtime subfolder. If at any point you wish to force a reinstallation of the runtime you can do so with the install\_requirements.sh file. While you can run this manually it is not neccesary. @@ -148,6 +151,10 @@ If you would like to manually install KoboldAI you will need some python/conda p AMD GPU's have terrible compute support, this will currently not work on Windows and will only work for a select few Linux GPU's. [You can find a list of the compatible GPU's here](https://github.com/RadeonOpenCompute/ROCm#Hardware-and-Software-Support). Any GPU that is not listed is guaranteed not to work with KoboldAI and we will not be able to provide proper support on GPU's that are not compatible with the versions of ROCm we require. Make sure to first install ROCm on your Linux system using a guide for your distribution, after that you can follow the usual linux instructions above. +### Intel ARC GPU's (Linux or WSL) + +Make sure to first install OneAPI on your Linux or WSL system using a guide for your distribution, after that you can follow the usual linux instructions above. + ### Troubleshooting There are multiple things that can go wrong with the way Python handles its dependencies, unfortunately we do not have direct step by step solutions for every scenario but there are a few common solutions you can try. diff --git a/aiserver.py b/aiserver.py index 34e2334a..1c3bbaed 100644 --- a/aiserver.py +++ b/aiserver.py @@ -74,6 +74,13 @@ from utils import debounce import utils import koboldai_settings import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from modeling.ipex import ipex_init + ipex_init() +except Exception: + pass from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForTokenClassification import transformers import ipaddress @@ -930,7 +937,7 @@ tags = [ api_version = None # This gets set automatically so don't change this value api_v1 = KoboldAPISpec( - version="1.2.4", + version="1.2.5", prefixes=["/api/v1", "/api/latest"], tags=tags, ) @@ -8146,6 +8153,8 @@ def permutation_validator(lst: list): return True class GenerationInputSchema(SamplerSettingsSchema): + class Meta: + unknown = EXCLUDE # Doing it on this level is not a deliberate design choice on our part, it doesn't work nested... - Henk prompt: str = fields.String(required=True, metadata={"description": "This is the submission."}) use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."}) use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text."}) @@ -8153,7 +8162,7 @@ class GenerationInputSchema(SamplerSettingsSchema): use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text."}) use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text."}) soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")]) - max_length: int = fields.Integer(validate=validate.Range(min=1, max=512), metadata={"description": "Number of tokens to generate."}) + max_length: int = fields.Integer(validate=validate.Range(min=1), metadata={"description": "Number of tokens to generate."}) max_context_length: int = fields.Integer(validate=validate.Range(min=1), metadata={"description": "Maximum number of tokens to send to the model."}) n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."}) disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all output formatting options default to `false` instead of the value in the KoboldAI GUI."}) @@ -8168,7 +8177,7 @@ class GenerationInputSchema(SamplerSettingsSchema): sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."}) sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."}) sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."}) - stop_sequence: Optional[List[str]] = fields.List(fields.String(),metadata={"description": "An array of string sequences where the API will stop generating further tokens. The returned text WILL contain the stop sequence."}, validate=[validate.Length(max=10)]) + stop_sequence: Optional[List[str]] = fields.List(fields.String(),metadata={"description": "An array of string sequences where the API will stop generating further tokens. The returned text WILL contain the stop sequence."}) class GenerationResultSchema(KoboldSchema): diff --git a/api_example.py b/api_example.py new file mode 100644 index 00000000..aa35f883 --- /dev/null +++ b/api_example.py @@ -0,0 +1,57 @@ +import requests + +user = "User:" +bot = "Bot:" +ENDPOINT = "http://localhost:5000/api" +conversation_history = [] # using a list to update conversation history is more memory efficient than constantly updating a string + +def get_prompt(user_msg): + return { + "prompt": f"{user_msg}", + "use_story": "False", # Use the story from the KoboldAI UI, can be managed using other API calls (See /api for the documentation) + "use_memory": "False", # Use the memnory from the KoboldAI UI, can be managed using other API calls (See /api for the documentation) + "use_authors_note": "False", # Use the authors notes from the KoboldAI UI, can be managed using other API calls (See /api for the documentation) + "use_world_info": "False", # Use the World Info from the KoboldAI UI, can be managed using other API calls (See /api for the documentation) + "max_context_length": 2048, # How much of the prompt will we submit to the AI generator? (Prevents AI / memory overloading) + "max_length": 100, # How long should the response be? + "rep_pen": 1.1, # Prevent the AI from repeating itself + "rep_pen_range": 2048, # The range to which to apply the previous + "rep_pen_slope": 0.7, # This number determains the strength of the repetition penalty over time + "temperature": 0.5, # How random should the AI be? In a low value we pick the most probable token, high values are a dice roll + "tfs": 0.97, # Tail free sampling, https://www.trentonbricken.com/Tail-Free-Sampling/ + "top_a": 0.0, # Top A sampling , https://github.com/BlinkDL/RWKV-LM/tree/4cb363e5aa31978d801a47bc89d28e927ab6912e#the-top-a-sampling-method + "top_k": 0, # Keep the X most probable tokens + "top_p": 0.9, # Top P sampling / Nucleus Sampling, https://arxiv.org/pdf/1904.09751.pdf + "typical": 1.0, # Typical Sampling, https://arxiv.org/pdf/2202.00666.pdf + "sampler_order": [6,0,1,3,4,2,5], # Order to apply the samplers, our default in this script is already the optimal one. KoboldAI Lite contains an easy list of what the + "stop_sequence": [f"{user}"], # When should the AI stop generating? In this example we stop when it tries to speak on behalf of the user. + #"sampler_seed": 1337, # Use specific seed for text generation? This helps with consistency across tests. + "singleline": "False", # Only return a response that fits on a single line, this can help with chatbots but also makes them less verbose + "sampler_full_determinism": "False", # Always return the same result for the same query, best used with a static seed + "frmttriminc": "True", # Trim incomplete sentences, prevents sentences that are unfinished but can interfere with coding and other non english sentences + "frmtrmblln": "False", #Remove blank lines + "quiet": "False" # Don't print what you are doing in the KoboldAI console, helps with user privacy + } + +while True: + try: + user_message = input(f"{user} ") + + if len(user_message.strip()) < 1: + print(f"{bot}Please provide a valid input.") + continue + + fullmsg = f"{conversation_history[-1] if conversation_history else ''}{user} {user_message}\n{bot}" # Add all of conversation history if it exists and add User and Bot names + prompt = get_prompt(fullmsg) # Process prompt into KoboldAI API format + response = requests.post(f"{ENDPOINT}/v1/generate", json=prompt) # Send prompt to API + if response.status_code == 200: + results = response.json()['results'] # Set results as JSON response + text = results[0]['text'] # inside results, look in first group for section labeled 'text' + response_text = text.split('\n')[0].replace(" ", " ") # Optional, keep only the text before a new line, and replace double spaces with normal ones + conversation_history.append(f"{fullmsg}{response_text}\n") # Add the response to the end of your conversation history + else: + print(response) + print(f"{bot} {response_text}") + + except Exception as e: + print(f"An error occurred: {e}") \ No newline at end of file diff --git a/environments/ipex.yml b/environments/ipex.yml new file mode 100644 index 00000000..bd00cd80 --- /dev/null +++ b/environments/ipex.yml @@ -0,0 +1,50 @@ +name: koboldai +channels: + - conda-forge + - defaults +dependencies: + - colorama + - flask=2.2.3 + - flask-socketio=5.3.2 + - flask-session=0.4.0 + - python-socketio=5.7.2 + - python=3.8.* + - eventlet=0.33.3 + - dnspython=2.2.1 + - markdown + - bleach=4.1.0 + - pip + - git=2.35.1 + - sentencepiece + - protobuf + - marshmallow>=3.13 + - apispec-webframeworks + - loguru + - termcolor + - Pillow + - psutil + - pip: + - -f https://developer.intel.com/ipex-whl-stable-xpu + - torch==2.0.1a0 + - intel_extension_for_pytorch==2.0.110+xpu + - flask-cloudflared==0.0.10 + - flask-ngrok + - flask-cors + - lupa==1.10 + - transformers==4.32.1 + - huggingface_hub==0.16.4 + - optimum==1.12.0 + - safetensors==0.3.3 + - accelerate==0.20.3 + - git+https://github.com/VE-FORBRYDERNE/mkultra + - ansi2html + - flask_compress + - ijson + - ftfy + - pydub + - diffusers + - git+https://github.com/0cc4m/hf_bleeding_edge/ + - einops + - peft==0.3.0 + - windows-curses; sys_platform == 'win32' + - pynvml \ No newline at end of file diff --git a/install_requirements.sh b/install_requirements.sh index 6e37c7e9..31fe5709 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -15,4 +15,11 @@ bin/micromamba create -f environments/rocm.yml -r runtime -n koboldai-rocm -y bin/micromamba create -f environments/rocm.yml -r runtime -n koboldai-rocm -y exit fi -echo Please specify either CUDA or ROCM +if [[ $1 = "ipex" || $1 = "IPEX" ]]; then +wget -qO- https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xvj bin/micromamba +bin/micromamba create -f environments/ipex.yml -r runtime -n koboldai-ipex -y +# Weird micromamba bug causes it to fail the first time, running it twice just to be safe, the second time is much faster +bin/micromamba create -f environments/ipex.yml -r runtime -n koboldai-ipex -y +exit +fi +echo Please specify either CUDA or ROCM or IPEX diff --git a/modeling/ipex/__init__.py b/modeling/ipex/__init__.py new file mode 100644 index 00000000..3503533b --- /dev/null +++ b/modeling/ipex/__init__.py @@ -0,0 +1,165 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex +from .diffusers import ipex_diffusers +from .hijacks import ipex_hijacks +from logger import logger + +#ControlNet depth_leres++ +class DummyDataParallel(torch.nn.Module): + def __new__(cls, module, device_ids=None, output_device=None, dim=0): + if type(device_ids) is list and len(device_ids) > 1: + logger.info("IPEX backend doesn't support DataParallel on multiple XPU devices") + return module.to("xpu") + +def return_null_context(*args, **kwargs): + return contextlib.nullcontext() + +def ipex_init(): + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 + + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel + torch.cuda.ipc_collect = lambda: None + torch.cuda.utilization = lambda: 0 + + ipex_hijacks() + ipex_diffusers() diff --git a/modeling/ipex/diffusers.py b/modeling/ipex/diffusers.py new file mode 100644 index 00000000..e0359219 --- /dev/null +++ b/modeling/ipex/diffusers.py @@ -0,0 +1,260 @@ +import torch +import intel_extension_for_pytorch as ipex +import torch.nn.functional as F +import diffusers #0.20.2 + +Attention = diffusers.models.attention_processor.Attention + +class SlicedAttnProcessor: + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, shape_three = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if block_size >= 4000: + do_split_2 = True + #Find something divisible with the query_tokens + while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the shape_one + while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the batch_size_attention + while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + + query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = F.scaled_dot_product_attention( + query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False + ) + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + else: + query_slice = query[:, start_idx:end_idx] + key_slice = key[:, start_idx:end_idx] + attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None + + attn_slice = F.scaled_dot_product_attention( + query_slice, key_slice, value[:, start_idx:end_idx], + attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False + ) + hidden_states[:, start_idx:end_idx] = attn_slice + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +def ipex_diffusers(): + #ARC GPUs can't allocate more than 4GB to a single block: + diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0 diff --git a/modeling/ipex/hijacks.py b/modeling/ipex/hijacks.py new file mode 100644 index 00000000..c50547bb --- /dev/null +++ b/modeling/ipex/hijacks.py @@ -0,0 +1,148 @@ +import torch +import intel_extension_for_pytorch as ipex +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) + +def check_device(device): + return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + +def return_xpu(device): + return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + +def ipex_no_cuda(orig_func, *args, **kwargs): + torch.cuda.is_available = lambda: False + orig_func(*args, **kwargs) + torch.cuda.is_available = torch.xpu.is_available + +original_autocast = torch.autocast +def ipex_autocast(*args, **kwargs): + if args[0] == "cuda" or args[0] == "xpu": + if "dtype" in kwargs: + return original_autocast("xpu", *args[1:], **kwargs) + else: + return original_autocast("xpu", *args[1:], dtype=torch.float16, **kwargs) + else: + return original_autocast(*args, **kwargs) + +original_torch_cat = torch.cat +def torch_cat(input, *args, **kwargs): + if len(input) == 3 and (input[0].dtype != input[1].dtype or input[2].dtype != input[1].dtype): + return original_torch_cat([input[0].to(input[1].dtype), input[1], input[2].to(input[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(input, *args, **kwargs) + +original_interpolate = torch.nn.functional.interpolate +def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): + if antialias: + return_device = input.device + return_dtype = input.dtype + return original_interpolate(input.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + else: + return original_interpolate(input, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + +original_linalg_solve = torch.linalg.solve +def linalg_solve(orig_func, A, B, *args, **kwargs): + if A.device != torch.device("cpu") or B.device != torch.device("cpu"): + return_device = A.device + original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + else: + original_linalg_solve(A, B, *args, **kwargs) + +def ipex_hijacks(): + CondFunc('torch.Tensor.to', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.Tensor.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.empty', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), + lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) + CondFunc('torch.randn', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.ones', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.zeros', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + + CondFunc('torch.Generator', + lambda orig_func, device: torch.xpu.Generator(device), + lambda orig_func, device: device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.batch_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + CondFunc('torch.instance_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + + #Functions with dtype errors: + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs), + lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype) + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + if not torch.xpu.has_fp64_dtype(): + CondFunc('torch.from_numpy', + lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), + lambda orig_func, ndarray: ndarray.dtype == float) + + #Broken functions when torch.cuda.is_available is True: + CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', + lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), + lambda orig_func, *args, **kwargs: True) + + #Functions that make compile mad with CondFunc: + torch.autocast = ipex_autocast + torch.cat = torch_cat + torch.linalg.solve = linalg_solve + torch.nn.functional.interpolate = interpolate diff --git a/play-ipex.sh b/play-ipex.sh new file mode 100755 index 00000000..eb6ecc29 --- /dev/null +++ b/play-ipex.sh @@ -0,0 +1,22 @@ +#!/bin/bash +export PYTHONNOUSERSITE=1 +if [ ! -f "runtime/envs/koboldai-ipex/bin/python" ]; then +./install_requirements.sh ipex +fi + +#Set OneAPI environmet if it's not set by the user +if [ ! -x "$(command -v sycl-ls)" ] +then + echo "Setting OneAPI environment" + if [[ -z "$ONEAPI_ROOT" ]] + then + ONEAPI_ROOT=/opt/intel/oneapi + fi + source $ONEAPI_ROOT/setvars.sh +fi + +export LD_PRELOAD=/usr/lib/libstdc++.so +export NEOReadDebugKeys=1 +export ClDeviceGlobalMemSizeAvailablePercent=100 + +bin/micromamba run -r runtime -n koboldai-ipex python aiserver.py $* \ No newline at end of file diff --git a/static/klite.html b/static/klite.html index 33ba94f0..661e8945 100644 --- a/static/klite.html +++ b/static/klite.html @@ -1,9 +1,9 @@ - + @@ -8148,8 +8712,6 @@ Kobold Lite is under the AGPL v3.0 License for the purposes of koboldcpp and Kob Share - - @@ -8184,7 +8746,7 @@ Kobold Lite is under the AGPL v3.0 License for the purposes of koboldcpp and Kob - +