mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-04-26 15:58:47 +02:00
Merge pull request #50 from VE-FORBRYDERNE/potluck
Chat mode GUI, and Lua and random story generator bug fixes
This commit is contained in:
commit
bbd68020a5
160
aiserver.py
160
aiserver.py
@ -89,7 +89,6 @@ class vars:
|
|||||||
submission = "" # Same as above, but after applying input formatting
|
submission = "" # Same as above, but after applying input formatting
|
||||||
lastctx = "" # The last context submitted to the generator
|
lastctx = "" # The last context submitted to the generator
|
||||||
model = "" # Model ID string chosen at startup
|
model = "" # Model ID string chosen at startup
|
||||||
model_orig = "" # Original model string before being changed by auto model type detection
|
|
||||||
model_type = "" # Model Type (Automatically taken from the model config)
|
model_type = "" # Model Type (Automatically taken from the model config)
|
||||||
noai = False # Runs the script without starting up the transformers pipeline
|
noai = False # Runs the script without starting up the transformers pipeline
|
||||||
aibusy = False # Stops submissions while the AI is working
|
aibusy = False # Stops submissions while the AI is working
|
||||||
@ -124,6 +123,7 @@ class vars:
|
|||||||
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
|
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
|
||||||
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
|
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
|
||||||
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
|
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
|
||||||
|
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
||||||
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
||||||
userscripts = [] # List of userscripts to load
|
userscripts = [] # List of userscripts to load
|
||||||
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||||
@ -158,6 +158,7 @@ class vars:
|
|||||||
saveow = False # Whether or not overwrite confirm has been displayed
|
saveow = False # Whether or not overwrite confirm has been displayed
|
||||||
genseqs = [] # Temporary storage for generated sequences
|
genseqs = [] # Temporary storage for generated sequences
|
||||||
recentback = False # Whether Back button was recently used without Submitting or Retrying after
|
recentback = False # Whether Back button was recently used without Submitting or Retrying after
|
||||||
|
recentrng = None # If a new random game was recently generated without Submitting after, this is the topic used (as a string), otherwise this is None
|
||||||
useprompt = False # Whether to send the full prompt with every submit action
|
useprompt = False # Whether to send the full prompt with every submit action
|
||||||
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
||||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
|
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
|
||||||
@ -194,7 +195,7 @@ def getModelSelection():
|
|||||||
while(vars.model == ''):
|
while(vars.model == ''):
|
||||||
modelsel = input("Model #> ")
|
modelsel = input("Model #> ")
|
||||||
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
|
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
|
||||||
vars.model = vars.model_orig = modellist[int(modelsel)-1][1]
|
vars.model = modellist[int(modelsel)-1][1]
|
||||||
else:
|
else:
|
||||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||||
|
|
||||||
@ -375,7 +376,7 @@ parser.add_argument("--override_rename", action='store_true', help="Renaming sto
|
|||||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vars.model = vars.model_orig = args.model;
|
vars.model = args.model;
|
||||||
|
|
||||||
if args.remote:
|
if args.remote:
|
||||||
vars.remote = True;
|
vars.remote = True;
|
||||||
@ -801,7 +802,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
scores: torch.FloatTensor,
|
scores: torch.FloatTensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
|
vars.generated_tkns += 1
|
||||||
|
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||||
|
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
|
||||||
|
if(vars.generated_tkns >= vars.genamt):
|
||||||
self.regeneration_required = False
|
self.regeneration_required = False
|
||||||
self.halt = False
|
self.halt = False
|
||||||
return True
|
return True
|
||||||
@ -813,7 +817,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
vars.lua_koboldbridge.regeneration_required = False
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||||
|
|
||||||
if(not vars.dynamicscan):
|
if(not vars.dynamicscan):
|
||||||
return self.regeneration_required or self.halt
|
return self.regeneration_required or self.halt
|
||||||
@ -1150,7 +1154,7 @@ def lua_compute_context(submission, entries, folders):
|
|||||||
i += 1
|
i += 1
|
||||||
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
|
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
|
||||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||||
return txt
|
return tokenizer.decode(txt)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get property of a world info entry given its UID and property name
|
# Get property of a world info entry given its UID and property name
|
||||||
@ -1452,6 +1456,8 @@ def lua_is_custommodel():
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
def execute_inmod():
|
def execute_inmod():
|
||||||
vars.lua_logname = ...
|
vars.lua_logname = ...
|
||||||
|
vars.lua_edited = set()
|
||||||
|
vars.lua_deleted = set()
|
||||||
try:
|
try:
|
||||||
tpool.execute(vars.lua_koboldbridge.execute_inmod)
|
tpool.execute(vars.lua_koboldbridge.execute_inmod)
|
||||||
except lupa.LuaError as e:
|
except lupa.LuaError as e:
|
||||||
@ -1465,8 +1471,6 @@ def execute_inmod():
|
|||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
|
||||||
def execute_genmod():
|
def execute_genmod():
|
||||||
vars.lua_edited = set()
|
|
||||||
vars.lua_deleted = set()
|
|
||||||
vars.lua_koboldbridge.execute_genmod()
|
vars.lua_koboldbridge.execute_genmod()
|
||||||
|
|
||||||
def execute_outmod():
|
def execute_outmod():
|
||||||
@ -1606,6 +1610,13 @@ def get_message(msg):
|
|||||||
if(msg['cmd'] == 'submit'):
|
if(msg['cmd'] == 'submit'):
|
||||||
if(vars.mode == "play"):
|
if(vars.mode == "play"):
|
||||||
vars.lua_koboldbridge.feedback = None
|
vars.lua_koboldbridge.feedback = None
|
||||||
|
if(vars.chatmode):
|
||||||
|
if(type(msg['chatname']) is not str):
|
||||||
|
raise ValueError("Chatname must be a string")
|
||||||
|
vars.chatname = msg['chatname']
|
||||||
|
settingschanged()
|
||||||
|
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname}, broadcast=True)
|
||||||
|
vars.recentrng = None
|
||||||
actionsubmit(msg['data'], actionmode=msg['actionmode'])
|
actionsubmit(msg['data'], actionmode=msg['actionmode'])
|
||||||
elif(vars.mode == "edit"):
|
elif(vars.mode == "edit"):
|
||||||
editsubmit(msg['data'])
|
editsubmit(msg['data'])
|
||||||
@ -1613,6 +1624,12 @@ def get_message(msg):
|
|||||||
memsubmit(msg['data'])
|
memsubmit(msg['data'])
|
||||||
# Retry Action
|
# Retry Action
|
||||||
elif(msg['cmd'] == 'retry'):
|
elif(msg['cmd'] == 'retry'):
|
||||||
|
if(vars.chatmode):
|
||||||
|
if(type(msg['chatname']) is not str):
|
||||||
|
raise ValueError("Chatname must be a string")
|
||||||
|
vars.chatname = msg['chatname']
|
||||||
|
settingschanged()
|
||||||
|
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname}, broadcast=True)
|
||||||
actionretry(msg['data'])
|
actionretry(msg['data'])
|
||||||
# Back/Undo Action
|
# Back/Undo Action
|
||||||
elif(msg['cmd'] == 'back'):
|
elif(msg['cmd'] == 'back'):
|
||||||
@ -2056,7 +2073,7 @@ def settingschanged():
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Take input text from SocketIO and decide what to do with it
|
# Take input text from SocketIO and decide what to do with it
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def actionsubmit(data, actionmode=0, force_submit=False):
|
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False):
|
||||||
# Ignore new submissions if the AI is currently busy
|
# Ignore new submissions if the AI is currently busy
|
||||||
if(vars.aibusy):
|
if(vars.aibusy):
|
||||||
return
|
return
|
||||||
@ -2064,6 +2081,9 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
while(True):
|
while(True):
|
||||||
set_aibusy(1)
|
set_aibusy(1)
|
||||||
|
|
||||||
|
if(disable_recentrng):
|
||||||
|
vars.recentrng = None
|
||||||
|
|
||||||
vars.recentback = False
|
vars.recentback = False
|
||||||
vars.recentedit = False
|
vars.recentedit = False
|
||||||
vars.actionmode = actionmode
|
vars.actionmode = actionmode
|
||||||
@ -2093,7 +2113,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
assert False
|
assert False
|
||||||
# Start the game
|
# Start the game
|
||||||
vars.gamestarted = True
|
vars.gamestarted = True
|
||||||
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
|
if(not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
|
||||||
# Save this first action as the prompt
|
# Save this first action as the prompt
|
||||||
vars.prompt = data
|
vars.prompt = data
|
||||||
# Clear the startup text from game screen
|
# Clear the startup text from game screen
|
||||||
@ -2102,6 +2122,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
break
|
break
|
||||||
@ -2122,6 +2143,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
refresh_story()
|
refresh_story()
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
@ -2129,6 +2151,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
refresh_story()
|
refresh_story()
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
genselect(genout)
|
genselect(genout)
|
||||||
refresh_story()
|
refresh_story()
|
||||||
@ -2157,6 +2180,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
break
|
break
|
||||||
@ -2174,12 +2198,14 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
if(vars.lua_koboldbridge.restart_sequence is not None):
|
if(vars.lua_koboldbridge.restart_sequence is not None):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
genselect(genout)
|
genselect(genout)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
@ -2195,6 +2221,9 @@ def actionretry(data):
|
|||||||
return
|
return
|
||||||
if(vars.aibusy):
|
if(vars.aibusy):
|
||||||
return
|
return
|
||||||
|
if(vars.recentrng is not None):
|
||||||
|
randomGameRequest(vars.recentrng)
|
||||||
|
return
|
||||||
# Remove last action if possible and resubmit
|
# Remove last action if possible and resubmit
|
||||||
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
|
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
|
||||||
if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions
|
if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions
|
||||||
@ -2246,38 +2275,53 @@ def calcsubmitbudgetheader(txt, **kwargs):
|
|||||||
|
|
||||||
return winfo, mem, anotetxt, found_entries
|
return winfo, mem, anotetxt, found_entries
|
||||||
|
|
||||||
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, budget_deduction=0):
|
||||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||||
anotetkns = [] # Placeholder for Author's Note tokens
|
anotetkns = [] # Placeholder for Author's Note tokens
|
||||||
lnanote = 0 # Placeholder for Author's Note length
|
lnanote = 0 # Placeholder for Author's Note length
|
||||||
|
|
||||||
# Calculate token budget
|
|
||||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
|
|
||||||
lnprompt = len(prompttkns)
|
|
||||||
|
|
||||||
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
|
|
||||||
lnmem = len(memtokens)
|
|
||||||
|
|
||||||
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
|
|
||||||
lnwi = len(witokens)
|
|
||||||
|
|
||||||
if(anotetxt != ""):
|
|
||||||
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
|
|
||||||
lnanote = len(anotetkns)
|
|
||||||
|
|
||||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||||
|
|
||||||
|
# Calculate token budget
|
||||||
|
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True)
|
||||||
|
lnprompt = len(prompttkns)
|
||||||
|
|
||||||
|
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True)
|
||||||
|
lnmem = len(memtokens)
|
||||||
|
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
|
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
|
||||||
|
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True)
|
||||||
|
lnwi = len(witokens)
|
||||||
|
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
|
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
|
||||||
|
if(anotetxt != ""):
|
||||||
|
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True)
|
||||||
|
lnanote = len(anotetkns)
|
||||||
|
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
|
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
|
||||||
if(vars.useprompt):
|
if(vars.useprompt):
|
||||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||||
else:
|
else:
|
||||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||||
|
|
||||||
|
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
||||||
|
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
|
||||||
|
|
||||||
|
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
|
raise OverflowError("Your submission is too long. Please either write a shorter submission or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. If you are using the Always Add Prompt setting, turning it off may help.")
|
||||||
|
|
||||||
|
assert budget >= 0
|
||||||
|
|
||||||
if(actionlen == 0):
|
if(actionlen == 0):
|
||||||
# First/Prompt action
|
# First/Prompt action
|
||||||
subtxt = vars.memory + winfo + anotetxt + vars.comregex_ai.sub('', vars.prompt)
|
tokens = memtokens + witokens + anotetkns + prompttkns
|
||||||
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
|
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
|
||||||
return subtxt, lnsub+1, lnsub+vars.genamt
|
ln = len(tokens) + lnsp
|
||||||
|
return tokens, ln+1, ln+vars.genamt
|
||||||
else:
|
else:
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
||||||
@ -2290,9 +2334,10 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
|||||||
for key in reversed(actions):
|
for key in reversed(actions):
|
||||||
chunk = vars.comregex_ai.sub('', actions[key])
|
chunk = vars.comregex_ai.sub('', actions[key])
|
||||||
|
|
||||||
|
assert budget >= 0
|
||||||
if(budget <= 0):
|
if(budget <= 0):
|
||||||
break
|
break
|
||||||
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
|
acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True)
|
||||||
tknlen = len(acttkns)
|
tknlen = len(acttkns)
|
||||||
if(tknlen < budget):
|
if(tknlen < budget):
|
||||||
tokens = acttkns + tokens
|
tokens = acttkns + tokens
|
||||||
@ -2317,7 +2362,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
|||||||
prompttkns = prompttkns[-budget:]
|
prompttkns = prompttkns[-budget:]
|
||||||
else:
|
else:
|
||||||
prompttkns = []
|
prompttkns = []
|
||||||
|
|
||||||
# Did we get to add the A.N.? If not, do it here
|
# Did we get to add the A.N.? If not, do it here
|
||||||
if(anotetxt != ""):
|
if(anotetxt != ""):
|
||||||
if((not anoteadded) or forceanote):
|
if((not anoteadded) or forceanote):
|
||||||
@ -2327,10 +2372,11 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
|||||||
else:
|
else:
|
||||||
# Prepend Memory, WI, and Prompt before action tokens
|
# Prepend Memory, WI, and Prompt before action tokens
|
||||||
tokens = memtokens + witokens + prompttkns + tokens
|
tokens = memtokens + witokens + prompttkns + tokens
|
||||||
|
|
||||||
# Send completed bundle to generator
|
# Send completed bundle to generator
|
||||||
|
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
|
||||||
ln = len(tokens) + lnsp
|
ln = len(tokens) + lnsp
|
||||||
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
|
return tokens, ln+1, ln+vars.genamt
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Take submitted text and build the text to be given to generator
|
# Take submitted text and build the text to be given to generator
|
||||||
@ -2345,23 +2391,23 @@ def calcsubmit(txt):
|
|||||||
|
|
||||||
# For all transformers models
|
# For all transformers models
|
||||||
if(vars.model != "InferKit"):
|
if(vars.model != "InferKit"):
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
||||||
if(actionlen == 0):
|
if(actionlen == 0):
|
||||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(subtxt, min, max)
|
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(subtxt, min, max)
|
oairequest(tokenizer.decode(subtxt), min, max)
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
else:
|
else:
|
||||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(subtxt, min, max)
|
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(subtxt, min, max)
|
oairequest(tokenizer.decode(subtxt), min, max)
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
|
||||||
@ -2426,13 +2472,14 @@ def calcsubmit(txt):
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
def _generate(txt, minimum, maximum, found_entries):
|
def _generate(txt, minimum, maximum, found_entries):
|
||||||
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
|
gen_in = torch.tensor(txt, dtype=torch.long)[None]
|
||||||
if(vars.sp is not None):
|
if(vars.sp is not None):
|
||||||
soft_tokens = torch.arange(
|
soft_tokens = torch.arange(
|
||||||
model.config.vocab_size,
|
model.config.vocab_size,
|
||||||
model.config.vocab_size + vars.sp.shape[0],
|
model.config.vocab_size + vars.sp.shape[0],
|
||||||
)
|
)
|
||||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||||
|
assert gen_in.shape[-1] + vars.genamt <= vars.max_length
|
||||||
|
|
||||||
if(vars.hascuda and vars.usegpu):
|
if(vars.hascuda and vars.usegpu):
|
||||||
gen_in = gen_in.to(vars.gpu_device)
|
gen_in = gen_in.to(vars.gpu_device)
|
||||||
@ -2464,11 +2511,14 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
num_return_sequences=numseqs
|
num_return_sequences=numseqs
|
||||||
)
|
)
|
||||||
already_generated += len(genout[0]) - len(gen_in[0])
|
already_generated += len(genout[0]) - len(gen_in[0])
|
||||||
|
assert already_generated <= vars.genamt
|
||||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||||
break
|
break
|
||||||
assert genout.ndim >= 2
|
assert genout.ndim >= 2
|
||||||
assert genout.shape[0] == vars.numseqs
|
assert genout.shape[0] == vars.numseqs
|
||||||
if(already_generated != vars.lua_koboldbridge.generated_cols):
|
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||||
|
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||||
|
if(already_generated != vars.generated_tkns):
|
||||||
raise RuntimeError("WI scanning error")
|
raise RuntimeError("WI scanning error")
|
||||||
for r in range(vars.numseqs):
|
for r in range(vars.numseqs):
|
||||||
for c in range(already_generated):
|
for c in range(already_generated):
|
||||||
@ -2479,8 +2529,8 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions)
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
|
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||||
max_length = len(max(encoded, key=len))
|
max_length = len(max(encoded, key=len))
|
||||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||||
genout = torch.cat(
|
genout = torch.cat(
|
||||||
@ -2497,6 +2547,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
device=genout.device,
|
device=genout.device,
|
||||||
)
|
)
|
||||||
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
||||||
|
assert genout.shape[-1] + vars.genamt - already_generated <= vars.max_length
|
||||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||||
minimum += diff
|
minimum += diff
|
||||||
maximum += diff
|
maximum += diff
|
||||||
@ -2508,14 +2559,16 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
|
|
||||||
|
|
||||||
def generate(txt, minimum, maximum, found_entries=None):
|
def generate(txt, minimum, maximum, found_entries=None):
|
||||||
|
vars.generated_tkns = 0
|
||||||
|
|
||||||
if(found_entries is None):
|
if(found_entries is None):
|
||||||
found_entries = set()
|
found_entries = set()
|
||||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
vars.lastctx = txt
|
vars.lastctx = tokenizer.decode(txt)
|
||||||
|
|
||||||
# Clear CUDA cache if using GPU
|
# Clear CUDA cache if using GPU
|
||||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||||
@ -2541,7 +2594,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item()
|
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
|
||||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
@ -2624,7 +2677,7 @@ def selectsequence(n):
|
|||||||
vars.genseqs = []
|
vars.genseqs = []
|
||||||
|
|
||||||
if(vars.lua_koboldbridge.restart_sequence is not None):
|
if(vars.lua_koboldbridge.restart_sequence is not None):
|
||||||
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
|
actionsubmit("", actionmode=vars.actionmode, force_submit=True, disable_recentrng=True)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send transformers-style request to ngrok/colab host
|
# Send transformers-style request to ngrok/colab host
|
||||||
@ -2712,7 +2765,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
found_entries = set()
|
found_entries = set()
|
||||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||||
|
|
||||||
# Submit input text to generator
|
# Submit input text to generator
|
||||||
try:
|
try:
|
||||||
@ -2742,7 +2795,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
|
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
tpu_mtj_backend.infer,
|
tpu_mtj_backend.infer,
|
||||||
txt,
|
np.uint32(txt),
|
||||||
gen_len = maximum-minimum+1,
|
gen_len = maximum-minimum+1,
|
||||||
temp=vars.temp,
|
temp=vars.temp,
|
||||||
top_p=vars.top_p,
|
top_p=vars.top_p,
|
||||||
@ -2809,8 +2862,8 @@ def getnewcontent(txt):
|
|||||||
return txt
|
return txt
|
||||||
|
|
||||||
# Tokenize the last context and the generated content
|
# Tokenize the last context and the generated content
|
||||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
|
ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True)
|
||||||
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
|
txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True)
|
||||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||||
|
|
||||||
# Remove the context from the returned text
|
# Remove the context from the returned text
|
||||||
@ -4126,10 +4179,11 @@ def newGameRequest():
|
|||||||
setStartState()
|
setStartState()
|
||||||
|
|
||||||
def randomGameRequest(topic):
|
def randomGameRequest(topic):
|
||||||
|
vars.recentrng = topic
|
||||||
newGameRequest()
|
newGameRequest()
|
||||||
vars.memory = "You generate the following " + topic + " story concept :"
|
vars.memory = "You generate the following " + topic + " story concept :"
|
||||||
vars.lua_koboldbridge.feedback = None
|
vars.lua_koboldbridge.feedback = None
|
||||||
actionsubmit("", force_submit=True)
|
actionsubmit("", force_submit=True, force_prompt_gen=True)
|
||||||
vars.memory = ""
|
vars.memory = ""
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
52
bridge.lua
52
bridge.lua
@ -888,6 +888,8 @@ return function(_python, _bridged)
|
|||||||
---@field rmspch boolean
|
---@field rmspch boolean
|
||||||
---@field adsnsp boolean
|
---@field adsnsp boolean
|
||||||
---@field singleline boolean
|
---@field singleline boolean
|
||||||
|
---@field chatmode boolean
|
||||||
|
---@field chatname string
|
||||||
local KoboldSettings = setmetatable({
|
local KoboldSettings = setmetatable({
|
||||||
_name = "KoboldSettings",
|
_name = "KoboldSettings",
|
||||||
}, metawrapper)
|
}, metawrapper)
|
||||||
@ -1038,7 +1040,7 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
---@return string
|
---@return string
|
||||||
function KoboldLib_getters.model(t)
|
function KoboldLib_getters.model(t)
|
||||||
return bridged.vars.model_orig
|
return bridged.vars.model
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
@ -1526,7 +1528,7 @@ return function(_python, _bridged)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local old_loadfile = loadfile
|
local old_loadfile = loadfile
|
||||||
local old_package_loaded = package.loaded
|
local package_loaded = {} ---@type table<table, table>
|
||||||
local old_package_searchers = package.searchers
|
local old_package_searchers = package.searchers
|
||||||
---@param modname string
|
---@param modname string
|
||||||
---@param env table<string, any>
|
---@param env table<string, any>
|
||||||
@ -1546,8 +1548,10 @@ return function(_python, _bridged)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
local allowsearch = type(modname) == "string" and string.match(modname, "[^%w._-]") == nil and string.match(modname, "%.%.") == nil
|
local allowsearch = type(modname) == "string" and string.match(modname, "[^%w._-]") == nil and string.match(modname, "%.%.") == nil
|
||||||
if allowsearch and old_package_loaded[modname] then
|
if allowsearch and package_loaded[env] == nil then
|
||||||
return old_package_loaded[modname]
|
package_loaded[env] = {}
|
||||||
|
elseif allowsearch and package_loaded[env][modname] then
|
||||||
|
return package_loaded[env][modname]
|
||||||
end
|
end
|
||||||
local loader, path
|
local loader, path
|
||||||
local errors = {}
|
local errors = {}
|
||||||
@ -1568,8 +1572,8 @@ return function(_python, _bridged)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
local retval = old_loadfile(path, "t", env)()
|
local retval = old_loadfile(path, "t", env)()
|
||||||
old_package_loaded[modname] = retval == nil or retval
|
package_loaded[env][modname] = retval == nil or retval
|
||||||
return old_package_loaded[modname], path
|
return package_loaded[env][modname], path
|
||||||
end
|
end
|
||||||
local function _safe_require(_g)
|
local function _safe_require(_g)
|
||||||
---@param modname string
|
---@param modname string
|
||||||
@ -1579,6 +1583,36 @@ return function(_python, _bridged)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local old_input = io.input
|
||||||
|
---@param file? string|file*
|
||||||
|
local function safe_input(file)
|
||||||
|
if type(file) == "string" then
|
||||||
|
error("Calling `io.input` with a string as argument is disabled for security reasons")
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return old_input(file)
|
||||||
|
end
|
||||||
|
|
||||||
|
local old_output = io.output
|
||||||
|
---@param file? string|file*
|
||||||
|
local function safe_output(file)
|
||||||
|
if type(file) == "string" then
|
||||||
|
error("Calling `io.output` with a string as argument is disabled for security reasons")
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return old_output(file)
|
||||||
|
end
|
||||||
|
|
||||||
|
local old_lines = io.lines
|
||||||
|
---@param filename? string
|
||||||
|
local function safe_lines(filename, ...)
|
||||||
|
if type(filename) == "string" then
|
||||||
|
error("Calling `io.lines` with a string as first argument is disabled for security reasons")
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return old_lines(filename, ...)
|
||||||
|
end
|
||||||
|
|
||||||
local function redirected_print(...)
|
local function redirected_print(...)
|
||||||
local args = table.pack(...)
|
local args = table.pack(...)
|
||||||
for i = 1, args.n do
|
for i = 1, args.n do
|
||||||
@ -1709,12 +1743,12 @@ return function(_python, _bridged)
|
|||||||
stdin = io.stdin,
|
stdin = io.stdin,
|
||||||
stdout = io.stdout,
|
stdout = io.stdout,
|
||||||
stderr = io.stderr,
|
stderr = io.stderr,
|
||||||
input = io.input,
|
input = safe_input,
|
||||||
output = io.output,
|
output = safe_output,
|
||||||
read = io.read,
|
read = io.read,
|
||||||
write = io.write,
|
write = io.write,
|
||||||
close = _new_close(io.close),
|
close = _new_close(io.close),
|
||||||
lines = io.lines,
|
lines = safe_lines,
|
||||||
flush = io.flush,
|
flush = io.flush,
|
||||||
type = io.type,
|
type = io.type,
|
||||||
},
|
},
|
||||||
|
@ -30,6 +30,7 @@ var button_actwi;
|
|||||||
var game_text;
|
var game_text;
|
||||||
var input_text;
|
var input_text;
|
||||||
var message_text;
|
var message_text;
|
||||||
|
var chat_name;
|
||||||
var settings_menu;
|
var settings_menu;
|
||||||
var format_menu;
|
var format_menu;
|
||||||
var wi_menu;
|
var wi_menu;
|
||||||
@ -722,6 +723,7 @@ function exitEditMode() {
|
|||||||
function enterMemoryMode() {
|
function enterMemoryMode() {
|
||||||
memorymode = true;
|
memorymode = true;
|
||||||
setmodevisibility(false);
|
setmodevisibility(false);
|
||||||
|
setchatnamevisibility(false);
|
||||||
showMessage("Edit the memory to be sent with each request to the AI.");
|
showMessage("Edit the memory to be sent with each request to the AI.");
|
||||||
button_actmem.html("Cancel");
|
button_actmem.html("Cancel");
|
||||||
hide([button_actback, button_actretry, button_actwi]);
|
hide([button_actback, button_actretry, button_actwi]);
|
||||||
@ -732,6 +734,7 @@ function enterMemoryMode() {
|
|||||||
function exitMemoryMode() {
|
function exitMemoryMode() {
|
||||||
memorymode = false;
|
memorymode = false;
|
||||||
setmodevisibility(adventure);
|
setmodevisibility(adventure);
|
||||||
|
setchatnamevisibility(chatmode);
|
||||||
hideMessage();
|
hideMessage();
|
||||||
button_actmem.html("Memory");
|
button_actmem.html("Memory");
|
||||||
show([button_actback, button_actretry, button_actwi]);
|
show([button_actback, button_actretry, button_actwi]);
|
||||||
@ -744,6 +747,7 @@ function enterWiMode() {
|
|||||||
showMessage("World Info will be added to memory only when the key appears in submitted text or the last action.");
|
showMessage("World Info will be added to memory only when the key appears in submitted text or the last action.");
|
||||||
button_actwi.html("Accept");
|
button_actwi.html("Accept");
|
||||||
hide([button_actback, button_actmem, button_actretry, game_text]);
|
hide([button_actback, button_actmem, button_actretry, game_text]);
|
||||||
|
setchatnamevisibility(false);
|
||||||
show([wi_menu]);
|
show([wi_menu]);
|
||||||
disableSendBtn();
|
disableSendBtn();
|
||||||
$("#gamescreen").addClass("wigamescreen");
|
$("#gamescreen").addClass("wigamescreen");
|
||||||
@ -753,6 +757,7 @@ function exitWiMode() {
|
|||||||
hideMessage();
|
hideMessage();
|
||||||
button_actwi.html("W Info");
|
button_actwi.html("W Info");
|
||||||
hide([wi_menu]);
|
hide([wi_menu]);
|
||||||
|
setchatnamevisibility(chatmode);
|
||||||
show([button_actback, button_actmem, button_actretry, game_text]);
|
show([button_actback, button_actmem, button_actretry, game_text]);
|
||||||
enableSendBtn();
|
enableSendBtn();
|
||||||
$("#gamescreen").removeClass("wigamescreen");
|
$("#gamescreen").removeClass("wigamescreen");
|
||||||
@ -797,7 +802,7 @@ function dosubmit() {
|
|||||||
input_text.val("");
|
input_text.val("");
|
||||||
hideMessage();
|
hideMessage();
|
||||||
hidegenseqs();
|
hidegenseqs();
|
||||||
socket.send({'cmd': 'submit', 'actionmode': adventure ? action_mode : 0, 'data': txt});
|
socket.send({'cmd': 'submit', 'actionmode': adventure ? action_mode : 0, 'chatname': chatmode ? chat_name.val() : undefined, 'data': txt});
|
||||||
if(memorymode) {
|
if(memorymode) {
|
||||||
memorytext = input_text.val();
|
memorytext = input_text.val();
|
||||||
}
|
}
|
||||||
@ -1155,6 +1160,14 @@ function setmodevisibility(state) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function setchatnamevisibility(state) {
|
||||||
|
if(state){ // Enabling
|
||||||
|
show([chat_name]);
|
||||||
|
} else{ // Disabling
|
||||||
|
hide([chat_name]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function setadventure(state) {
|
function setadventure(state) {
|
||||||
adventure = state;
|
adventure = state;
|
||||||
if(state) {
|
if(state) {
|
||||||
@ -1169,6 +1182,7 @@ function setadventure(state) {
|
|||||||
|
|
||||||
function setchatmode(state) {
|
function setchatmode(state) {
|
||||||
chatmode = state;
|
chatmode = state;
|
||||||
|
setchatnamevisibility(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
function autofocus(event) {
|
function autofocus(event) {
|
||||||
@ -1706,6 +1720,7 @@ $(document).ready(function(){
|
|||||||
game_text = $('#gametext');
|
game_text = $('#gametext');
|
||||||
input_text = $('#input_text');
|
input_text = $('#input_text');
|
||||||
message_text = $('#messagefield');
|
message_text = $('#messagefield');
|
||||||
|
chat_name = $('#chatname');
|
||||||
settings_menu = $("#settingsmenu");
|
settings_menu = $("#settingsmenu");
|
||||||
format_menu = $('#formatmenu');
|
format_menu = $('#formatmenu');
|
||||||
anote_menu = $('#anoterowcontainer');
|
anote_menu = $('#anoterowcontainer');
|
||||||
@ -2130,6 +2145,8 @@ $(document).ready(function(){
|
|||||||
} else if(msg.cmd == "hidegenseqs") {
|
} else if(msg.cmd == "hidegenseqs") {
|
||||||
// Collapse genseqs menu
|
// Collapse genseqs menu
|
||||||
hidegenseqs();
|
hidegenseqs();
|
||||||
|
} else if(msg.cmd == "setchatname") {
|
||||||
|
chat_name.val(msg.data);
|
||||||
} else if(msg.cmd == "setlabelnumseq") {
|
} else if(msg.cmd == "setlabelnumseq") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#setnumseqcur").html(msg.data);
|
$("#setnumseqcur").html(msg.data);
|
||||||
@ -2234,7 +2251,7 @@ $(document).ready(function(){
|
|||||||
|
|
||||||
button_actretry.on("click", function(ev) {
|
button_actretry.on("click", function(ev) {
|
||||||
hideMessage();
|
hideMessage();
|
||||||
socket.send({'cmd': 'retry', 'data': ''});
|
socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''});
|
||||||
hidegenseqs();
|
hidegenseqs();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -32,6 +32,13 @@ chunk.editing, chunk.editing * {
|
|||||||
display: flex;
|
display: flex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#chatname {
|
||||||
|
background-color: #404040;
|
||||||
|
color: #ffffff;
|
||||||
|
width: 200px;
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
#menuitems {
|
#menuitems {
|
||||||
display: flex;
|
display: flex;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
|
@ -10,12 +10,12 @@
|
|||||||
<script src="static/bootstrap.min.js"></script>
|
<script src="static/bootstrap.min.js"></script>
|
||||||
<script src="static/bootstrap-toggle.min.js"></script>
|
<script src="static/bootstrap-toggle.min.js"></script>
|
||||||
<script src="static/rangy-core.min.js"></script>
|
<script src="static/rangy-core.min.js"></script>
|
||||||
<script src="static/application.js?ver=1.16.4m"></script>
|
<script src="static/application.js?ver=1.16.4n"></script>
|
||||||
|
|
||||||
<link rel="stylesheet" href="static/jquery-ui.sortable.min.css">
|
<link rel="stylesheet" href="static/jquery-ui.sortable.min.css">
|
||||||
<link rel="stylesheet" href="static/bootstrap.min.css">
|
<link rel="stylesheet" href="static/bootstrap.min.css">
|
||||||
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
||||||
<link rel="stylesheet" href="static/custom.css?ver=1.16.4g">
|
<link rel="stylesheet" href="static/custom.css?ver=1.16.4h">
|
||||||
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
@ -124,6 +124,7 @@
|
|||||||
<button type="button" class="btn btn-primary" id="btn_actundo">Back</button>
|
<button type="button" class="btn btn-primary" id="btn_actundo">Back</button>
|
||||||
<button type="button" class="btn btn-primary" id="btn_actretry">Retry</button>
|
<button type="button" class="btn btn-primary" id="btn_actretry">Retry</button>
|
||||||
</div>
|
</div>
|
||||||
|
<input type="text" id="chatname" class="form-control hidden" placeholder="Chat name">
|
||||||
<div id="messagefield"></div>
|
<div id="messagefield"></div>
|
||||||
<div class="box flex-push-right">
|
<div class="box flex-push-right">
|
||||||
<input type="checkbox" data-toggle="toggle" data-onstyle="success" id="allowediting" disabled>
|
<input type="checkbox" data-toggle="toggle" data-onstyle="success" id="allowediting" disabled>
|
||||||
|
@ -268,7 +268,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
|
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
context: str,
|
context: np.array,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
temp=0.5,
|
temp=0.5,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
@ -281,7 +281,7 @@ def infer(
|
|||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
total_batch = 1
|
total_batch = 1
|
||||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
tokens = context
|
||||||
if(soft_tokens is not None):
|
if(soft_tokens is not None):
|
||||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||||
provided_ctx = tokens.shape[0]
|
provided_ctx = tokens.shape[0]
|
||||||
|
12
utils.py
12
utils.py
@ -73,13 +73,15 @@ def addsentencespacing(txt, vars):
|
|||||||
# Get last character of last action
|
# Get last character of last action
|
||||||
if(len(vars.actions) > 0):
|
if(len(vars.actions) > 0):
|
||||||
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
||||||
lastchar = vars.actions[vars.actions.get_last_key()][-1]
|
action = vars.actions[vars.actions.get_last_key()]
|
||||||
|
lastchar = action[-1] if len(action) else ""
|
||||||
else:
|
else:
|
||||||
# Last action is blank, this should never happen, but
|
# Last action is blank, this should never happen, but
|
||||||
# since it did let's bail out.
|
# since it did let's bail out.
|
||||||
return txt
|
return txt
|
||||||
else:
|
else:
|
||||||
lastchar = vars.prompt[-1]
|
action = vars.prompt
|
||||||
|
lastchar = action[-1] if len(action) else ""
|
||||||
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):
|
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):
|
||||||
txt = " " + txt
|
txt = " " + txt
|
||||||
return txt
|
return txt
|
||||||
@ -88,13 +90,15 @@ def singlelineprocessing(txt, vars):
|
|||||||
txt = vars.regex_sl.sub('', txt)
|
txt = vars.regex_sl.sub('', txt)
|
||||||
if(len(vars.actions) > 0):
|
if(len(vars.actions) > 0):
|
||||||
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
||||||
lastchar = vars.actions[vars.actions.get_last_key()][-1]
|
action = vars.actions[vars.actions.get_last_key()]
|
||||||
|
lastchar = action[-1] if len(action) else ""
|
||||||
else:
|
else:
|
||||||
# Last action is blank, this should never happen, but
|
# Last action is blank, this should never happen, but
|
||||||
# since it did let's bail out.
|
# since it did let's bail out.
|
||||||
return txt
|
return txt
|
||||||
else:
|
else:
|
||||||
lastchar = vars.prompt[-1]
|
action = vars.prompt
|
||||||
|
lastchar = action[-1] if len(action) else ""
|
||||||
if(lastchar != "\n"):
|
if(lastchar != "\n"):
|
||||||
txt = txt + "\n"
|
txt = txt + "\n"
|
||||||
return txt
|
return txt
|
||||||
|
Loading…
x
Reference in New Issue
Block a user