diff --git a/aiserver.py b/aiserver.py index c20a3f95..8e79692d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -23,6 +23,9 @@ from ansi2html import Ansi2HTMLConverter logging.getLogger("urllib3").setLevel(logging.ERROR) +import attention_bias +attention_bias.do_patches() + from os import path, getcwd import time import re @@ -4946,7 +4949,27 @@ def calcsubmit(txt): if(koboldai_vars.model != "InferKit"): #subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) subtxt, min, max, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt) + + if koboldai_vars.memory_attn_bias > 1: + offset = 0 + bounds = None + for c in koboldai_vars.context: + length = len(tokenizer.encode(c["text"])) + if c["type"] == "memory": + bounds = [offset, offset + length] + break + offset += length + + print(f"Memory bounds: {bounds}") + assert bounds + + bias = [1] * bounds[0] + bias += [koboldai_vars.memory_attn_bias] * bounds[1] + + attention_bias.attention_bias = torch.Tensor(bias).to(breakmodel.primary_device) + print(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}") generate(subtxt, min, max, found_entries) + attention_bias.attention_bias = None # For InferKit web API diff --git a/attention_bias.py b/attention_bias.py new file mode 100644 index 00000000..f0df11c0 --- /dev/null +++ b/attention_bias.py @@ -0,0 +1,173 @@ +# All OPT code is under the following license: +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +import transformers +from typing import Optional, Tuple + +from typing import Optional, Tuple + +has_harassed_user = False +attention_bias = None + +# Attention patch for attention bias + +def OPTAttention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + global attention_bias + global has_harassed_user + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + dtype_attn_weights = attn_weights.dtype + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if dtype_attn_weights == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # BEGIN ATTENTION BIAS + if attention_bias is not None and self.is_decoder: + if not has_harassed_user: + print("[attention] Applying attention bias (will not show this again!!!!!!!!!!!)") + has_harassed_user = True + + extra_tokens = attn_probs.shape[2] - attention_bias.shape[0] + + # Obviously we add tokens during generation + att = nn.functional.pad(attention_bias, pad=(0, extra_tokens), value=1) + + # May be slow, not sure + att = att.to(attn_probs.device) + + attn_probs[:, :, :] *= att + attn_probs = nn.functional.normalize(attn_probs, p=1, dim=2) + # END ATTENTION BIAS + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + + return attn_output, attn_weights_reshaped, past_key_value + +# Patch patch patch! +def do_patches(): + transformers.models.opt.modeling_opt.OPTAttention.forward = OPTAttention_forward \ No newline at end of file diff --git a/koboldai_settings.py b/koboldai_settings.py index 436bbc05..d611be81 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -713,6 +713,9 @@ class story_settings(settings): #must be at bottom self.no_save = False #Temporary disable save (doesn't save with the file) + + # bias experiment + self.memory_attn_bias = 1 def save_story(self): if not self.no_save: diff --git a/static/koboldai.js b/static/koboldai.js index af9f9837..7f7658c6 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -2297,6 +2297,7 @@ function sync_to_server(item) { value = item.checked; } else { value = item.value; + if (item.classList.contains("sync_as_float")) value = parseFloat(value); } } else { value = item.textContent; diff --git a/templates/story flyout.html b/templates/story flyout.html index 313aeb47..50af1382 100644 --- a/templates/story flyout.html +++ b/templates/story flyout.html @@ -12,6 +12,21 @@ Important information the AI should always keep in mind. + +