mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-26 00:17:41 +01:00
commit
8d3eb44d2e
118
aiserver.py
118
aiserver.py
@ -963,7 +963,10 @@ def loadmodelsettings():
|
||||
if("nobreakmodel" in js):
|
||||
vars.nobreakmodel = js["nobreakmodel"]
|
||||
if("sampler_order" in js):
|
||||
vars.sampler_order = js["sampler_order"]
|
||||
sampler_order = vars.sampler_order
|
||||
if(len(sampler_order) < 7):
|
||||
sampler_order = [6] + sampler_order
|
||||
vars.sampler_order = sampler_order
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
@ -1094,7 +1097,10 @@ def processsettings(js):
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("sampler_order" in js):
|
||||
vars.sampler_order = js["sampler_order"]
|
||||
sampler_order = vars.sampler_order
|
||||
if(len(sampler_order) < 7):
|
||||
sampler_order = [6] + sampler_order
|
||||
vars.sampler_order = sampler_order
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
@ -1474,22 +1480,22 @@ def get_model_info(model, directory=""):
|
||||
|
||||
def get_layer_count(model, directory=""):
|
||||
if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(vars.model == "GPT2Custom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
if(model == "GPT2Custom"):
|
||||
with open(os.path.join(directory, "config.json"), "r") as f:
|
||||
model_config = json.load(f)
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
else:
|
||||
if(directory):
|
||||
model = directory
|
||||
from transformers import AutoConfig
|
||||
if directory == "":
|
||||
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
||||
if(os.path.isdir(model.replace('/', '_'))):
|
||||
model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(model.replace('/', '_')))):
|
||||
model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir(directory)):
|
||||
model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir(vars.custmodpth.replace('/', '_'))):
|
||||
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
|
||||
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
||||
return utils.num_layers(model_config)
|
||||
else:
|
||||
return None
|
||||
@ -1727,8 +1733,6 @@ def patch_transformers():
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@ -1805,9 +1809,13 @@ def patch_transformers():
|
||||
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||
for k in vars.sampler_order:
|
||||
sampler_order = vars.sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
for k in sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
return scores
|
||||
|
||||
@ -1940,22 +1948,24 @@ def reset_model_settings():
|
||||
vars.badwordsids = []
|
||||
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
||||
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
|
||||
vars.sampler_order = [0, 1, 2, 3, 4, 5]
|
||||
vars.sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||
vars.newlinemode = "n"
|
||||
vars.revision = None
|
||||
|
||||
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model=""):
|
||||
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False):
|
||||
global model
|
||||
global generator
|
||||
global torch
|
||||
global model_config
|
||||
global GPT2TokenizerFast
|
||||
global tokenizer
|
||||
if(initial_load):
|
||||
use_breakmodel_args = True
|
||||
reset_model_settings()
|
||||
if not utils.HAS_ACCELERATE:
|
||||
disk_layers = None
|
||||
vars.noai = False
|
||||
if not initial_load:
|
||||
if not use_breakmodel_args:
|
||||
set_aibusy(True)
|
||||
if vars.model != 'ReadOnly':
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': "Loading {}".format(vars.model)}, broadcast=True)
|
||||
@ -1963,12 +1973,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
time.sleep(0.1)
|
||||
if gpu_layers is not None:
|
||||
args.breakmodel_gpulayers = gpu_layers
|
||||
elif initial_load:
|
||||
elif use_breakmodel_args:
|
||||
gpu_layers = args.breakmodel_gpulayers
|
||||
if breakmodel_args_default_to_cpu and gpu_layers is None:
|
||||
gpu_layers = args.breakmodel_gpulayers = []
|
||||
if disk_layers is not None:
|
||||
args.breakmodel_disklayers = int(disk_layers)
|
||||
elif initial_load:
|
||||
elif use_breakmodel_args:
|
||||
disk_layers = args.breakmodel_disklayers
|
||||
if breakmodel_args_default_to_cpu and disk_layers is None:
|
||||
disk_layers = args.breakmodel_disklayers = 0
|
||||
|
||||
#We need to wipe out the existing model and refresh the cuda cache
|
||||
model = None
|
||||
@ -2062,6 +2076,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print(2)
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
|
||||
@ -2311,7 +2326,6 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||
# early so that it knows where to load the individual model tensors
|
||||
if (utils.HAS_ACCELERATE or vars.lazy_load and vars.hascuda and vars.breakmodel) and not vars.nobreakmodel:
|
||||
print(1)
|
||||
device_config(model_config)
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
@ -2551,8 +2565,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
vars.compiling = False
|
||||
|
||||
def tpumtjgenerate_settings_callback() -> dict:
|
||||
sampler_order = vars.sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
return {
|
||||
"sampler_order": vars.sampler_order,
|
||||
"sampler_order": sampler_order,
|
||||
"top_p": float(vars.top_p),
|
||||
"temp": float(vars.temp),
|
||||
"top_k": int(vars.top_k),
|
||||
@ -3659,12 +3676,16 @@ def get_message(msg):
|
||||
sendUSStatItems()
|
||||
elif(msg['cmd'] == 'samplers'):
|
||||
sampler_order = msg["data"]
|
||||
sampler_order_min_length = 6
|
||||
sampler_order_max_length = 7
|
||||
if(not isinstance(sampler_order, list)):
|
||||
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
|
||||
if(len(sampler_order) != len(vars.sampler_order)):
|
||||
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
|
||||
if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)):
|
||||
raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}")
|
||||
if(not all(isinstance(e, int) for e in sampler_order)):
|
||||
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
|
||||
if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)):
|
||||
raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers")
|
||||
vars.sampler_order = sampler_order
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'list_model'):
|
||||
@ -3848,7 +3869,6 @@ def get_message(msg):
|
||||
emit(
|
||||
'from_server',
|
||||
{'cmd': 'showfieldbudget', 'data': {"length": None, "max": None, "field": field}},
|
||||
broadcast=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -4618,7 +4638,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=int(2e9),
|
||||
repetition_penalty=1.1,
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
@ -7242,6 +7262,9 @@ class WorldInfoFoldersUIDsSchema(KoboldSchema):
|
||||
class WorldInfoUIDsSchema(WorldInfoEntriesUIDsSchema):
|
||||
folders: List[WorldInfoFolderSchema] = fields.List(fields.Nested(WorldInfoFolderUIDsSchema), required=True)
|
||||
|
||||
class ModelSelectionSchema(KoboldSchema):
|
||||
model: str = fields.String(required=True, validate=validate.Regexp(r"^(?!\s*NeoCustom)(?!\s*GPT2Custom)(?!\s*TPUMeshTransformerGPTJ)(?!\s*TPUMeshTransformerGPTNeoX)(?!\s*GooseAI)(?!\s*OAI)(?!\s*InferKit)(?!\s*Colab)(?!\s*API).*$"), metadata={"description": 'Hugging Face model ID, the path to a model folder (relative to the "models" folder in the KoboldAI root folder) or "ReadOnly" for no model'})
|
||||
|
||||
def _generate_text(body: GenerationInputSchema):
|
||||
if vars.aibusy or vars.genseqs:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
@ -7453,6 +7476,49 @@ def get_model():
|
||||
return {"result": vars.model}
|
||||
|
||||
|
||||
@api_v1.put("/model")
|
||||
@api_schema_wrap
|
||||
def put_model(body: ModelSelectionSchema):
|
||||
"""---
|
||||
put:
|
||||
summary: Load a model
|
||||
description: |-2
|
||||
Loads a model given its Hugging Face model ID, the path to a model folder (relative to the "models" folder in the KoboldAI root folder) or "ReadOnly" for no model.
|
||||
tags:
|
||||
- model
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: ModelSelectionSchema
|
||||
example:
|
||||
model: ReadOnly
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: EmptySchema
|
||||
{api_validation_error_response}
|
||||
{api_server_busy_response}
|
||||
"""
|
||||
if vars.aibusy or vars.genseqs:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
"msg": "Server is busy; please try again later.",
|
||||
"type": "service_unavailable",
|
||||
}}), mimetype="application/json", status=503))
|
||||
set_aibusy(1)
|
||||
old_model = vars.model
|
||||
vars.model = body.model.strip()
|
||||
try:
|
||||
load_model(use_breakmodel_args=True, breakmodel_args_default_to_cpu=True)
|
||||
except Exception as e:
|
||||
vars.model = old_model
|
||||
raise e
|
||||
set_aibusy(0)
|
||||
return {}
|
||||
|
||||
|
||||
def prompt_validator(prompt: str):
|
||||
if len(prompt.strip()) == 0:
|
||||
raise ValidationError("String does not match expected pattern.")
|
||||
|
@ -66,7 +66,7 @@
|
||||
"#@title <b><-- Select your model below and then click this to start KoboldAI</b>\n",
|
||||
"#@markdown You can find a description of the models below along with instructions on how to start KoboldAI.\n",
|
||||
"\n",
|
||||
"Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n",
|
||||
"Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 20B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n",
|
||||
"Version = \"Official\" #@param [\"Official\", \"United\"] {allow-input: true}\n",
|
||||
"Provider = \"Cloudflare\" #@param [\"Localtunnel\", \"Cloudflare\"]\n",
|
||||
"\n",
|
||||
@ -93,6 +93,10 @@
|
||||
" Model = \"KoboldAI/fairseq-dense-13B-Shinen\"\n",
|
||||
" path = \"\"\n",
|
||||
" download = \"\"\n",
|
||||
"elif Model == \"Skein 20B\":\n",
|
||||
" Model = \"KoboldAI/GPT-NeoX-20B-Skein\"\n",
|
||||
" path = \"\"\n",
|
||||
" download = \"\"\n",
|
||||
"elif Model == \"NeoX 20B\":\n",
|
||||
" Model = \"EleutherAI/gpt-neox-20b\"\n",
|
||||
" path = \"\"\n",
|
||||
@ -128,7 +132,7 @@
|
||||
"elif Model == \"GPT-J-6B\":\n",
|
||||
" Model = \"EleutherAI/gpt-j-6B\"\n",
|
||||
" path = \"\"\n",
|
||||
" download = \"\"\n",
|
||||
" download = \"\"\n",
|
||||
"else:\n",
|
||||
" path = \"\"\n",
|
||||
" download = \"\"\n",
|
||||
@ -225,4 +229,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
@ -274,6 +274,17 @@ gensettingstf = [
|
||||
"default": 0,
|
||||
"tooltip": "Shows token selection probabilities. Does not work with more than one gens per action."
|
||||
},
|
||||
{
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
"label": "Show Field Budget",
|
||||
"id": "setshowbudget",
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Shows token usage when typing in relevant text boxes. <b>May lag slower devices.</b>"
|
||||
},
|
||||
]
|
||||
|
||||
gensettingsik =[{
|
||||
|
@ -241,8 +241,27 @@ function addSetting(ob) {
|
||||
if(ob.id == "setadventure"){
|
||||
setadventure($(this).prop('checked'));
|
||||
}
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
if (ob.id === "setshowbudget") {
|
||||
$("#setshowbudget").on("change", function () {
|
||||
for (const el of document.getElementsByClassName("input-token-usage")) {
|
||||
if (this.checked) {
|
||||
el.classList.remove("hidden");
|
||||
} else {
|
||||
el.classList.add("hidden");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (!$("#setshowbudget")[0].checked) {
|
||||
for (const el of document.getElementsByClassName("input-token-usage")) {
|
||||
el.classList.add("hidden");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function refreshTitle() {
|
||||
@ -1287,12 +1306,13 @@ function buildSamplerList(samplers) {
|
||||
"Tail-free Sampling",
|
||||
"Typical Sampling",
|
||||
"Temperature",
|
||||
"Repetition Penalty",
|
||||
]
|
||||
for(i=0; i<samplers.length; i++) {
|
||||
samplerslist.append("<div class=\"flex\">\
|
||||
<div class=\"samplerslistitem flex-row-container\" sid=\""+samplers[i]+"\">\
|
||||
<div class=\"flex-row\">\
|
||||
<div>"+samplers_lookup_table[samplers[i]]+"</div>\
|
||||
<div>"+(samplers[i] < samplers_lookup_table.length ? samplers_lookup_table[samplers[i]] : "Unknown sampler #" + samplers[i])+"</div>\
|
||||
</div>\
|
||||
</div>\
|
||||
</div>");
|
||||
@ -2165,6 +2185,9 @@ function interpolateRGB(color0, color1, t) {
|
||||
}
|
||||
|
||||
function updateInputBudget(inputElement) {
|
||||
let budgetElement = document.getElementById("setshowbudget");
|
||||
if (budgetElement && !budgetElement.checked) return;
|
||||
|
||||
let data = {"unencoded": inputElement.value, "field": inputElement.id};
|
||||
|
||||
if (inputElement.id === "anoteinput") {
|
||||
@ -2182,7 +2205,6 @@ function registerTokenCounters() {
|
||||
|
||||
let span = document.createElement("span");
|
||||
span.classList.add("input-token-usage");
|
||||
span.innerText = "?/? Tokens";
|
||||
el.appendChild(span);
|
||||
|
||||
let inputElement = el.querySelector("input, textarea");
|
||||
@ -2958,6 +2980,7 @@ $(document).ready(function(){
|
||||
$("#showmodelnamecontainer").removeClass("hidden");
|
||||
} else if(msg.cmd == 'hide_model_name') {
|
||||
$("#showmodelnamecontainer").addClass("hidden");
|
||||
$(window).off('beforeunload');
|
||||
location.reload();
|
||||
//console.log("Closing window");
|
||||
} else if(msg.cmd == 'model_load_status') {
|
||||
|
@ -473,7 +473,7 @@ body.connected #popupfooter, #popupfooter.always-available {
|
||||
}
|
||||
|
||||
#samplerslist {
|
||||
height: 300px;
|
||||
height: 310px;
|
||||
overflow-y: scroll;
|
||||
overflow-wrap: anywhere;
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
@ -312,6 +312,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non
|
||||
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
@ -497,6 +498,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample_dynamic to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
logits = apply_repetition_penalty_dynamic(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
next_token = kobold_sample_dynamic(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
assert logits.shape == (1, config["n_vocab"])
|
||||
# Flatten it into a 1D array to make it easier to use
|
||||
logits = logits[0]
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
if repetition_penalty is not None:
|
||||
logits = apply_repetition_penalty_static(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
next_token = kobold_sample_static(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
@ -863,6 +858,9 @@ def infer_static(
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
if sampler_order is None:
|
||||
sampler_order = utils.default_sampler_order.copy()
|
||||
sampler_order = sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
sampler_order = np.uint32(sampler_order)
|
||||
total_batch = 1
|
||||
tokens = context
|
||||
|
4
utils.py
4
utils.py
@ -33,7 +33,7 @@ layers_module_names: Optional[List[str]] = None
|
||||
module_names: Optional[List[str]] = None
|
||||
named_buffers: Optional[List[tuple]] = None
|
||||
|
||||
default_sampler_order = [0, 1, 2, 3, 4, 5]
|
||||
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
@ -167,7 +167,7 @@ def decodenewlines(txt):
|
||||
# Returns number of layers given an HF model config
|
||||
#==================================================================#
|
||||
def num_layers(config):
|
||||
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
||||
return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
||||
|
||||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
|
@ -28,10 +28,10 @@ SOFTWARE.
|
||||
'''
|
||||
|
||||
import torch
|
||||
from transformers import LogitsWarper, LogitsProcessor
|
||||
from transformers import LogitsWarper
|
||||
|
||||
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user