Merge branch 'henk717:united' into united
This commit is contained in:
commit
b0f1bdf2fd
107
aiserver.py
107
aiserver.py
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/python3
|
#!/usr/bin/python3
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# KoboldAI
|
# KoboldAI
|
||||||
# Version: 1.16.4
|
# Version: 1.17.0
|
||||||
# By: KoboldAIDev and the KoboldAI Community
|
# By: KoboldAIDev and the KoboldAI Community
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import packaging
|
||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Iterable
|
||||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -100,6 +101,8 @@ class vars:
|
||||||
genamt = 80 # Amount of text for each action to generate
|
genamt = 80 # Amount of text for each action to generate
|
||||||
ikgen = 200 # Number of characters for InferKit to generate
|
ikgen = 200 # Number of characters for InferKit to generate
|
||||||
rep_pen = 1.1 # Default generator repetition_penalty
|
rep_pen = 1.1 # Default generator repetition_penalty
|
||||||
|
rep_pen_slope = 0.0 # Default generator repetition penalty slope
|
||||||
|
rep_pen_range = 0 # Default generator repetition penalty range
|
||||||
temp = 0.5 # Default generator temperature
|
temp = 0.5 # Default generator temperature
|
||||||
top_p = 0.9 # Default generator top_p
|
top_p = 0.9 # Default generator top_p
|
||||||
top_k = 0 # Default generator top_k
|
top_k = 0 # Default generator top_k
|
||||||
|
@ -709,65 +712,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
|
|
||||||
# Patch transformers to use our custom logit warpers
|
# Patch transformers to use our custom logit warpers
|
||||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||||
|
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
|
||||||
|
|
||||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||||
old_call = cls.__call__
|
old_call = cls.__call__
|
||||||
def new_call(self, *args, **kwargs):
|
def new_call(self, *args, **kwargs):
|
||||||
setattr(self, field_name, getattr(vars, var_name))
|
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
|
||||||
|
conds = []
|
||||||
|
for f, v in zip(field_name, var_name):
|
||||||
|
conds.append(getattr(vars, v))
|
||||||
|
setattr(self, f, conds[-1])
|
||||||
|
else:
|
||||||
|
conds = getattr(vars, var_name)
|
||||||
|
setattr(self, field_name, conds)
|
||||||
assert len(args) == 2
|
assert len(args) == 2
|
||||||
if(cond is None or cond(getattr(vars, var_name))):
|
if(cond is None or cond(conds)):
|
||||||
return old_call(self, *args, **kwargs)
|
return old_call(self, *args, **kwargs)
|
||||||
return args[1]
|
return args[1]
|
||||||
cls.__call__ = new_call
|
cls.__call__ = new_call
|
||||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
|
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
|
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||||
|
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||||
|
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
|
||||||
|
|
||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
||||||
tfs = float(tfs)
|
|
||||||
if tfs < 0 or tfs > 1.0:
|
|
||||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
|
||||||
self.tfs = tfs
|
|
||||||
self.filter_value = filter_value
|
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
self.tfs = vars.tfs
|
|
||||||
|
|
||||||
if self.filter_value >= 1.0:
|
|
||||||
return scores
|
|
||||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
|
||||||
probs = sorted_logits.softmax(dim=-1)
|
|
||||||
|
|
||||||
# Compute second derivative normalized CDF
|
|
||||||
d2 = probs.diff().diff().abs()
|
|
||||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
|
||||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
|
||||||
|
|
||||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
|
||||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
|
||||||
|
|
||||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
|
||||||
sorted_indices_to_remove = torch.cat(
|
|
||||||
(
|
|
||||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.min_tokens_to_keep > 1:
|
|
||||||
# Keep at least min_tokens_to_keep
|
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
|
||||||
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
class LuaLogitsProcessor(LogitsProcessor):
|
class LuaLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1085,6 +1055,8 @@ else:
|
||||||
"top_k": int(vars.top_k),
|
"top_k": int(vars.top_k),
|
||||||
"tfs": float(vars.tfs),
|
"tfs": float(vars.tfs),
|
||||||
"repetition_penalty": float(vars.rep_pen),
|
"repetition_penalty": float(vars.rep_pen),
|
||||||
|
"rpslope": float(vars.rep_pen_slope),
|
||||||
|
"rprange": int(vars.rep_pen_range),
|
||||||
}
|
}
|
||||||
|
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
|
@ -1432,6 +1404,8 @@ def lua_has_setting(setting):
|
||||||
"settopk",
|
"settopk",
|
||||||
"settfs",
|
"settfs",
|
||||||
"setreppen",
|
"setreppen",
|
||||||
|
"setreppenslope",
|
||||||
|
"setreppenrange",
|
||||||
"settknmax",
|
"settknmax",
|
||||||
"setwidepth",
|
"setwidepth",
|
||||||
"setuseprompt",
|
"setuseprompt",
|
||||||
|
@ -1447,6 +1421,8 @@ def lua_has_setting(setting):
|
||||||
"top_k",
|
"top_k",
|
||||||
"tfs",
|
"tfs",
|
||||||
"reppen",
|
"reppen",
|
||||||
|
"reppenslope",
|
||||||
|
"reppenrange",
|
||||||
"tknmax",
|
"tknmax",
|
||||||
"widepth",
|
"widepth",
|
||||||
"useprompt",
|
"useprompt",
|
||||||
|
@ -1478,6 +1454,8 @@ def lua_get_setting(setting):
|
||||||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||||
|
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||||
|
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||||
if(setting in ("settknmax", "tknmax")): return vars.max_length
|
if(setting in ("settknmax", "tknmax")): return vars.max_length
|
||||||
if(setting == "anotedepth"): return vars.andepth
|
if(setting == "anotedepth"): return vars.andepth
|
||||||
if(setting in ("setwidepth", "widepth")): return vars.widepth
|
if(setting in ("setwidepth", "widepth")): return vars.widepth
|
||||||
|
@ -1509,6 +1487,8 @@ def lua_set_setting(setting, v):
|
||||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||||
|
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||||
|
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||||
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
|
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
|
||||||
if(setting == "anotedepth"): vars.andepth = v; return True
|
if(setting == "anotedepth"): vars.andepth = v; return True
|
||||||
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
|
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
|
||||||
|
@ -1906,6 +1886,16 @@ def get_message(msg):
|
||||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
|
elif(msg['cmd'] == 'setreppenslope'):
|
||||||
|
vars.rep_pen_slope = float(msg['data'])
|
||||||
|
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
|
elif(msg['cmd'] == 'setreppenrange'):
|
||||||
|
vars.rep_pen_range = float(msg['data'])
|
||||||
|
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
elif(msg['cmd'] == 'setoutput'):
|
elif(msg['cmd'] == 'setoutput'):
|
||||||
vars.genamt = int(msg['data'])
|
vars.genamt = int(msg['data'])
|
||||||
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
|
||||||
|
@ -2183,6 +2173,8 @@ def savesettings():
|
||||||
js["top_k"] = vars.top_k
|
js["top_k"] = vars.top_k
|
||||||
js["tfs"] = vars.tfs
|
js["tfs"] = vars.tfs
|
||||||
js["rep_pen"] = vars.rep_pen
|
js["rep_pen"] = vars.rep_pen
|
||||||
|
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||||
|
js["rep_pen_range"] = vars.rep_pen_range
|
||||||
js["genamt"] = vars.genamt
|
js["genamt"] = vars.genamt
|
||||||
js["max_length"] = vars.max_length
|
js["max_length"] = vars.max_length
|
||||||
js["ikgen"] = vars.ikgen
|
js["ikgen"] = vars.ikgen
|
||||||
|
@ -2238,6 +2230,10 @@ def loadsettings():
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
|
if("rep_pen_slope" in js):
|
||||||
|
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||||
|
if("rep_pen_range" in js):
|
||||||
|
vars.rep_pen_range = js["rep_pen_range"]
|
||||||
if("genamt" in js):
|
if("genamt" in js):
|
||||||
vars.genamt = js["genamt"]
|
vars.genamt = js["genamt"]
|
||||||
if("max_length" in js):
|
if("max_length" in js):
|
||||||
|
@ -2303,7 +2299,10 @@ def loadmodelsettings():
|
||||||
model_js_config = str(model_config).partition(' ')[2]
|
model_js_config = str(model_config).partition(' ')[2]
|
||||||
js = json.loads(model_js_config)
|
js = json.loads(model_js_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
try:
|
||||||
|
model_js_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
|
except Exception as e:
|
||||||
|
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
||||||
js = json.load(model_js_config)
|
js = json.load(model_js_config)
|
||||||
if("badwordsids" in js):
|
if("badwordsids" in js):
|
||||||
vars.badwordsids = js["badwordsids"]
|
vars.badwordsids = js["badwordsids"]
|
||||||
|
@ -2317,6 +2316,10 @@ def loadmodelsettings():
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
|
if("rep_pen_slope" in js):
|
||||||
|
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||||
|
if("rep_pen_range" in js):
|
||||||
|
vars.rep_pen_range = js["rep_pen_range"]
|
||||||
if("adventure" in js):
|
if("adventure" in js):
|
||||||
vars.adventure = js["adventure"]
|
vars.adventure = js["adventure"]
|
||||||
if("chatmode" in js):
|
if("chatmode" in js):
|
||||||
|
@ -3091,6 +3094,8 @@ def sendtocolab(txt, min, max):
|
||||||
'min': min,
|
'min': min,
|
||||||
'max': max,
|
'max': max,
|
||||||
'rep_pen': vars.rep_pen,
|
'rep_pen': vars.rep_pen,
|
||||||
|
'rep_pen_slope': vars.rep_pen_slope,
|
||||||
|
'rep_pen_range': vars.rep_pen_range,
|
||||||
'temperature': vars.temp,
|
'temperature': vars.temp,
|
||||||
'top_p': vars.top_p,
|
'top_p': vars.top_p,
|
||||||
'top_k': vars.top_k,
|
'top_k': vars.top_k,
|
||||||
|
@ -3233,6 +3238,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
tfs=vars.tfs,
|
tfs=vars.tfs,
|
||||||
numseqs=vars.numseqs,
|
numseqs=vars.numseqs,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
|
rpslope=vars.rep_pen_slope,
|
||||||
|
rprange=vars.rep_pen_range,
|
||||||
soft_embeddings=vars.sp,
|
soft_embeddings=vars.sp,
|
||||||
soft_tokens=soft_tokens,
|
soft_tokens=soft_tokens,
|
||||||
)
|
)
|
||||||
|
@ -3415,6 +3422,8 @@ def refresh_settings():
|
||||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||||
|
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||||
|
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
|
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)
|
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)
|
||||||
|
|
|
@ -867,6 +867,8 @@ return function(_python, _bridged)
|
||||||
---@field settopk integer
|
---@field settopk integer
|
||||||
---@field settfs number
|
---@field settfs number
|
||||||
---@field setreppen number
|
---@field setreppen number
|
||||||
|
---@field setreppenslope number
|
||||||
|
---@field setreppenrange number
|
||||||
---@field settknmax integer
|
---@field settknmax integer
|
||||||
---@field setwidepth integer
|
---@field setwidepth integer
|
||||||
---@field setuseprompt boolean
|
---@field setuseprompt boolean
|
||||||
|
@ -881,6 +883,8 @@ return function(_python, _bridged)
|
||||||
---@field top_k integer
|
---@field top_k integer
|
||||||
---@field tfs number
|
---@field tfs number
|
||||||
---@field reppen number
|
---@field reppen number
|
||||||
|
---@field reppenslope number
|
||||||
|
---@field reppenrange number
|
||||||
---@field tknmax integer
|
---@field tknmax integer
|
||||||
---@field widepth integer
|
---@field widepth integer
|
||||||
---@field useprompt boolean
|
---@field useprompt boolean
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
"private_outputs": true,
|
"private_outputs": true,
|
||||||
"provenance": [],
|
"provenance": [],
|
||||||
"collapsed_sections": [],
|
"collapsed_sections": [],
|
||||||
"authorship_tag": "ABX9TyPvJVmaGhMfnEaKeWczHtH+",
|
"authorship_tag": "ABX9TyM1P6rS/XmJbyV/HZRzXohF",
|
||||||
"include_colab_link": true
|
"include_colab_link": true
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
|
@ -39,7 +39,9 @@
|
||||||
"# Welcome to KoboldAI on Google Colab, GPU Edition!\n",
|
"# Welcome to KoboldAI on Google Colab, GPU Edition!\n",
|
||||||
"KoboldAI is a powerful and easy way to use a variety of AI based text generation experiences. You can use it to write stories, blog posts, play a text adventure game, use it like a chatbot and more! In some cases it might even help you with an assignment or programming task (But always make sure the information the AI mentions is correct, it loves to make stuff up).\n",
|
"KoboldAI is a powerful and easy way to use a variety of AI based text generation experiences. You can use it to write stories, blog posts, play a text adventure game, use it like a chatbot and more! In some cases it might even help you with an assignment or programming task (But always make sure the information the AI mentions is correct, it loves to make stuff up).\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For more information about KoboldAI check our our Github readme : https://github.com/KoboldAI/KoboldAI-Client/blob/main/readme.md"
|
"For more information about KoboldAI check our our Github readme : https://github.com/KoboldAI/KoboldAI-Client/blob/main/readme.md\n",
|
||||||
|
"\n",
|
||||||
|
"For the larger AI models (That are typically more coherent) check out our [TPU edition](https://colab.research.google.com/github/KoboldAI/KoboldAI-Client/blob/main/colab/TPU.ipynb)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -81,7 +83,7 @@
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
"source": [
|
||||||
"## GPU Edition Model Descriptions\n",
|
"# GPU Edition Model Descriptions\n",
|
||||||
"| Model | Size | Style | Description |\n",
|
"| Model | Size | Style | Description |\n",
|
||||||
"| ------------------------------------------------------------ | -------- | ---------- | ------------------------------------------------------------ |\n",
|
"| ------------------------------------------------------------ | -------- | ---------- | ------------------------------------------------------------ |\n",
|
||||||
"| [GPT-Neo-2.7B-Picard](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Picard) by Mr Seeker | 2.7B GPU | Novel | Picard is a model trained for SFW Novels based on GPT-Neo-2.7B. It is focused on Novel style writing without the NSFW bias. While the name suggests a sci-fi model this model is designed for Novels of a variety of genre's. It is meant to be used in KoboldAI's regular mode. |\n",
|
"| [GPT-Neo-2.7B-Picard](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Picard) by Mr Seeker | 2.7B GPU | Novel | Picard is a model trained for SFW Novels based on GPT-Neo-2.7B. It is focused on Novel style writing without the NSFW bias. While the name suggests a sci-fi model this model is designed for Novels of a variety of genre's. It is meant to be used in KoboldAI's regular mode. |\n",
|
||||||
|
@ -91,6 +93,17 @@
|
||||||
"| [GPT-Neo-2.7B-Shinen](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Shinen) by Mr Seeker | 2.7B GPU | NSFW | Shinen is an alternative to the Horni model designed to be more explicit. If Horni is to tame for you shinen might produce better results. While it is a Novel model it is unsuitable for SFW stories due to its heavy NSFW bias. Shinen will not hold back. It is meant to be used in KoboldAI's regular mode. |\n",
|
"| [GPT-Neo-2.7B-Shinen](https://huggingface.co/KoboldAI/GPT-Neo-2.7B-Shinen) by Mr Seeker | 2.7B GPU | NSFW | Shinen is an alternative to the Horni model designed to be more explicit. If Horni is to tame for you shinen might produce better results. While it is a Novel model it is unsuitable for SFW stories due to its heavy NSFW bias. Shinen will not hold back. It is meant to be used in KoboldAI's regular mode. |\n",
|
||||||
"| [GPT-Neo-2.7B](https://huggingface.co/EleutherAI/gpt-neo-2.7B) by EleutherAI | 2.7B GPU | Generic | This is the base model for all the other 2.7B models, it is best used when you have a use case that we have no other models available for, such as writing blog articles or programming. It can also be a good basis for the experience of some of the softprompts if your softprompt is not about a subject the other models cover. |\n",
|
"| [GPT-Neo-2.7B](https://huggingface.co/EleutherAI/gpt-neo-2.7B) by EleutherAI | 2.7B GPU | Generic | This is the base model for all the other 2.7B models, it is best used when you have a use case that we have no other models available for, such as writing blog articles or programming. It can also be a good basis for the experience of some of the softprompts if your softprompt is not about a subject the other models cover. |\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# [TPU Edition Model Descriptions](https://colab.research.google.com/github/KoboldAI/KoboldAI-Client/blob/main/colab/TPU.ipynb)\n",
|
||||||
|
"\n",
|
||||||
|
"| Model | Size | Style | Drive Space | Description |\n",
|
||||||
|
"| ------------------------------ | ------ | --------- | ----------- | ------------------------------------------------------------ |\n",
|
||||||
|
"| Skein 6B by VE_FORBDRYDERNE | 6B TPU | Hybrid | 0 GB | Skein is our flagship 6B model, it is a hybrid between a Adventure model and a Novel model. Best used with either Adventure mode or the You Bias userscript enabled. Skein has been trained on high quality Novels along with CYOA adventure stories and is not as wackey as the Adventure model. It also has tagging support. |\n",
|
||||||
|
"| Adventure 6B by VE_FORBRYDERNE | 6B TPU | Adventure | 0 GB | Adventure is a 6B model designed to mimick the behavior of AI Dungeon. It is exclusively for Adventure Mode and can take you on the epic and wackey adventures that AI Dungeon players love. It also features the many tropes of AI Dungeon as it has been trained on very similar data. It must be used in second person (You). |\n",
|
||||||
|
"| Lit 6B by Haru | 6B TPU | NSFW | 8 GB / 12 GB | Lit is a great NSFW model trained by Haru on both a large set of Literotica stories and high quality novels along with tagging support. Creating a high quality model for your NSFW stories. This model is exclusively a novel model and is best used in third person. |\n",
|
||||||
|
"| Generic 6B by EleutherAI | 6B TPU | Generic | 10 GB / 12 GB | GPT-J-6B is what all other models are based on, if you need something that has no specific bias towards any particular subject this is the model for you. Best used when the other models are not suitable for what you wish to do. Such as homework assistance, blog writing, coding and more. It needs more hand holding than other models and is more prone to undesirable formatting changes. |\n",
|
||||||
|
"| C1 6B by Haru | 6B TPU | Chatbot | 8 GB / 12 GB | C1 has been trained on various internet chatrooms, it makes the basis for an interesting chatbot model and has been optimized to be used in the Chatmode. |\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"| Style | Description |\n",
|
"| Style | Description |\n",
|
||||||
"| --------- | ------------------------------------------------------------ |\n",
|
"| --------- | ------------------------------------------------------------ |\n",
|
||||||
"| Novel | For regular story writing, not compatible with Adventure mode or other specialty modes. |\n",
|
"| Novel | For regular story writing, not compatible with Adventure mode or other specialty modes. |\n",
|
||||||
|
@ -100,7 +113,7 @@
|
||||||
"| Hybrid | Hybrid models are a blend between different styles, for example they are trained on both Novel stories and Adventure stories. These models are great variety models that you can use for multiple different playstyles and modes, but depending on your usage you may need to enable Adventure Mode or the You bias (in userscripts). |\n",
|
"| Hybrid | Hybrid models are a blend between different styles, for example they are trained on both Novel stories and Adventure stories. These models are great variety models that you can use for multiple different playstyles and modes, but depending on your usage you may need to enable Adventure Mode or the You bias (in userscripts). |\n",
|
||||||
"| Generic | Generic models are not trained towards anything specific, typically used as a basis for other tasks and models. They can do everything the other models can do, but require much more handholding to work properly. Generic models are an ideal basis for tasks that we have no specific model for, or for experiencing a softprompt in its raw form. |\n",
|
"| Generic | Generic models are not trained towards anything specific, typically used as a basis for other tasks and models. They can do everything the other models can do, but require much more handholding to work properly. Generic models are an ideal basis for tasks that we have no specific model for, or for experiencing a softprompt in its raw form. |\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## How to start KoboldAI in 7 simple steps\n",
|
"# How to start KoboldAI in 7 simple steps\n",
|
||||||
"Using KoboldAI on Google Colab is easy! Simply follow these steps to get started:\n",
|
"Using KoboldAI on Google Colab is easy! Simply follow these steps to get started:\n",
|
||||||
"1. Mobile phone? Tap the play button below next to \"<--- Tap this if you play on mobile\" to reveal an audio player, play the silent audio to keep the tab alive so Google will not shut you down when your using KoboldAI. If no audio player is revealed your phone browser does not support Google Colab in the mobile view, go to your browser menu and enable Desktop mode before you continue.\n",
|
"1. Mobile phone? Tap the play button below next to \"<--- Tap this if you play on mobile\" to reveal an audio player, play the silent audio to keep the tab alive so Google will not shut you down when your using KoboldAI. If no audio player is revealed your phone browser does not support Google Colab in the mobile view, go to your browser menu and enable Desktop mode before you continue.\n",
|
||||||
"2. Select the model that most describes what you would like to do, by default we have the most recommended model for people willing to try out KoboldAI selected. If you are an advanced user you can also type any GPT model name from Huggingface.co to load this up (Unlisted Models may or may not work depending on Colab's hardware limitations).\n",
|
"2. Select the model that most describes what you would like to do, by default we have the most recommended model for people willing to try out KoboldAI selected. If you are an advanced user you can also type any GPT model name from Huggingface.co to load this up (Unlisted Models may or may not work depending on Colab's hardware limitations).\n",
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"For more information about KoboldAI check our our Github readme : https://github.com/KoboldAI/KoboldAI-Client/blob/main/readme.md\n",
|
"For more information about KoboldAI check our our Github readme : https://github.com/KoboldAI/KoboldAI-Client/blob/main/readme.md\n",
|
||||||
"\n",
|
"\n",
|
||||||
"More (smaller) models are available in the **[GPU edition](https://colab.research.google.com/github/henk717/KoboldAI/blob/united/colab/GPU.ipynb)**!"
|
"More (smaller) models are available in the **[GPU edition](https://colab.research.google.com/github/koboldai/KoboldAI-Client/blob/united/colab/GPU.ipynb)**!"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "zrLGxVCEaqZx"
|
"id": "zrLGxVCEaqZx"
|
||||||
|
@ -86,7 +86,7 @@
|
||||||
"if Model == \"C1 6B\":\n",
|
"if Model == \"C1 6B\":\n",
|
||||||
" path = \"gpt-j-6b-c1-jax\"\n",
|
" path = \"gpt-j-6b-c1-jax\"\n",
|
||||||
" location = \"drive\"\n",
|
" location = \"drive\"\n",
|
||||||
" download = \"-a https://storage.henk.tech/KoboldAI/aria2.php?file=gpt-j-6b-c1-jax\"\n",
|
" download = \"-a https://storage.henk.tech/KoboldAI/aria2.php?file=gpt-j-6b-c1-jax.7z\"\n",
|
||||||
" extract = \"-z gpt-j-6b-c1-jax.7z\"\n",
|
" extract = \"-z gpt-j-6b-c1-jax.7z\"\n",
|
||||||
" ![[ -f /content/drive/MyDrive/KoboldAI/settings/gpt-j-6b-c1-jax.settings ]] || echo -e \"{\\n \\\"apikey\\\": \\\"\\\",\\n \\\"andepth\\\": 3,\\n \\\"temp\\\": 0.5,\\n \\\"top_p\\\": 0.9,\\n \\\"top_k\\\": 0,\\n \\\"tfs\\\": 1.0,\\n \\\"rep_pen\\\": 1.1,\\n \\\"genamt\\\": 80,\\n \\\"max_length\\\": 2048,\\n \\\"ikgen\\\": 200,\\n \\\"formatoptns\\\": {\\n \\\"frmttriminc\\\": true,\\n \\\"frmtrmblln\\\": false,\\n \\\"frmtrmspch\\\": false,\\n \\\"frmtadsnsp\\\": false\\n },\\n \\\"numseqs\\\": 1,\\n \\\"widepth\\\": 3,\\n \\\"useprompt\\\": true,\\n \\\"chatmode\\\": true\\n}\" > /content/drive/MyDrive/KoboldAI/settings/gpt-j-6b-c1-jax.settings\n",
|
" ![[ -f /content/drive/MyDrive/KoboldAI/settings/gpt-j-6b-c1-jax.settings ]] || echo -e \"{\\n \\\"apikey\\\": \\\"\\\",\\n \\\"andepth\\\": 3,\\n \\\"temp\\\": 0.5,\\n \\\"top_p\\\": 0.9,\\n \\\"top_k\\\": 0,\\n \\\"tfs\\\": 1.0,\\n \\\"rep_pen\\\": 1.1,\\n \\\"genamt\\\": 80,\\n \\\"max_length\\\": 2048,\\n \\\"ikgen\\\": 200,\\n \\\"formatoptns\\\": {\\n \\\"frmttriminc\\\": true,\\n \\\"frmtrmblln\\\": false,\\n \\\"frmtrmspch\\\": false,\\n \\\"frmtadsnsp\\\": false\\n },\\n \\\"numseqs\\\": 1,\\n \\\"widepth\\\": 3,\\n \\\"useprompt\\\": true,\\n \\\"chatmode\\\": true\\n}\" > /content/drive/MyDrive/KoboldAI/settings/gpt-j-6b-c1-jax.settings\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -106,7 +106,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"| Model | Size | Style | Drive Space | Description |\n",
|
"| Model | Size | Style | Drive Space | Description |\n",
|
||||||
"| ------------------------------ | ------ | --------- | ----------- | ------------------------------------------------------------ |\n",
|
"| ------------------------------ | ------ | --------- | ----------- | ------------------------------------------------------------ |\n",
|
||||||
"| Skein 6B by VE_FORBDRYDERNE | 6B TPU | Hybrid | 0 GB | Skein is our flagship NSFW 6B model, it is a hybrid between a Adventure model and a Novel model. Best used with either Adventure mode or the You Bias userscript enabled. Skein has been trained on high quality Novels along with CYOA adventure stories and is not as wackey as the Adventure model. It also has tagging support. |\n",
|
"| Skein 6B by VE_FORBDRYDERNE | 6B TPU | Hybrid | 0 GB | Skein is our flagship 6B model, it is a hybrid between a Adventure model and a Novel model. Best used with either Adventure mode or the You Bias userscript enabled. Skein has been trained on high quality Novels along with CYOA adventure stories and is not as wackey as the Adventure model. It also has tagging support. |\n",
|
||||||
"| Adventure 6B by VE_FORBRYDERNE | 6B TPU | Adventure | 0 GB | Adventure is a 6B model designed to mimick the behavior of AI Dungeon. It is exclusively for Adventure Mode and can take you on the epic and wackey adventures that AI Dungeon players love. It also features the many tropes of AI Dungeon as it has been trained on very similar data. It must be used in second person (You). |\n",
|
"| Adventure 6B by VE_FORBRYDERNE | 6B TPU | Adventure | 0 GB | Adventure is a 6B model designed to mimick the behavior of AI Dungeon. It is exclusively for Adventure Mode and can take you on the epic and wackey adventures that AI Dungeon players love. It also features the many tropes of AI Dungeon as it has been trained on very similar data. It must be used in second person (You). |\n",
|
||||||
"| Lit 6B by Haru | 6B TPU | NSFW | 8 GB / 12 GB | Lit is a great NSFW model trained by Haru on both a large set of Literotica stories and high quality novels along with tagging support. Creating a high quality model for your NSFW stories. This model is exclusively a novel model and is best used in third person. |\n",
|
"| Lit 6B by Haru | 6B TPU | NSFW | 8 GB / 12 GB | Lit is a great NSFW model trained by Haru on both a large set of Literotica stories and high quality novels along with tagging support. Creating a high quality model for your NSFW stories. This model is exclusively a novel model and is best used in third person. |\n",
|
||||||
"| Generic 6B by EleutherAI | 6B TPU | Generic | 10 GB / 12 GB | GPT-J-6B is what all other models are based on, if you need something that has no specific bias towards any particular subject this is the model for you. Best used when the other models are not suitable for what you wish to do. Such as homework assistance, blog writing, coding and more. It needs more hand holding than other models and is more prone to undesirable formatting changes. |\n",
|
"| Generic 6B by EleutherAI | 6B TPU | Generic | 10 GB / 12 GB | GPT-J-6B is what all other models are based on, if you need something that has no specific bias towards any particular subject this is the model for you. Best used when the other models are not suitable for what you wish to do. Such as homework assistance, blog writing, coding and more. It needs more hand holding than other models and is more prone to undesirable formatting changes. |\n",
|
||||||
|
|
|
@ -1,4 +1,16 @@
|
||||||
gensettingstf = [{
|
gensettingstf = [
|
||||||
|
{
|
||||||
|
"uitype": "slider",
|
||||||
|
"unit": "int",
|
||||||
|
"label": "Amount to Generate",
|
||||||
|
"id": "setoutput",
|
||||||
|
"min": 16,
|
||||||
|
"max": 512,
|
||||||
|
"step": 2,
|
||||||
|
"default": 80,
|
||||||
|
"tooltip": "Number of tokens the AI should generate. Higher numbers will take longer to generate."
|
||||||
|
},
|
||||||
|
{
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
"unit": "float",
|
"unit": "float",
|
||||||
"label": "Temperature",
|
"label": "Temperature",
|
||||||
|
@ -56,13 +68,24 @@ gensettingstf = [{
|
||||||
{
|
{
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
"unit": "int",
|
"unit": "int",
|
||||||
"label": "Amount to Generate",
|
"label": "Rep Penalty Range",
|
||||||
"id": "setoutput",
|
"id": "setreppenrange",
|
||||||
"min": 16,
|
"min": 0,
|
||||||
"max": 512,
|
"max": 4096,
|
||||||
"step": 2,
|
"step": 4,
|
||||||
"default": 80,
|
"default": 0,
|
||||||
"tooltip": "Number of tokens the AI should generate. Higher numbers will take longer to generate."
|
"tooltip": "Repetition penalty range. If set higher than 0, only applies repetition penalty to the last few tokens of your story rather than applying it to the entire story. This slider controls the amount of tokens at the end of your story to apply it to."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uitype": "slider",
|
||||||
|
"unit": "float",
|
||||||
|
"label": "Rep Penalty Slope",
|
||||||
|
"id": "setreppenslope",
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1,
|
||||||
|
"default": 0.0,
|
||||||
|
"tooltip": "Repetition penalty slope. If BOTH this setting and Rep Penalty Range are set higher than 0, will use sigmoid interpolation to apply repetition penalty more strongly on tokens that are closer to the end of your story. This setting controls the tension of the sigmoid curve; higher settings will result in the repetition penalty difference between the start and end of your story being more apparent. Setting this to 1 uses linear interpolation; setting this to 0 disables interpolation."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
|
|
|
@ -2014,6 +2014,14 @@ $(document).ready(function(){
|
||||||
// Send current rep pen value to input
|
// Send current rep pen value to input
|
||||||
$("#setreppen").val(parseFloat(msg.data));
|
$("#setreppen").val(parseFloat(msg.data));
|
||||||
$("#setreppencur").html(msg.data);
|
$("#setreppencur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "updatereppenslope") {
|
||||||
|
// Send current rep pen value to input
|
||||||
|
$("#setreppenslope").val(parseFloat(msg.data));
|
||||||
|
$("#setreppenslopecur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "updatereppenrange") {
|
||||||
|
// Send current rep pen value to input
|
||||||
|
$("#setreppenrange").val(parseFloat(msg.data));
|
||||||
|
$("#setreppenrangecur").html(msg.data);
|
||||||
} else if(msg.cmd == "updateoutlen") {
|
} else if(msg.cmd == "updateoutlen") {
|
||||||
// Send current output amt value to input
|
// Send current output amt value to input
|
||||||
$("#setoutput").val(parseInt(msg.data));
|
$("#setoutput").val(parseInt(msg.data));
|
||||||
|
@ -2041,6 +2049,12 @@ $(document).ready(function(){
|
||||||
} else if(msg.cmd == "setlabelreppen") {
|
} else if(msg.cmd == "setlabelreppen") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#setreppencur").html(msg.data);
|
$("#setreppencur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "setlabelreppenslope") {
|
||||||
|
// Update setting label with value from server
|
||||||
|
$("#setreppenslopecur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "setlabelreppenrange") {
|
||||||
|
// Update setting label with value from server
|
||||||
|
$("#setreppenrangecur").html(msg.data);
|
||||||
} else if(msg.cmd == "setlabeloutput") {
|
} else if(msg.cmd == "setlabeloutput") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#setoutputcur").html(msg.data);
|
$("#setoutputcur").html(msg.data);
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
<link rel="stylesheet" href="static/bootstrap.min.css">
|
<link rel="stylesheet" href="static/bootstrap.min.css">
|
||||||
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
||||||
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
||||||
<link rel="stylesheet" href="static/custom.css?ver=1.16.4o">
|
<link rel="stylesheet" href="static/custom.css?ver=1.17">
|
||||||
|
|
||||||
<script src="static/jquery-3.6.0.min.js"></script>
|
<script src="static/jquery-3.6.0.min.js"></script>
|
||||||
<script src="static/jquery-ui.sortable.min.js"></script>
|
<script src="static/jquery-ui.sortable.min.js"></script>
|
||||||
|
@ -17,7 +17,7 @@
|
||||||
<script src="static/bootstrap.min.js"></script>
|
<script src="static/bootstrap.min.js"></script>
|
||||||
<script src="static/bootstrap-toggle.min.js"></script>
|
<script src="static/bootstrap-toggle.min.js"></script>
|
||||||
<script src="static/rangy-core.min.js"></script>
|
<script src="static/rangy-core.min.js"></script>
|
||||||
<script src="static/application.js?ver=1.16.4ac"></script>
|
<script src="static/application.js?ver=1.17"></script>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
||||||
|
|
|
@ -1,3 +1,32 @@
|
||||||
|
'''
|
||||||
|
This file is AGPL-licensed.
|
||||||
|
|
||||||
|
Some of the code in this file is from Clover Edition:
|
||||||
|
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
|
||||||
|
|
||||||
|
The license for Clover Edition is shown below:
|
||||||
|
|
||||||
|
Copyright (c) 2019 Nick Walton
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
'''
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
|
@ -33,6 +62,8 @@ def settings_callback() -> dict:
|
||||||
"top_k": 0,
|
"top_k": 0,
|
||||||
"tfs": 1.0,
|
"tfs": 1.0,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
|
"rpslope": 0.0,
|
||||||
|
"rprange": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def started_compiling_callback() -> None:
|
def started_compiling_callback() -> None:
|
||||||
|
@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1):
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
|
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
'''
|
'''
|
||||||
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
|
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
|
||||||
|
rpslope = np.int32(rpslope)
|
||||||
|
rprange = np.int32(rprange)
|
||||||
|
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
|
||||||
|
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||||
# Make a new array with the same length as the tokens array but with
|
# Make a new array with the same length as the tokens array but with
|
||||||
# each element replaced by the value at the corresponding index in the
|
# each element replaced by the value at the corresponding index in the
|
||||||
# logits array; e.g.
|
# logits array; e.g.
|
||||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
penalty_logits = np.take(logits, tokens)
|
penalty_logits = np.take(logits, tokens)
|
||||||
|
# Repetition penalty slope
|
||||||
|
if rpslope != 0.0 and rprange > 0:
|
||||||
|
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||||
|
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||||
|
repetition_penalty = _penalty
|
||||||
# Divide positive values by repetition_penalty and multiply negative
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
# values by repetition_penalty (the academic publication that described
|
# values by repetition_penalty (the academic publication that described
|
||||||
# this technique actually just only divided, but that would cause tokens
|
# this technique actually just only divided, but that would cause tokens
|
||||||
# with negative logits to become more likely, which is obviously wrong)
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
penalty_logits = np.where(
|
penalty_logits = np.where(
|
||||||
penalty_logits > 0,
|
penalty_arange >= 0,
|
||||||
penalty_logits/repetition_penalty,
|
np.where(
|
||||||
penalty_logits*repetition_penalty,
|
penalty_logits > 0,
|
||||||
|
penalty_logits/repetition_penalty,
|
||||||
|
penalty_logits*repetition_penalty,
|
||||||
|
),
|
||||||
|
penalty_logits,
|
||||||
)
|
)
|
||||||
# Finally, put those penalized logit values back into their original
|
# Finally, put those penalized logit values back into their original
|
||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
|
@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty):
|
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
'''
|
'''
|
||||||
|
rpslope = jnp.int32(rpslope)
|
||||||
|
rprange = jnp.int32(rprange)
|
||||||
|
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
|
||||||
|
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||||
# Make a new array with the same length as the tokens array but with
|
# Make a new array with the same length as the tokens array but with
|
||||||
# each element replaced by the value at the corresponding index in the
|
# each element replaced by the value at the corresponding index in the
|
||||||
# logits array; e.g.
|
# logits array; e.g.
|
||||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
penalty_logits = jnp.take(logits, tokens)
|
penalty_logits = jnp.take(logits, tokens)
|
||||||
|
# Repetition penalty slope
|
||||||
|
def apply_slope(carry):
|
||||||
|
repetition_penalty, rprange = carry
|
||||||
|
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||||
|
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||||
|
return _penalty
|
||||||
|
repetition_penalty = jax.lax.cond(
|
||||||
|
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||||
|
apply_slope,
|
||||||
|
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||||
|
(repetition_penalty, rprange),
|
||||||
|
)
|
||||||
# Divide positive values by repetition_penalty and multiply negative
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
# values by repetition_penalty (the academic publication that described
|
# values by repetition_penalty (the academic publication that described
|
||||||
# this technique actually just only divided, but that would cause tokens
|
# this technique actually just only divided, but that would cause tokens
|
||||||
# with negative logits to become more likely, which is obviously wrong)
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
penalty_logits = jnp.where(
|
penalty_logits = jnp.where(
|
||||||
penalty_logits > 0,
|
penalty_arange >= 0,
|
||||||
penalty_logits/repetition_penalty,
|
jnp.where(
|
||||||
penalty_logits*repetition_penalty,
|
penalty_logits > 0,
|
||||||
|
penalty_logits/repetition_penalty,
|
||||||
|
penalty_logits*repetition_penalty,
|
||||||
|
),
|
||||||
|
penalty_logits,
|
||||||
)
|
)
|
||||||
# Finally, put those penalized logit values back into their original
|
# Finally, put those penalized logit values back into their original
|
||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
|
@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
|
|
||||||
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
|
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
gi = data[0][1]
|
gi = data[0][1]
|
||||||
def sample_loop_fn(carry):
|
def sample_loop_fn(carry):
|
||||||
|
@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op
|
||||||
logits = apply_repetition_penalty_dynamic(
|
logits = apply_repetition_penalty_dynamic(
|
||||||
logits,
|
logits,
|
||||||
generated,
|
generated,
|
||||||
repetition_penalty
|
repetition_penalty,
|
||||||
|
generated_index,
|
||||||
|
gen_length,
|
||||||
|
rpslope,
|
||||||
|
rprange,
|
||||||
)
|
)
|
||||||
# Remove any tokens in the badwords list by setting
|
# Remove any tokens in the badwords list by setting
|
||||||
# their logits to negative infinity which effectively
|
# their logits to negative infinity which effectively
|
||||||
|
@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||||
# Get repetition penalty from the arguments
|
# Get repetition penalty from the arguments
|
||||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||||
|
rpslope = sampler_options.pop('rpslope', None)
|
||||||
|
rprange = sampler_options.pop('rprange', None)
|
||||||
# This is the main generation loop
|
# This is the main generation loop
|
||||||
def generate_loop_fn(carry):
|
def generate_loop_fn(carry):
|
||||||
# Unpack current generate_loop_fn state
|
# Unpack current generate_loop_fn state
|
||||||
|
@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
logits = apply_repetition_penalty_static(
|
logits = apply_repetition_penalty_static(
|
||||||
logits,
|
logits,
|
||||||
generated,
|
generated,
|
||||||
repetition_penalty
|
repetition_penalty,
|
||||||
|
generated_index,
|
||||||
|
gen_length,
|
||||||
|
rpslope,
|
||||||
|
rprange,
|
||||||
)
|
)
|
||||||
# Remove any tokens in the badwords list by setting
|
# Remove any tokens in the badwords list by setting
|
||||||
# their logits to negative infinity which effectively
|
# their logits to negative infinity which effectively
|
||||||
|
@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
sample_data[i][2] = logits[i]
|
sample_data[i][2] = logits[i]
|
||||||
sampler_options = settings_callback()
|
sampler_options = settings_callback()
|
||||||
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
rpslope = sampler_options.pop("rpslope", 0.0)
|
||||||
|
rprange = sampler_options.pop("rprange", 0)
|
||||||
|
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
|
||||||
n_generated += 1
|
n_generated += 1
|
||||||
for i in range(numseqs):
|
for i in range(numseqs):
|
||||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||||
|
@ -659,6 +737,8 @@ def infer_static(
|
||||||
top_k=0,
|
top_k=0,
|
||||||
tfs=1.0,
|
tfs=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
|
rpslope=0.0,
|
||||||
|
rprange=0,
|
||||||
numseqs=1,
|
numseqs=1,
|
||||||
gen_len=80,
|
gen_len=80,
|
||||||
soft_embeddings: Optional[np.array] = None,
|
soft_embeddings: Optional[np.array] = None,
|
||||||
|
@ -679,6 +759,8 @@ def infer_static(
|
||||||
"top_p": top_p * np.ones(total_batch),
|
"top_p": top_p * np.ones(total_batch),
|
||||||
"tfs": tfs * np.ones(total_batch),
|
"tfs": tfs * np.ones(total_batch),
|
||||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||||
|
"rpslope": rpslope * np.ones(total_batch),
|
||||||
|
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||||
}
|
}
|
||||||
output = network.generate_static(
|
output = network.generate_static(
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
-- Logit viewer
|
||||||
|
-- Displays raw token scores and softmax probabilities during generation.
|
||||||
|
|
||||||
|
kobold = require("bridge")()
|
||||||
|
local userscript = {} ---@class KoboldUserScript
|
||||||
|
|
||||||
|
local K = 10
|
||||||
|
|
||||||
|
---@class Pair
|
||||||
|
---@field id integer
|
||||||
|
---@field score number
|
||||||
|
|
||||||
|
---@class ArrayBase
|
||||||
|
---@type table<any, Pair>
|
||||||
|
local _ = {}
|
||||||
|
|
||||||
|
---@class Array : ArrayBase
|
||||||
|
---@field n integer
|
||||||
|
|
||||||
|
---@param array Array
|
||||||
|
---@param index integer
|
||||||
|
---@return nil
|
||||||
|
local function bubble(array, index)
|
||||||
|
local j = 0
|
||||||
|
while (index<<1)+1 < array.n do
|
||||||
|
j = index
|
||||||
|
if array[(index<<1)+1].score > array[j].score then
|
||||||
|
j = (index<<1)+1
|
||||||
|
end
|
||||||
|
if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then
|
||||||
|
j = (index<<1)+2
|
||||||
|
end
|
||||||
|
if index == j then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
local b = array[index]
|
||||||
|
array[index] = array[j]
|
||||||
|
array[j] = b
|
||||||
|
index = j
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param array Array
|
||||||
|
---@return nil
|
||||||
|
local function build(array)
|
||||||
|
for i = (array.n-1)>>1, 0, -1 do
|
||||||
|
bubble(array, i)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param array Array
|
||||||
|
---@return Pair
|
||||||
|
local function pop(array)
|
||||||
|
local r = array[0]
|
||||||
|
array.n = array.n - 1
|
||||||
|
array[0] = array[array.n]
|
||||||
|
bubble(array, 0)
|
||||||
|
return r
|
||||||
|
end
|
||||||
|
|
||||||
|
function userscript.genmod()
|
||||||
|
if K > kobold.logits_cols then
|
||||||
|
error("K must be at most the vocabulary size of the model")
|
||||||
|
end
|
||||||
|
|
||||||
|
if kobold.generated_cols > 0 then
|
||||||
|
for s, logits in ipairs(kobold.logits) do
|
||||||
|
local token = kobold.generated[s][kobold.generated_cols]
|
||||||
|
print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
for s, logits in ipairs(kobold.logits) do
|
||||||
|
local a = {} ---@type Array
|
||||||
|
local sum = 0.0
|
||||||
|
for i = 0, kobold.logits_cols-1 do
|
||||||
|
a[i] = {id = i, score = logits[i + 1]}
|
||||||
|
a.n = i + 1
|
||||||
|
sum = sum + math.exp(logits[i + 1])
|
||||||
|
end
|
||||||
|
build(a)
|
||||||
|
print()
|
||||||
|
print("Top " .. K .. " scores for sequence " .. s .. ":")
|
||||||
|
for i = 1, K do
|
||||||
|
local e = pop(a)
|
||||||
|
print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return userscript
|
|
@ -0,0 +1,100 @@
|
||||||
|
'''
|
||||||
|
This file is AGPL-licensed.
|
||||||
|
|
||||||
|
Some of the code in this file is from Clover Edition:
|
||||||
|
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
|
||||||
|
|
||||||
|
The license for Clover Edition is shown below:
|
||||||
|
|
||||||
|
Copyright (c) 2019 Nick Walton
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import LogitsWarper, LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
self.penalty_range = int(self.penalty_range)
|
||||||
|
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
|
||||||
|
|
||||||
|
if self.penalty != 1.0:
|
||||||
|
if self.penalty_range > 0:
|
||||||
|
if clipped_penalty_range < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[..., -clipped_penalty_range:]
|
||||||
|
|
||||||
|
if self.penalty_slope != 0:
|
||||||
|
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
|
||||||
|
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
|
||||||
|
self.penalty = _penalty[..., -clipped_penalty_range:]
|
||||||
|
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
tfs = float(tfs)
|
||||||
|
if tfs < 0 or tfs > 1.0:
|
||||||
|
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||||
|
self.tfs = tfs
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.filter_value >= 1.0:
|
||||||
|
return scores
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
# Compute second derivative normalized CDF
|
||||||
|
d2 = probs.diff().diff().abs()
|
||||||
|
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||||
|
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||||
|
|
||||||
|
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||||
|
sorted_indices_to_remove = torch.cat(
|
||||||
|
(
|
||||||
|
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
Loading…
Reference in New Issue