diff --git a/aiserver.py b/aiserver.py index c20a3f95..8e79692d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -23,6 +23,9 @@ from ansi2html import Ansi2HTMLConverter logging.getLogger("urllib3").setLevel(logging.ERROR) +import attention_bias +attention_bias.do_patches() + from os import path, getcwd import time import re @@ -4946,7 +4949,27 @@ def calcsubmit(txt): if(koboldai_vars.model != "InferKit"): #subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) subtxt, min, max, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt) + + if koboldai_vars.memory_attn_bias > 1: + offset = 0 + bounds = None + for c in koboldai_vars.context: + length = len(tokenizer.encode(c["text"])) + if c["type"] == "memory": + bounds = [offset, offset + length] + break + offset += length + + print(f"Memory bounds: {bounds}") + assert bounds + + bias = [1] * bounds[0] + bias += [koboldai_vars.memory_attn_bias] * bounds[1] + + attention_bias.attention_bias = torch.Tensor(bias).to(breakmodel.primary_device) + print(f"Bias by {koboldai_vars.memory_attn_bias} -- {attention_bias.attention_bias}") generate(subtxt, min, max, found_entries) + attention_bias.attention_bias = None # For InferKit web API diff --git a/attention_bias.py b/attention_bias.py new file mode 100644 index 00000000..f0df11c0 --- /dev/null +++ b/attention_bias.py @@ -0,0 +1,173 @@ +# All OPT code is under the following license: +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +import transformers +from typing import Optional, Tuple + +from typing import Optional, Tuple + +has_harassed_user = False +attention_bias = None + +# Attention patch for attention bias + +def OPTAttention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + global attention_bias + global has_harassed_user + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + dtype_attn_weights = attn_weights.dtype + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if dtype_attn_weights == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # BEGIN ATTENTION BIAS + if attention_bias is not None and self.is_decoder: + if not has_harassed_user: + print("[attention] Applying attention bias (will not show this again!!!!!!!!!!!)") + has_harassed_user = True + + extra_tokens = attn_probs.shape[2] - attention_bias.shape[0] + + # Obviously we add tokens during generation + att = nn.functional.pad(attention_bias, pad=(0, extra_tokens), value=1) + + # May be slow, not sure + att = att.to(attn_probs.device) + + attn_probs[:, :, :] *= att + attn_probs = nn.functional.normalize(attn_probs, p=1, dim=2) + # END ATTENTION BIAS + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + + return attn_output, attn_weights_reshaped, past_key_value + +# Patch patch patch! +def do_patches(): + transformers.models.opt.modeling_opt.OPTAttention.forward = OPTAttention_forward \ No newline at end of file diff --git a/koboldai_settings.py b/koboldai_settings.py index 436bbc05..d611be81 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -713,6 +713,9 @@ class story_settings(settings): #must be at bottom self.no_save = False #Temporary disable save (doesn't save with the file) + + # bias experiment + self.memory_attn_bias = 1 def save_story(self): if not self.no_save: diff --git a/static/koboldai.js b/static/koboldai.js index af9f9837..7f7658c6 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -2297,6 +2297,7 @@ function sync_to_server(item) { value = item.checked; } else { value = item.value; + if (item.classList.contains("sync_as_float")) value = parseFloat(value); } } else { value = item.textContent; diff --git a/templates/story flyout.html b/templates/story flyout.html index 313aeb47..50af1382 100644 --- a/templates/story flyout.html +++ b/templates/story flyout.html @@ -12,6 +12,21 @@ Important information the AI should always keep in mind. + +

Attention Bias test

+ + Note: This is OPT only for now! Patches will be written for other models once it's known this actually has a positive effect. Upon first use of this bias, you should see "Applying attention bias" in the console.
+ This is an experimental setting that may change how the AI pays attention to memory. Any high number in the ballpark of 15 may cause incoherence. The option to select higher numbers is present for experimentation. +
+ + 1 + +