mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #43 from VE-FORBRYDERNE/dynamic-scan-patch
Dynamic scan patch
This commit is contained in:
32
aiserver.py
32
aiserver.py
@ -21,6 +21,7 @@ import collections
|
|||||||
import zipfile
|
import zipfile
|
||||||
import packaging
|
import packaging
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import traceback
|
||||||
from typing import Any, Union, Dict, Set, List
|
from typing import Any, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -744,6 +745,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
scores: torch.FloatTensor,
|
scores: torch.FloatTensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
|
||||||
|
self.regeneration_required = False
|
||||||
|
self.halt = False
|
||||||
|
return True
|
||||||
|
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||||
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
|
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||||
@ -773,7 +779,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||||
head_length=self.kai_scanner_head_length,
|
head_length=self.kai_scanner_head_length,
|
||||||
)
|
)
|
||||||
stopping_criteria.append(self.kai_scanner)
|
stopping_criteria.insert(0, self.kai_scanner)
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
@ -1995,17 +2001,17 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
|||||||
lnanote = 0 # Placeholder for Author's Note length
|
lnanote = 0 # Placeholder for Author's Note length
|
||||||
|
|
||||||
# Calculate token budget
|
# Calculate token budget
|
||||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt))
|
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
|
||||||
lnprompt = len(prompttkns)
|
lnprompt = len(prompttkns)
|
||||||
|
|
||||||
memtokens = tokenizer.encode(mem)
|
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
|
||||||
lnmem = len(memtokens)
|
lnmem = len(memtokens)
|
||||||
|
|
||||||
witokens = tokenizer.encode(winfo)
|
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
|
||||||
lnwi = len(witokens)
|
lnwi = len(witokens)
|
||||||
|
|
||||||
if(anotetxt != ""):
|
if(anotetxt != ""):
|
||||||
anotetkns = tokenizer.encode(anotetxt)
|
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
|
||||||
lnanote = len(anotetkns)
|
lnanote = len(anotetkns)
|
||||||
|
|
||||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||||
@ -2034,7 +2040,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
|||||||
|
|
||||||
if(budget <= 0):
|
if(budget <= 0):
|
||||||
break
|
break
|
||||||
acttkns = tokenizer.encode(chunk)
|
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
|
||||||
tknlen = len(acttkns)
|
tknlen = len(acttkns)
|
||||||
if(tknlen < budget):
|
if(tknlen < budget):
|
||||||
tokens = acttkns + tokens
|
tokens = acttkns + tokens
|
||||||
@ -2168,7 +2174,7 @@ def calcsubmit(txt):
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
def _generate(txt, minimum, maximum, found_entries):
|
def _generate(txt, minimum, maximum, found_entries):
|
||||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
|
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
|
||||||
if(vars.sp is not None):
|
if(vars.sp is not None):
|
||||||
soft_tokens = torch.arange(
|
soft_tokens = torch.arange(
|
||||||
model.config.vocab_size,
|
model.config.vocab_size,
|
||||||
@ -2214,14 +2220,14 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
for r in range(vars.numseqs):
|
for r in range(vars.numseqs):
|
||||||
for c in range(already_generated):
|
for c in range(already_generated):
|
||||||
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||||
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
|
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
|
||||||
max_length = len(max(encoded, key=len))
|
max_length = len(max(encoded, key=len))
|
||||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||||
genout = torch.cat(
|
genout = torch.cat(
|
||||||
@ -2275,7 +2281,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -2486,7 +2492,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -2527,8 +2533,8 @@ def getnewcontent(txt):
|
|||||||
return txt
|
return txt
|
||||||
|
|
||||||
# Tokenize the last context and the generated content
|
# Tokenize the last context and the generated content
|
||||||
ctxtokens = tokenizer.encode(vars.lastctx)
|
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
|
||||||
txttokens = tokenizer.encode(txt)
|
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
|
||||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||||
|
|
||||||
# Remove the context from the returned text
|
# Remove the context from the returned text
|
||||||
|
Reference in New Issue
Block a user