GPTQ support for IPEX

This commit is contained in:
Disty0
2023-09-19 17:09:51 +03:00
parent 9e49d7bc3c
commit 806fc4b8ad
2 changed files with 54 additions and 19 deletions

View File

@@ -24,18 +24,22 @@ dependencies:
- psutil - psutil
- pip: - pip:
- -f https://developer.intel.com/ipex-whl-stable-xpu - -f https://developer.intel.com/ipex-whl-stable-xpu
- torch==2.0.1a0 - torch==2.0.1a0; sys_platform == 'linux'
- intel_extension_for_pytorch==2.0.110+xpu - torch==2.0.0a0; sys_platform == 'win32'
- intel_extension_for_pytorch==2.0.110+xpu; sys_platform == 'linux'
- intel_extension_for_pytorch==2.0.110+gitba7f6c1; sys_platform == 'win32'
- intel-extension-for-transformers
- flask-cloudflared==0.0.10 - flask-cloudflared==0.0.10
- flask-ngrok - flask-ngrok
- flask-cors - flask-cors
- lupa==1.10 - lupa==1.10
- transformers[sentencepiece]==4.33.1 - transformers[sentencepiece]==4.33.1
- huggingface_hub==0.16.4 - huggingface_hub==0.16.4
- optimum[onnxruntime]==1.12.0 - optimum[openvino,nncf,neural-compressor]==1.12.0
- safetensors==0.3.3 - safetensors==0.3.3
- accelerate==0.20.3 - accelerate==0.21.0
- git+https://github.com/VE-FORBRYDERNE/mkultra - git+https://github.com/VE-FORBRYDERNE/mkultra
- flask-session
- ansi2html - ansi2html
- flask_compress - flask_compress
- ijson - ijson
@@ -43,7 +47,14 @@ dependencies:
- pydub - pydub
- diffusers - diffusers
- git+https://github.com/0cc4m/hf_bleeding_edge/ - git+https://github.com/0cc4m/hf_bleeding_edge/
- https://github.com/0cc4m/GPTQ-for-LLaMa/releases/download/0.0.6/gptq_koboldai-0.0.6-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux'
- https://github.com/0cc4m/GPTQ-for-LLaMa/releases/download/0.0.6/gptq_koboldai-0.0.6-cp38-cp38-win_amd64.whl; sys_platform == 'win32'
- https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.1/auto_gptq-0.4.1+cu118-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux'
- https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.1/auto_gptq-0.4.1+cu118-cp38-cp38-win_amd64.whl; sys_platform == 'win32'
- einops - einops
- peft==0.3.0 - peft==0.3.0
- scipy
- https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux'
- https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-win_amd64.whl; sys_platform == 'win32'
- windows-curses; sys_platform == 'win32' - windows-curses; sys_platform == 'win32'
- pynvml - pynvml

View File

@@ -64,8 +64,14 @@ def torch_bmm(input, mat2, *, out=None):
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it: #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape shape_one, batch_size_attention, query_tokens, shape_four = query.shape
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 no_shape_one = False
block_multiply = 3.6 if query.dtype == torch.float32 else 1.8
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
split_slice_size = batch_size_attention split_slice_size = batch_size_attention
if block_size >= 4000: if block_size >= 4000:
@@ -101,6 +107,15 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2], query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2], key[:, start_idx:end_idx, start_idx_2:end_idx_2],
@@ -108,6 +123,15 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal dropout_p=dropout_p, is_causal=is_causal
) )
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else: else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx], query[:, start_idx:end_idx],