mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-03-09 16:10:16 +01:00
Merge pull request #43 from VE-FORBRYDERNE/dynamic-scan-patch
Dynamic scan patch
This commit is contained in:
commit
46b0473229
32
aiserver.py
32
aiserver.py
@ -21,6 +21,7 @@ import collections
|
||||
import zipfile
|
||||
import packaging
|
||||
import contextlib
|
||||
import traceback
|
||||
from typing import Any, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
@ -744,6 +745,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
return True
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
@ -773,7 +779,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
head_length=self.kai_scanner_head_length,
|
||||
)
|
||||
stopping_criteria.append(self.kai_scanner)
|
||||
stopping_criteria.insert(0, self.kai_scanner)
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
@ -1995,17 +2001,17 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt))
|
||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem)
|
||||
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnmem = len(memtokens)
|
||||
|
||||
witokens = tokenizer.encode(winfo)
|
||||
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnwi = len(witokens)
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt)
|
||||
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnanote = len(anotetkns)
|
||||
|
||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||
@ -2034,7 +2040,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
|
||||
if(budget <= 0):
|
||||
break
|
||||
acttkns = tokenizer.encode(chunk)
|
||||
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
|
||||
tknlen = len(acttkns)
|
||||
if(tknlen < budget):
|
||||
tokens = acttkns + tokens
|
||||
@ -2168,7 +2174,7 @@ def calcsubmit(txt):
|
||||
#==================================================================#
|
||||
|
||||
def _generate(txt, minimum, maximum, found_entries):
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
@ -2214,14 +2220,14 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
for r in range(vars.numseqs):
|
||||
for c in range(already_generated):
|
||||
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
@ -2275,7 +2281,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
else:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
@ -2486,7 +2492,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
else:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
@ -2527,8 +2533,8 @@ def getnewcontent(txt):
|
||||
return txt
|
||||
|
||||
# Tokenize the last context and the generated content
|
||||
ctxtokens = tokenizer.encode(vars.lastctx)
|
||||
txttokens = tokenizer.encode(txt)
|
||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
|
||||
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
|
||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||
|
||||
# Remove the context from the returned text
|
||||
|
Loading…
x
Reference in New Issue
Block a user