2021-11-14 03:13:52 +01:00
#!/usr/bin/python3
#==================================================================#
# KoboldAI
2022-05-28 19:39:05 +02:00
# Version: 1.18.0
2021-11-14 03:13:52 +01:00
# By: KoboldAIDev and the KoboldAI Community
#==================================================================#
# External packages
2021-12-13 23:19:04 +01:00
import eventlet
2022-01-07 19:47:21 +01:00
eventlet . monkey_patch ( all = True , thread = False )
2021-11-14 03:13:52 +01:00
import os
2021-12-26 22:46:15 +01:00
os . system ( " " )
2022-03-13 07:22:11 +01:00
__file__ = os . path . dirname ( os . path . realpath ( __file__ ) )
os . chdir ( __file__ )
2022-01-20 08:10:57 +01:00
os . environ [ ' EVENTLET_THREADPOOL_SIZE ' ] = ' 1 '
2022-02-01 18:58:48 +01:00
os . environ [ ' TOKENIZERS_PARALLELISM ' ] = ' false '
2021-12-13 23:19:04 +01:00
from eventlet import tpool
2022-05-13 23:00:10 +02:00
import logging
logging . getLogger ( " urllib3 " ) . setLevel ( logging . ERROR )
2021-11-14 03:13:52 +01:00
from os import path , getcwd
2022-01-17 03:09:10 +01:00
import time
2021-11-14 03:13:52 +01:00
import re
import json
import collections
import zipfile
2021-11-26 00:09:25 +01:00
import packaging
2022-05-29 05:44:19 +02:00
import packaging . version
2021-11-26 16:55:52 +01:00
import contextlib
2021-12-15 08:03:08 +01:00
import traceback
2022-01-07 19:47:21 +01:00
import threading
2022-01-30 19:47:30 +01:00
import markdown
import bleach
2022-03-04 06:33:10 +01:00
import itertools
import bisect
2022-04-09 01:40:12 +02:00
import functools
2022-01-24 21:30:38 +01:00
from collections . abc import Iterable
2022-01-15 03:39:02 +01:00
from typing import Any , Callable , TypeVar , Tuple , Union , Dict , Set , List
2021-11-14 03:13:52 +01:00
import requests
import html
import argparse
import sys
import gc
2021-12-11 18:45:45 +01:00
import lupa
2022-06-09 00:42:44 +02:00
import importlib
2021-12-11 18:45:45 +01:00
2021-11-14 03:13:52 +01:00
# KoboldAI
import fileops
import gensettings
from utils import debounce
import utils
import structures
2022-06-06 15:49:46 +02:00
import torch
2022-06-09 00:42:44 +02:00
from transformers import StoppingCriteria , GPT2TokenizerFast , GPT2LMHeadModel , GPTNeoForCausalLM , GPTNeoModel , AutoModelForCausalLM , AutoTokenizer , PreTrainedModel , modeling_utils
from transformers import __version__ as transformers_version
import transformers
try :
from transformers . models . opt . modeling_opt import OPTDecoder
except :
pass
import transformers . generation_utils
2022-06-07 15:05:51 +02:00
global tpu_mtj_backend
2021-11-14 03:13:52 +01:00
2021-12-11 18:45:45 +01:00
if lupa . LUA_VERSION [ : 2 ] != ( 5 , 4 ) :
print ( f " Please install lupa==1.10. You have lupa { lupa . __version__ } . " , file = sys . stderr )
2022-06-09 00:42:44 +02:00
patch_causallm_patched = False
2021-12-11 18:45:45 +01:00
2022-05-13 23:37:45 +02:00
# Make sure tqdm progress bars display properly in Colab
from tqdm . auto import tqdm
old_init = tqdm . __init__
def new_init ( self , * args , * * kwargs ) :
old_init ( self , * args , * * kwargs )
if ( self . ncols == 0 and kwargs . get ( " ncols " ) != 0 ) :
self . ncols = 99
tqdm . __init__ = new_init
2021-11-14 03:13:52 +01:00
#==================================================================#
# Variables & Storage
#==================================================================#
# Terminal tags for colored text
class colors :
PURPLE = ' \033 [95m '
BLUE = ' \033 [94m '
CYAN = ' \033 [96m '
GREEN = ' \033 [92m '
YELLOW = ' \033 [93m '
RED = ' \033 [91m '
END = ' \033 [0m '
UNDERLINE = ' \033 [4m '
2022-06-06 15:49:46 +02:00
# AI models Menu
# This is a dict of lists where they key is the menu name, and the list is the menu items.
# Each item takes the 4 elements, 1: Text to display, 2: Model Name (var.model) or menu name (Key name for another menu),
# 3: the memory requirement for the model, 4: if the item is a menu or not (True/False)
model_menu = {
' mainmenu ' : [
2022-03-07 01:51:35 +01:00
[ " Load a model from its directory " , " NeoCustom " , " " , False ] ,
[ " Load an old GPT-2 model (eg CloverEdition) " , " GPT2Custom " , " " , False ] ,
2022-06-06 15:49:46 +02:00
[ " Adventure Models " , " adventurelist " , " " , True ] ,
[ " Novel Models " , " novellist " , " " , True ] ,
[ " NSFW Models " , " nsfwlist " , " " , True ] ,
[ " Chatbot Models " , " chatlist " , " " , True ] ,
2022-03-07 01:51:35 +01:00
[ " Untuned GPT-Neo/J " , " gptneolist " , " " , True ] ,
[ " Untuned Fairseq Dense " , " fsdlist " , " " , True ] ,
2022-06-06 15:49:46 +02:00
[ " Untuned OPT " , " optlist " , " " , True ] ,
2022-03-07 01:51:35 +01:00
[ " Untuned XGLM " , " xglmlist " , " " , True ] ,
[ " Untuned GPT2 " , " gpt2list " , " " , True ] ,
[ " Online Services " , " apilist " , " " , True ] ,
2022-06-07 01:21:10 +02:00
[ " Read Only (No AI) " , " ReadOnly " , " " , False ]
2022-06-06 15:49:46 +02:00
] ,
' adventurelist ' : [
[ " Nerys FSD 13B (Hybrid) " , " KoboldAI/fairseq-dense-13B-Nerys " , " 32GB " , False ] ,
[ " Skein 6B " , " KoboldAI/GPT-J-6B-Skein " , " 16GB " , False ] ,
[ " Adventure 6B " , " KoboldAI/GPT-J-6B-Adventure " , " 16GB " , False ] ,
[ " Nerys FSD 2.7B (Hybrid) " , " KoboldAI/fairseq-dense-2.7B-Nerys " , " 8GB " , False ] ,
[ " Adventure 2.7B " , " KoboldAI/GPT-Neo-2.7B-AID " , " 8GB " , False ] ,
[ " Adventure 1.3B " , " KoboldAI/GPT-Neo-1.3B-Adventure " , " 6GB " , False ] ,
[ " Adventure 125M (Mia) " , " Merry/AID-Neo-125M " , " 2GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
' novellist ' : [
[ " Nerys FSD 13B (Hybrid) " , " KoboldAI/fairseq-dense-13B-Nerys " , " 32GB " , False ] ,
[ " Janeway FSD 13B " , " KoboldAI/fairseq-dense-13B-Janeway " , " 32GB " , False ] ,
[ " Janeway FSD 6.7B " , " KoboldAI/fairseq-dense-6.7B-Janeway " , " 16GB " , False ] ,
[ " Janeway Neo 6B " , " KoboldAI/GPT-J-6B-Janeway " , " 16GB " , False ] ,
[ " Janeway Neo 2.7B " , " KoboldAI/GPT-Neo-2.7B-Janeway " , " 8GB " , False ] ,
[ " Janeway FSD 2.7B " , " KoboldAI/fairseq-dense-2.7B-Janeway " , " 8GB " , False ] ,
[ " Nerys FSD 2.7B (Hybrid) " , " KoboldAI/fairseq-dense-2.7B-Nerys " , " 8GB " , False ] ,
[ " Horni-LN 2.7B " , " KoboldAI/GPT-Neo-2.7B-Horni-LN " , " 8GB " , False ] ,
[ " Picard 2.7B (Older Janeway) " , " KoboldAI/GPT-Neo-2.7B-Picard " , " 8GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
' nsfwlist ' : [
[ " Shinen FSD 13B (NSFW) " , " KoboldAI/fairseq-dense-13B-Shinen " , " 32GB " , False ] ,
[ " Shinen FSD 6.7B (NSFW) " , " KoboldAI/fairseq-dense-6.7B-Shinen " , " 16GB " , False ] ,
[ " Lit 6B (NSFW) " , " hakurei/lit-6B " , " 16GB " , False ] ,
[ " Shinen 6B (NSFW) " , " KoboldAI/GPT-J-6B-Shinen " , " 16GB " , False ] ,
[ " Horni 2.7B (NSFW) " , " KoboldAI/GPT-Neo-2.7B-Horni " , " 8GB " , False ] ,
[ " Shinen 2.7B (NSFW) " , " KoboldAI/GPT-Neo-2.7B-Shinen " , " 8GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
' chatlist ' : [
[ " Convo 6B (Chatbot) " , " hitomi-team/convo-6B " , " 16GB " , False ] ,
[ " C1 6B (Chatbot) " , " hakurei/c1-6B " , " 16GB " , False ] ,
[ " C1 1.3B (Chatbot) " , " iokru/c1-1.3B " , " 6GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
2022-03-07 01:51:35 +01:00
' gptneolist ' : [
[ " GPT-J 6B " , " EleutherAI/gpt-j-6B " , " 16GB " , False ] ,
[ " GPT-Neo 2.7B " , " EleutherAI/gpt-neo-2.7B " , " 8GB " , False ] ,
[ " GPT-Neo 1.3B " , " EleutherAI/gpt-neo-1.3B " , " 6GB " , False ] ,
2022-06-06 15:49:46 +02:00
[ " GPT-Neo 125M " , " EleutherAI/gpt-neo-125M " , " 2GB " , False ] ,
2022-03-07 01:51:35 +01:00
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
2022-06-06 15:49:46 +02:00
] ,
2022-03-07 01:51:35 +01:00
' gpt2list ' : [
[ " GPT-2 XL " , " gpt2-xl " , " 6GB " , False ] ,
[ " GPT-2 Large " , " gpt2-large " , " 4GB " , False ] ,
[ " GPT-2 Med " , " gpt2-medium " , " 2GB " , False ] ,
[ " GPT-2 " , " gpt2 " , " 2GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
2022-06-06 15:49:46 +02:00
] ,
' optlist ' : [
[ " OPT 30B " , " facebook/opt-30b " , " 64GB " , False ] ,
[ " OPT 13B " , " facebook/opt-13b " , " 32GB " , False ] ,
[ " OPT 6.7B " , " facebook/opt-6.7b " , " 16GB " , False ] ,
[ " OPT 2.7B " , " facebook/opt-2.7b " , " 8GB " , False ] ,
[ " OPT 1.3B " , " facebook/opt-1.3b " , " 4GB " , False ] ,
[ " OPT 350M " , " facebook/opt-350m " , " 2GB " , False ] ,
[ " OPT 125M " , " facebook/opt-125m " , " 1GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
2022-03-07 01:51:35 +01:00
' fsdlist ' : [
[ " Fairseq Dense 13B " , " KoboldAI/fairseq-dense-13B " , " 32GB " , False ] ,
[ " Fairseq Dense 6.7B " , " KoboldAI/fairseq-dense-6.7B " , " 16GB " , False ] ,
[ " Fairseq Dense 2.7B " , " KoboldAI/fairseq-dense-2.7B " , " 8GB " , False ] ,
2022-06-06 15:49:46 +02:00
[ " Fairseq Dense 1.3B " , " KoboldAI/fairseq-dense-1.3B " , " 4GB " , False ] ,
[ " Fairseq Dense 355M " , " KoboldAI/fairseq-dense-355M " , " 2GB " , False ] ,
[ " Fairseq Dense 125M " , " KoboldAI/fairseq-dense-125M " , " 1GB " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
] ,
2022-03-07 01:51:35 +01:00
' xglmlist ' : [
2022-06-06 15:49:46 +02:00
[ " XGLM 4.5B (Larger Dataset) " , " facebook/xglm-4.5B " , " 12GB " , False ] ,
[ " XGLM 7.5B " , " facebook/xglm-7.5B " , " 18GB " , False ] ,
[ " XGLM 2.9B " , " facebook/xglm-2.9B " , " 10GB " , False ] ,
[ " XGLM 1.7B " , " facebook/xglm-1.7B " , " 6GB " , False ] ,
[ " XGLM 564M " , " facebook/xglm-564M " , " 4GB " , False ] ,
2022-03-07 01:51:35 +01:00
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
2022-06-06 15:49:46 +02:00
] ,
2022-03-07 01:51:35 +01:00
' apilist ' : [
2022-06-06 15:49:46 +02:00
[ " GooseAI API (requires API key) " , " GooseAI " , " " , False ] ,
2022-03-07 01:51:35 +01:00
[ " OpenAI API (requires API key) " , " OAI " , " " , False ] ,
[ " InferKit API (requires API key) " , " InferKit " , " " , False ] ,
[ " KoboldAI Server API (Old Google Colab) " , " Colab " , " " , False ] ,
[ " Return to Main Menu " , " mainmenu " , " " , True ] ,
2022-02-26 12:34:07 +01:00
]
2022-03-07 01:51:35 +01:00
}
2021-11-14 03:13:52 +01:00
# Variables
class vars :
lastact = " " # The last action received from the user
2021-12-12 07:52:42 +01:00
submission = " " # Same as above, but after applying input formatting
2021-11-14 03:13:52 +01:00
lastctx = " " # The last context submitted to the generator
model = " " # Model ID string chosen at startup
2021-12-23 02:50:06 +01:00
model_type = " " # Model Type (Automatically taken from the model config)
2021-11-14 03:13:52 +01:00
noai = False # Runs the script without starting up the transformers pipeline
aibusy = False # Stops submissions while the AI is working
2022-05-27 01:23:48 +02:00
max_length = 2048 # Maximum number of tokens to submit per action
2021-11-14 03:13:52 +01:00
ikmax = 3000 # Maximum number of characters to submit to InferKit
genamt = 80 # Amount of text for each action to generate
ikgen = 200 # Number of characters for InferKit to generate
rep_pen = 1.1 # Default generator repetition_penalty
2022-03-16 18:34:02 +01:00
rep_pen_slope = 0.7 # Default generator repetition penalty slope
2022-02-24 01:14:26 +01:00
rep_pen_range = 1024 # Default generator repetition penalty range
2021-11-14 03:13:52 +01:00
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
tfs = 1.0 # Default generator tfs (tail-free sampling)
2022-03-27 22:25:50 +02:00
typical = 1.0 # Default generator typical sampling threshold
2021-11-14 03:13:52 +01:00
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
2022-01-18 23:20:45 +01:00
gamesaved = True # Whether or not current game is saved
2021-12-13 08:32:09 +01:00
serverstarted = False # Whether or not the Flask server has started
2021-11-14 03:13:52 +01:00
prompt = " " # Prompt
memory = " " # Text submitted to memory field
authornote = " " # Text submitted to Author's Note field
2021-12-30 05:43:36 +01:00
authornotetemplate = " [Author ' s note: <|>] " # Author's note template
setauthornotetemplate = authornotetemplate # Saved author's note template in settings
2021-11-14 03:13:52 +01:00
andepth = 3 # How far back in history to append author's note
actions = structures . KoboldStoryRegister ( ) # Actions submitted by user and AI
2022-02-28 14:31:26 +01:00
actions_metadata = { } # List of dictonaries, one dictonary for every action that contains information about the action like alternative options.
2022-01-20 21:18:43 +01:00
# Contains at least the same number of items as actions. Back action will remove an item from actions, but not actions_metadata
# Dictonary keys are:
# Selected Text: (text the user had selected. None when this is a newly generated action)
# Alternative Generated Text: {Text, Pinned, Previous Selection, Edited}
#
2021-12-05 05:59:28 +01:00
worldinfo = [ ] # List of World Info key/value objects
2021-12-12 01:11:38 +01:00
worldinfo_i = [ ] # List of World Info key/value objects sans uninitialized entries
2021-12-11 01:45:57 +01:00
worldinfo_u = { } # Dictionary of World Info UID - key/value pairs
2021-12-05 05:59:28 +01:00
wifolders_d = { } # Dictionary of World Info folder UID-info pairs
wifolders_l = [ ] # List of World Info folder UIDs
2021-12-11 01:45:57 +01:00
wifolders_u = { } # Dictionary of pairs of folder UID - list of WI UID
2022-02-24 00:22:18 +01:00
modelconfig = { } # Raw contents of the model's config.json, or empty dictionary if none found
2021-12-11 18:45:45 +01:00
lua_state = None # Lua state of the Lua scripting system
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
lua_koboldcore = None # `koboldcore` from bridge.lua
2021-12-13 07:59:53 +01:00
lua_logname = . . . # Name of previous userscript that logged to terminal
2021-12-23 05:33:27 +01:00
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
2021-12-20 02:18:28 +01:00
lua_edited = set ( ) # Set of chunk numbers that were edited from a Lua generation modifier
lua_deleted = set ( ) # Set of chunk numbers that were deleted from a Lua generation modifier
2021-12-27 00:29:54 +01:00
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
2022-01-10 22:36:15 +01:00
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
2022-01-17 03:09:10 +01:00
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
2022-04-12 21:59:05 +02:00
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
2021-12-23 19:02:11 +01:00
spfilename = " " # Filename of soft prompt to load, or an empty string if not using a soft prompt
2021-12-13 07:03:26 +01:00
userscripts = [ ] # List of userscripts to load
2021-12-23 05:33:27 +01:00
last_userscripts = [ ] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
2021-12-14 01:28:33 +01:00
corescript = " default.lua " # Filename of corescript to load
2021-11-14 03:13:52 +01:00
# badwords = [] # Array of str/chr values that should be removed from output
2022-06-09 00:42:44 +02:00
badwordsids = [ ]
badwordsids_default = [ [ 13460 ] , [ 6880 ] , [ 50256 ] , [ 42496 ] , [ 4613 ] , [ 17414 ] , [ 22039 ] , [ 16410 ] , [ 27 ] , [ 29 ] , [ 38430 ] , [ 37922 ] , [ 15913 ] , [ 24618 ] , [ 28725 ] , [ 58 ] , [ 47175 ] , [ 36937 ] , [ 26700 ] , [ 12878 ] , [ 16471 ] , [ 37981 ] , [ 5218 ] , [ 29795 ] , [ 13412 ] , [ 45160 ] , [ 3693 ] , [ 49778 ] , [ 4211 ] , [ 20598 ] , [ 36475 ] , [ 33409 ] , [ 44167 ] , [ 32406 ] , [ 29847 ] , [ 29342 ] , [ 42669 ] , [ 685 ] , [ 25787 ] , [ 7359 ] , [ 3784 ] , [ 5320 ] , [ 33994 ] , [ 33490 ] , [ 34516 ] , [ 43734 ] , [ 17635 ] , [ 24293 ] , [ 9959 ] , [ 23785 ] , [ 21737 ] , [ 28401 ] , [ 18161 ] , [ 26358 ] , [ 32509 ] , [ 1279 ] , [ 38155 ] , [ 18189 ] , [ 26894 ] , [ 6927 ] , [ 14610 ] , [ 23834 ] , [ 11037 ] , [ 14631 ] , [ 26933 ] , [ 46904 ] , [ 22330 ] , [ 25915 ] , [ 47934 ] , [ 38214 ] , [ 1875 ] , [ 14692 ] , [ 41832 ] , [ 13163 ] , [ 25970 ] , [ 29565 ] , [ 44926 ] , [ 19841 ] , [ 37250 ] , [ 49029 ] , [ 9609 ] , [ 44438 ] , [ 16791 ] , [ 17816 ] , [ 30109 ] , [ 41888 ] , [ 47527 ] , [ 42924 ] , [ 23984 ] , [ 49074 ] , [ 33717 ] , [ 31161 ] , [ 49082 ] , [ 30138 ] , [ 31175 ] , [ 12240 ] , [ 14804 ] , [ 7131 ] , [ 26076 ] , [ 33250 ] , [ 3556 ] , [ 38381 ] , [ 36338 ] , [ 32756 ] , [ 46581 ] , [ 17912 ] , [ 49146 ] ] # Tokenized array of badwords used to prevent AI artifacting
2022-03-19 06:30:02 +01:00
badwordsids_neox = [ [ 0 ] , [ 1 ] , [ 44162 ] , [ 9502 ] , [ 12520 ] , [ 31841 ] , [ 36320 ] , [ 49824 ] , [ 34417 ] , [ 6038 ] , [ 34494 ] , [ 24815 ] , [ 26635 ] , [ 24345 ] , [ 3455 ] , [ 28905 ] , [ 44270 ] , [ 17278 ] , [ 32666 ] , [ 46880 ] , [ 7086 ] , [ 43189 ] , [ 37322 ] , [ 17778 ] , [ 20879 ] , [ 49821 ] , [ 3138 ] , [ 14490 ] , [ 4681 ] , [ 21391 ] , [ 26786 ] , [ 43134 ] , [ 9336 ] , [ 683 ] , [ 48074 ] , [ 41256 ] , [ 19181 ] , [ 29650 ] , [ 28532 ] , [ 36487 ] , [ 45114 ] , [ 46275 ] , [ 16445 ] , [ 15104 ] , [ 11337 ] , [ 1168 ] , [ 5647 ] , [ 29 ] , [ 27482 ] , [ 44965 ] , [ 43782 ] , [ 31011 ] , [ 42944 ] , [ 47389 ] , [ 6334 ] , [ 17548 ] , [ 38329 ] , [ 32044 ] , [ 35487 ] , [ 2239 ] , [ 34761 ] , [ 7444 ] , [ 1084 ] , [ 12399 ] , [ 18990 ] , [ 17636 ] , [ 39083 ] , [ 1184 ] , [ 35830 ] , [ 28365 ] , [ 16731 ] , [ 43467 ] , [ 47744 ] , [ 1138 ] , [ 16079 ] , [ 40116 ] , [ 45564 ] , [ 18297 ] , [ 42368 ] , [ 5456 ] , [ 18022 ] , [ 42696 ] , [ 34476 ] , [ 23505 ] , [ 23741 ] , [ 39334 ] , [ 37944 ] , [ 45382 ] , [ 38709 ] , [ 33440 ] , [ 26077 ] , [ 43600 ] , [ 34418 ] , [ 36033 ] , [ 6660 ] , [ 48167 ] , [ 48471 ] , [ 15775 ] , [ 19884 ] , [ 41533 ] , [ 1008 ] , [ 31053 ] , [ 36692 ] , [ 46576 ] , [ 20095 ] , [ 20629 ] , [ 31759 ] , [ 46410 ] , [ 41000 ] , [ 13488 ] , [ 30952 ] , [ 39258 ] , [ 16160 ] , [ 27655 ] , [ 22367 ] , [ 42767 ] , [ 43736 ] , [ 49694 ] , [ 13811 ] , [ 12004 ] , [ 46768 ] , [ 6257 ] , [ 37471 ] , [ 5264 ] , [ 44153 ] , [ 33805 ] , [ 20977 ] , [ 21083 ] , [ 25416 ] , [ 14277 ] , [ 31096 ] , [ 42041 ] , [ 18331 ] , [ 33376 ] , [ 22372 ] , [ 46294 ] , [ 28379 ] , [ 38475 ] , [ 1656 ] , [ 5204 ] , [ 27075 ] , [ 50001 ] , [ 16616 ] , [ 11396 ] , [ 7748 ] , [ 48744 ] , [ 35402 ] , [ 28120 ] , [ 41512 ] , [ 4207 ] , [ 43144 ] , [ 14767 ] , [ 15640 ] , [ 16595 ] , [ 41305 ] , [ 44479 ] , [ 38958 ] , [ 18474 ] , [ 22734 ] , [ 30522 ] , [ 46267 ] , [ 60 ] , [ 13976 ] , [ 31830 ] , [ 48701 ] , [ 39822 ] , [ 9014 ] , [ 21966 ] , [ 31422 ] , [ 28052 ] , [ 34607 ] , [ 2479 ] , [ 3851 ] , [ 32214 ] , [ 44082 ] , [ 45507 ] , [ 3001 ] , [ 34368 ] , [ 34758 ] , [ 13380 ] , [ 38363 ] , [ 4299 ] , [ 46802 ] , [ 30996 ] , [ 12630 ] , [ 49236 ] , [ 7082 ] , [ 8795 ] , [ 5218 ] , [ 44740 ] , [ 9686 ] , [ 9983 ] , [ 45301 ] , [ 27114 ] , [ 40125 ] , [ 1570 ] , [ 26997 ] , [ 544 ] , [ 5290 ] , [ 49193 ] , [ 23781 ] , [ 14193 ] , [ 40000 ] , [ 2947 ] , [ 43781 ] , [ 9102 ] , [ 48064 ] , [ 42274 ] , [ 18772 ] , [ 49384 ] , [ 9884 ] , [ 45635 ] , [ 43521 ] , [ 31258 ] , [ 32056 ] , [ 47686 ] , [ 21760 ] , [ 13143 ] , [ 10148 ] , [ 26119 ] , [ 44308 ] , [ 31379 ] , [ 36399 ] , [ 23983 ] , [ 46694 ] , [ 36134 ] , [ 8562 ] , [ 12977 ] , [ 35117 ] , [ 28591 ] , [ 49021 ] , [ 47093 ] , [ 28653 ] , [ 29013 ] , [ 46468 ] , [ 8605 ] , [ 7254 ] , [ 25896 ] , [ 5032 ] , [ 8168 ] , [ 36893 ] , [ 38270 ] , [ 20499 ] , [ 27501 ] , [ 34419 ] , [ 29547 ] , [ 28571 ] , [ 36586 ] , [ 20871 ] , [ 30537 ] , [ 26842 ] , [ 21375 ] , [ 31148 ] , [ 27618 ] , [ 33094 ] , [ 3291 ] , [ 31789 ] , [ 28391 ] , [ 870 ] , [ 9793 ] , [ 41361 ] , [ 47916 ] , [ 27468 ] , [ 43856 ] , [ 8850 ] , [ 35237 ] , [ 15707 ] , [ 47552 ] , [ 2730 ] , [ 41449 ] , [ 45488 ] , [ 3073 ] , [ 49806 ] , [ 21938 ] , [ 24430 ] , [ 22747 ] , [ 20924 ] , [ 46145 ] , [ 20481 ] , [ 20197 ] , [ 8239 ] , [ 28231 ] , [ 17987 ] , [ 42804 ] , [ 47269 ] , [ 29972 ] , [ 49884 ] , [ 21382 ] , [ 46295 ] , [ 36676 ] , [ 34616 ] , [ 3921 ] , [ 26991 ] , [ 27720 ] , [ 46265 ] , [ 654 ] , [ 9855 ] , [ 40354 ] , [ 5291 ] , [ 34904 ] , [ 44342 ] , [ 2470 ] , [ 14598 ] , [ 880 ] , [ 19282 ] , [ 2498 ] , [ 24237 ] , [ 21431 ] , [ 16369 ] , [ 8994 ] , [ 44524 ] , [ 45662 ] , [ 13663 ] , [ 37077 ] , [ 1447 ] , [ 37786 ] , [ 30863 ] , [ 42854 ] , [ 1019 ] , [ 20322 ] , [ 4398 ] , [ 12159 ] , [ 44072 ] , [ 48664 ] , [ 31547 ] , [ 18736 ] , [ 9259 ] , [ 31 ] , [ 16354 ] , [ 21810 ] , [ 4357 ] , [ 37982 ] , [ 5064 ] , [ 2033 ] , [ 32871 ] , [ 47446 ] , [ 62 ] , [ 22158 ] , [ 37387 ] , [ 8743 ] , [ 47007 ] , [ 17981 ] , [ 11049 ] , [ 4622 ] , [ 37916 ] , [ 36786 ] , [ 35138 ] , [ 29925 ] , [ 14157 ] , [ 18095 ] , [ 27829 ] , [ 1181 ] , [ 22226 ] , [ 5709 ] , [ 4725 ] , [ 30189 ] , [ 37014 ] , [ 1254 ] , [ 11380 ] , [ 42989 ] , [ 696 ] , [ 24576 ] , [ 39487 ] , [ 30119 ] , [ 1092 ] , [ 8088 ] , [ 2194 ] , [ 9899 ] , [ 14412 ] , [ 21828 ] , [ 3725 ] , [ 13544 ] , [ 5180 ] , [ 44679 ] , [ 34398 ] , [ 3891 ] , [ 28739 ] , [ 14219 ] , [ 37594 ] , [ 49550 ] , [ 11326 ] , [ 6904 ] , [ 17266 ] , [ 5749 ] , [ 10174 ] , [ 23405 ] , [ 9955 ] , [ 38271 ] , [ 41018 ] , [ 13011 ] , [ 48392 ] , [ 36784 ] , [ 24254 ] , [ 21687 ] , [ 23734 ] , [ 5413 ] , [ 41447 ] , [ 45472 ] , [ 10122 ] , [ 17555 ] , [ 15830 ] , [ 47384 ] , [ 12084 ] , [ 31350 ] , [ 47940 ] , [ 11661 ] , [ 27988 ] , [ 45443 ] , [ 905 ] , [ 49651 ] , [ 16614 ] , [ 34993 ] , [ 6781 ] , [ 30803 ] , [ 35869 ] , [ 8001 ] , [ 41604 ] , [ 28118 ] , [ 46462 ] , [ 46762 ] , [ 16262 ] , [ 17281 ] , [ 5774 ] , [ 10943 ] , [ 5013 ] , [ 18257 ] , [ 6750 ] , [ 4713 ] , [ 3951 ] , [ 11899 ] , [ 38791 ] , [ 16943 ] , [ 37596 ] , [ 9318 ] , [ 18413 ] , [ 40473 ] , [ 13208 ] , [ 16375 ] ]
2022-05-13 16:45:28 +02:00
badwordsids_opt = [ [ 44717 ] , [ 46613 ] , [ 48513 ] , [ 49923 ] , [ 50185 ] , [ 48755 ] , [ 8488 ] , [ 43303 ] , [ 49659 ] , [ 48601 ] , [ 49817 ] , [ 45405 ] , [ 48742 ] , [ 49925 ] , [ 47720 ] , [ 11227 ] , [ 48937 ] , [ 48784 ] , [ 50017 ] , [ 42248 ] , [ 49310 ] , [ 48082 ] , [ 49895 ] , [ 50025 ] , [ 49092 ] , [ 49007 ] , [ 8061 ] , [ 44226 ] , [ 0 ] , [ 742 ] , [ 28578 ] , [ 15698 ] , [ 49784 ] , [ 46679 ] , [ 39365 ] , [ 49281 ] , [ 49609 ] , [ 48081 ] , [ 48906 ] , [ 46161 ] , [ 48554 ] , [ 49670 ] , [ 48677 ] , [ 49721 ] , [ 49632 ] , [ 48610 ] , [ 48462 ] , [ 47457 ] , [ 10975 ] , [ 46077 ] , [ 28696 ] , [ 48709 ] , [ 43839 ] , [ 49798 ] , [ 49154 ] , [ 48203 ] , [ 49625 ] , [ 48395 ] , [ 50155 ] , [ 47161 ] , [ 49095 ] , [ 48833 ] , [ 49420 ] , [ 49666 ] , [ 48443 ] , [ 22176 ] , [ 49242 ] , [ 48651 ] , [ 49138 ] , [ 49750 ] , [ 40389 ] , [ 48021 ] , [ 21838 ] , [ 49070 ] , [ 45333 ] , [ 40862 ] , [ 1 ] , [ 49915 ] , [ 33525 ] , [ 49858 ] , [ 50254 ] , [ 44403 ] , [ 48992 ] , [ 48872 ] , [ 46117 ] , [ 49853 ] , [ 47567 ] , [ 50206 ] , [ 41552 ] , [ 50068 ] , [ 48999 ] , [ 49703 ] , [ 49940 ] , [ 49329 ] , [ 47620 ] , [ 49868 ] , [ 49962 ] , [ 2 ] , [ 44082 ] , [ 50236 ] , [ 31274 ] , [ 50260 ] , [ 47052 ] , [ 42645 ] , [ 49177 ] , [ 17523 ] , [ 48691 ] , [ 49900 ] , [ 49069 ] , [ 49358 ] , [ 48794 ] , [ 47529 ] , [ 46479 ] , [ 48457 ] , [ 646 ] , [ 49910 ] , [ 48077 ] , [ 48935 ] , [ 46386 ] , [ 48902 ] , [ 49151 ] , [ 48759 ] , [ 49803 ] , [ 45587 ] , [ 48392 ] , [ 47789 ] , [ 48654 ] , [ 49836 ] , [ 49230 ] , [ 48188 ] , [ 50264 ] , [ 46844 ] , [ 44690 ] , [ 48505 ] , [ 50161 ] , [ 27779 ] , [ 49995 ] , [ 41833 ] , [ 50154 ] , [ 49097 ] , [ 48520 ] , [ 50018 ] , [ 8174 ] , [ 50084 ] , [ 49366 ] , [ 49526 ] , [ 50193 ] , [ 7479 ] , [ 49982 ] , [ 3 ] ]
2022-05-14 06:45:43 +02:00
fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
2022-01-01 03:22:51 +01:00
deletewi = None # Temporary storage for UID to delete
2021-11-14 03:13:52 +01:00
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
widepth = 3 # How many historical actions to scan for WI hits
mode = " play " # Whether the interface is in play, memory, or edit mode
editln = 0 # Which line was last selected in Edit Mode
2021-12-16 01:03:49 +01:00
gpu_device = 0 # Which PyTorch device to use when using pure GPU generation
2021-11-14 03:13:52 +01:00
url = " https://api.inferkit.com/v1/models/standard/generate " # InferKit API URL
oaiurl = " " # OpenAI API URL
oaiengines = " https://api.openai.com/v1/engines "
colaburl = " " # Ngrok url for Google Colab mode
apikey = " " # API key to use for InferKit API calls
oaiapikey = " " # API key to use for OpenAI API calls
savedir = getcwd ( ) + " \ stories "
hascuda = False # Whether torch has detected CUDA on the system
usegpu = False # Whether to launch pipeline with GPU support
custmodpth = " " # Filesystem location of custom model to run
formatoptns = { ' frmttriminc ' : True , ' frmtrmblln ' : False , ' frmtrmspch ' : False , ' frmtadsnsp ' : False , ' singleline ' : False } # Container for state of formatting options
importnum = - 1 # Selection on import popup list
importjs = { } # Temporary storage for import data
loadselect = " " # Temporary storage for story filename to load
spselect = " " # Temporary storage for soft prompt filename to load
2021-12-23 05:33:27 +01:00
spmeta = None # Metadata of current soft prompt, or None if not using a soft prompt
2021-11-14 03:13:52 +01:00
sp = None # Current soft prompt tensor (as a NumPy array)
2021-11-22 00:08:04 +01:00
sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt
2022-01-16 05:31:07 +01:00
has_genmod = False # Whether or not at least one loaded Lua userscript has a generation modifier
2021-11-14 03:13:52 +01:00
svowname = " " # Filename that was flagged for overwrite confirm
saveow = False # Whether or not overwrite confirm has been displayed
2022-01-20 13:46:34 +01:00
autosave = False # Whether or not to automatically save after each action
2021-11-14 03:13:52 +01:00
genseqs = [ ] # Temporary storage for generated sequences
recentback = False # Whether Back button was recently used without Submitting or Retrying after
2021-12-27 01:52:56 +01:00
recentrng = None # If a new random game was recently generated without Submitting after, this is the topic used (as a string), otherwise this is None
2022-01-04 20:40:10 +01:00
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
2021-11-14 03:13:52 +01:00
useprompt = False # Whether to send the full prompt with every submit action
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
2022-05-13 07:03:38 +02:00
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently)
2022-01-30 17:06:15 +01:00
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
2021-11-14 03:13:52 +01:00
smandelete = False # Whether stories can be deleted from inside the browser
smanrename = False # Whether stories can be renamed from inside the browser
allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
modeldim = - 1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
laststory = None # Filename (without extension) of most recent story JSON file we loaded
regex_sl = re . compile ( r ' \ n*(?<=.) * \ n(.| \ n)* ' ) # Pattern for limiting the output to a single line
acregex_ai = re . compile ( r ' \ n* *>(.| \ n)* ' ) # Pattern for matching adventure actions from the AI so we can remove them
acregex_ui = re . compile ( r ' ^ *(>.*)$ ' , re . MULTILINE ) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
2021-11-21 07:09:19 +01:00
comregex_ai = re . compile ( r ' (?: \ n< \ |(?:.| \ n)*? \ |>(?= \ n|$))|(?:< \ |(?:.| \ n)*? \ |> \ n?) ' ) # Pattern for matching comments to remove them before sending them to the AI
2021-11-20 07:27:57 +01:00
comregex_ui = re . compile ( r ' (< \ |(?:.| \ n)*? \ |>) ' ) # Pattern for matching comments in the editor
2021-12-26 01:51:32 +01:00
chatmode = False
chatname = " You "
2021-11-14 03:13:52 +01:00
adventure = False
2021-12-26 01:51:32 +01:00
actionmode = 1
2021-11-14 03:13:52 +01:00
dynamicscan = False
2022-02-18 01:08:12 +01:00
host = False
2021-12-16 12:47:44 +01:00
nopromptgen = False
2021-12-30 05:15:59 +01:00
rngpersist = False
2022-01-16 05:31:07 +01:00
nogenmod = False
2022-01-30 19:47:30 +01:00
welcome = False # Custom Welcome Text (False is default)
2022-01-31 18:39:34 +01:00
newlinemode = " n "
2022-01-22 21:30:56 +01:00
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
2022-01-24 18:54:44 +01:00
debug = False # If set to true, will send debug information to the client for display
2022-03-04 06:33:10 +01:00
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
2022-05-31 18:20:16 +02:00
use_colab_tpu = os . environ . get ( " COLAB_TPU_ADDR " , " " ) != " " or os . environ . get ( " TPU_NAME " , " " ) != " " # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
2022-06-09 19:16:32 +02:00
revision = None
2021-11-14 03:13:52 +01:00
2022-02-12 19:23:59 +01:00
utils . vars = vars
2022-06-06 18:49:40 +02:00
class Send_to_socketio ( object ) :
def write ( self , bar ) :
print ( bar , end = " " )
time . sleep ( 0.01 )
2022-06-08 02:24:31 +02:00
try :
emit ( ' from_server ' , { ' cmd ' : ' model_load_status ' , ' data ' : bar . replace ( " " , " " ) } , broadcast = True )
except :
pass
2022-06-06 18:49:40 +02:00
2022-06-06 15:49:46 +02:00
# Set logging level to reduce chatter from Flask
import logging
log = logging . getLogger ( ' werkzeug ' )
log . setLevel ( logging . ERROR )
# Start flask & SocketIO
print ( " {0} Initializing Flask... {1} " . format ( colors . PURPLE , colors . END ) , end = " " )
2022-06-09 14:42:35 +02:00
from flask import Flask , render_template , Response , request , copy_current_request_context , send_from_directory
2022-06-06 15:49:46 +02:00
from flask_socketio import SocketIO , emit
app = Flask ( __name__ , root_path = os . getcwd ( ) )
app . config [ ' SECRET KEY ' ] = ' secret! '
socketio = SocketIO ( app , async_method = " eventlet " )
print ( " {0} OK! {1} " . format ( colors . GREEN , colors . END ) )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Function to get model selection at startup
#==================================================================#
2022-06-09 00:42:44 +02:00
def sendModelSelection ( menu = " mainmenu " , folder = " ./models " ) :
2022-03-07 02:55:11 +01:00
#If we send one of the manual load options, send back the list of model directories, otherwise send the menu
if menu in ( ' NeoCustom ' , ' GPT2Custom ' ) :
2022-06-09 00:42:44 +02:00
( paths , breadcrumbs ) = get_folder_path_info ( folder )
menu_list = [ [ folder , menu , " " , False ] for folder in paths ]
2022-06-06 20:02:17 +02:00
menu_list . append ( [ " Return to Main Menu " , " mainmenu " , " " , True ] )
2022-06-09 00:42:44 +02:00
emit ( ' from_server ' , { ' cmd ' : ' show_model_menu ' , ' data ' : menu_list , ' menu ' : menu , ' breadcrumbs ' : breadcrumbs } , broadcast = True )
2022-03-07 02:55:11 +01:00
else :
2022-06-09 00:42:44 +02:00
emit ( ' from_server ' , { ' cmd ' : ' show_model_menu ' , ' data ' : model_menu [ menu ] , ' menu ' : menu , ' breadcrumbs ' : [ ] } , broadcast = True )
def get_folder_path_info ( base ) :
if base == ' This PC ' :
breadcrumbs = [ [ ' This PC ' , ' This PC ' ] ]
paths = [ [ " {} : \\ " . format ( chr ( i ) ) , " {} : \\ " . format ( chr ( i ) ) ] for i in range ( 65 , 91 ) if os . path . exists ( " {} : " . format ( chr ( i ) ) ) ]
else :
path = os . path . abspath ( base )
if path [ - 1 ] == " \\ " :
path = path [ : - 1 ]
breadcrumbs = [ ]
for i in range ( len ( path . split ( " \\ " ) ) ) :
breadcrumbs . append ( [ " \\ " . join ( path . split ( " \\ " ) [ : i + 1 ] ) ,
path . split ( " \\ " ) [ i ] ] )
if len ( breadcrumbs ) == 1 :
breadcrumbs = [ [ " {} : \\ " . format ( chr ( i ) ) , " {} : \\ " . format ( chr ( i ) ) ] for i in range ( 65 , 91 ) if os . path . exists ( " {} : " . format ( chr ( i ) ) ) ]
else :
if len ( [ [ " {} : \\ " . format ( chr ( i ) ) , " {} : \\ " . format ( chr ( i ) ) ] for i in range ( 65 , 91 ) if os . path . exists ( " {} : " . format ( chr ( i ) ) ) ] ) > 0 :
breadcrumbs . insert ( 0 , [ ' This PC ' , ' This PC ' ] )
paths = [ ]
base_path = os . path . abspath ( base )
for item in os . listdir ( base_path ) :
if os . path . isdir ( os . path . join ( base_path , item ) ) :
paths . append ( [ os . path . join ( base_path , item ) , item ] )
# Paths/breadcrumbs is a list of lists, where the first element in the sublist is the full path and the second is the folder name
return ( paths , breadcrumbs )
2022-03-07 01:51:35 +01:00
2022-02-26 12:34:07 +01:00
def getModelSelection ( modellist ) :
2022-02-28 02:47:15 +01:00
print ( " # Model \t \t \t \t \t \t VRAM \n ======================================================== " )
2021-11-14 03:13:52 +01:00
i = 1
for m in modellist :
2022-02-28 02:47:15 +01:00
print ( " {0} - {1} \t \t \t {2} " . format ( " {:<2} " . format ( i ) , m [ 0 ] . ljust ( 25 ) , m [ 2 ] ) )
2021-11-14 03:13:52 +01:00
i + = 1
print ( " " ) ;
modelsel = 0
vars . model = ' '
while ( vars . model == ' ' ) :
modelsel = input ( " Model #> " )
if ( modelsel . isnumeric ( ) and int ( modelsel ) > 0 and int ( modelsel ) < = len ( modellist ) ) :
2021-12-26 18:49:28 +01:00
vars . model = modellist [ int ( modelsel ) - 1 ] [ 1 ]
2021-11-14 03:13:52 +01:00
else :
print ( " {0} Please enter a valid selection. {1} " . format ( colors . RED , colors . END ) )
2022-02-26 12:34:07 +01:00
# Model Lists
try :
getModelSelection ( eval ( vars . model ) )
except Exception as e :
if ( vars . model == " Return " ) :
getModelSelection ( mainmenu )
2022-06-07 21:32:58 +02:00
2022-02-26 12:34:07 +01:00
# If custom model was selected, get the filesystem location and store it
if ( vars . model == " NeoCustom " or vars . model == " GPT2Custom " ) :
print ( " {0} Please choose the folder where pytorch_model.bin is located: {1} \n " . format ( colors . CYAN , colors . END ) )
modpath = fileops . getdirpath ( getcwd ( ) + " /models " , " Select Model Folder " )
2021-11-14 03:13:52 +01:00
2022-02-26 12:34:07 +01:00
if ( modpath ) :
# Save directory to vars
vars . custmodpth = modpath
else :
# Print error and retry model selection
print ( " {0} Model select cancelled! {1} " . format ( colors . RED , colors . END ) )
print ( " {0} Select an AI model to continue: {1} \n " . format ( colors . CYAN , colors . END ) )
getModelSelection ( mainmenu )
2021-11-14 03:13:52 +01:00
2022-06-09 00:42:44 +02:00
def check_if_dir_is_model ( path ) :
try :
from transformers import AutoConfig
model_config = AutoConfig . from_pretrained ( path , revision = vars . revision , cache_dir = " cache " )
except :
return False
return True
2021-11-14 03:13:52 +01:00
#==================================================================#
# Return all keys in tokenizer dictionary containing char
#==================================================================#
2022-03-09 11:59:33 +01:00
#def gettokenids(char):
# keys = []
# for key in vocab_keys:
# if(key.find(char) != -1):
# keys.append(key)
# return keys
2021-11-14 03:13:52 +01:00
#==================================================================#
# Return Model Name
#==================================================================#
def getmodelname ( ) :
if ( args . configname ) :
modelname = args . configname
return modelname
2022-03-15 04:14:20 +01:00
if ( vars . model in ( " NeoCustom " , " GPT2Custom " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2021-11-14 03:13:52 +01:00
modelname = os . path . basename ( os . path . normpath ( vars . custmodpth ) )
return modelname
else :
modelname = vars . model
return modelname
#==================================================================#
# Breakmodel configuration functions
#==================================================================#
def device_list ( n_layers , primary = None , selected = None ) :
device_count = torch . cuda . device_count ( )
if ( device_count < 2 ) :
primary = None
gpu_blocks = breakmodel . gpu_blocks + ( device_count - len ( breakmodel . gpu_blocks ) ) * [ 0 ]
print ( f " { colors . YELLOW } DEVICE ID | LAYERS | DEVICE NAME { colors . END } " )
for i in range ( device_count ) :
name = torch . cuda . get_device_name ( i )
if ( len ( name ) > 47 ) :
name = " ... " + name [ - 44 : ]
row_color = colors . END
sep_color = colors . YELLOW
print ( f " { row_color } { colors . YELLOW + ' -> ' + row_color if i == selected else ' ' } { ' (primary) ' if i == primary else ' ' * 9 } { i : 3 } { sep_color } | { row_color } { gpu_blocks [ i ] : 3 } { sep_color } | { row_color } { name } { colors . END } " )
row_color = colors . END
sep_color = colors . YELLOW
print ( f " { row_color } { ' ' * 9 } N/A { sep_color } | { row_color } { n_layers : 3 } { sep_color } | { row_color } (CPU) { colors . END } " )
2022-03-04 06:33:10 +01:00
def device_config ( config ) :
2021-11-14 03:13:52 +01:00
global breakmodel , generator
import breakmodel
2022-05-13 07:03:38 +02:00
n_layers = utils . num_layers ( config )
2021-11-14 03:13:52 +01:00
if ( args . breakmodel_gpulayers is not None ) :
try :
breakmodel . gpu_blocks = list ( map ( int , args . breakmodel_gpulayers . split ( ' , ' ) ) )
assert len ( breakmodel . gpu_blocks ) < = torch . cuda . device_count ( )
2022-01-04 20:43:37 +01:00
s = n_layers
for i in range ( len ( breakmodel . gpu_blocks ) ) :
if ( breakmodel . gpu_blocks [ i ] < = - 1 ) :
breakmodel . gpu_blocks [ i ] = s
break
else :
s - = breakmodel . gpu_blocks [ i ]
2021-11-14 03:13:52 +01:00
assert sum ( breakmodel . gpu_blocks ) < = n_layers
n_layers - = sum ( breakmodel . gpu_blocks )
except :
2022-01-30 17:40:43 +01:00
print ( " WARNING: --breakmodel_gpulayers is malformatted. Please use the --help option to see correct usage of --breakmodel_gpulayers. Defaulting to all layers on device 0. " , file = sys . stderr )
2021-11-14 03:13:52 +01:00
breakmodel . gpu_blocks = [ n_layers ]
n_layers = 0
elif ( args . breakmodel_layers is not None ) :
breakmodel . gpu_blocks = [ n_layers - max ( 0 , min ( n_layers , args . breakmodel_layers ) ) ]
n_layers - = sum ( breakmodel . gpu_blocks )
elif ( args . model is not None ) :
print ( " Breakmodel not specified, assuming GPU 0 " )
breakmodel . gpu_blocks = [ n_layers ]
n_layers = 0
else :
device_count = torch . cuda . device_count ( )
if ( device_count > 1 ) :
print ( colors . CYAN + " \n Please select one of your GPUs to be your primary GPU. " )
print ( " VRAM usage in your primary GPU will be higher than for your other ones. " )
print ( " It is recommended you make your fastest GPU your primary GPU. " )
device_list ( n_layers )
while ( True ) :
primaryselect = input ( " device ID> " )
if ( primaryselect . isnumeric ( ) and 0 < = int ( primaryselect ) < device_count ) :
breakmodel . primary_device = int ( primaryselect )
break
else :
print ( f " { colors . RED } Please enter an integer between 0 and { device_count - 1 } . { colors . END } " )
else :
breakmodel . primary_device = 0
print ( colors . PURPLE + " \n If you don ' t have enough VRAM to run the model on a single GPU " )
print ( " you can split the model between your CPU and your GPU(s), or between " )
print ( " multiple GPUs if you have more than one. " )
print ( " By putting more ' layers ' on a GPU or CPU, more computations will be " )
print ( " done on that device and more VRAM or RAM will be required on that device " )
print ( " (roughly proportional to number of layers). " )
print ( " It should be noted that GPUs are orders of magnitude faster than the CPU. " )
print ( f " This model has { colors . YELLOW } { n_layers } { colors . PURPLE } layers. { colors . END } \n " )
for i in range ( device_count ) :
device_list ( n_layers , primary = breakmodel . primary_device , selected = i )
print ( f " { colors . CYAN } \n How many of the remaining { colors . YELLOW } { n_layers } { colors . CYAN } layers would you like to put into device { i } ? \n You can also enter -1 to allocate all remaining layers to this device. { colors . END } \n " )
while ( True ) :
layerselect = input ( " # of layers> " )
if ( ( layerselect . isnumeric ( ) or layerselect . strip ( ) == ' -1 ' ) and - 1 < = int ( layerselect ) < = n_layers ) :
layerselect = int ( layerselect )
layerselect = n_layers if layerselect == - 1 else layerselect
breakmodel . gpu_blocks . append ( layerselect )
n_layers - = layerselect
break
else :
print ( f " { colors . RED } Please enter an integer between -1 and { n_layers } . { colors . END } " )
if ( n_layers == 0 ) :
break
print ( colors . PURPLE + " \n Final device configuration: " )
device_list ( n_layers )
2021-11-15 00:42:18 +01:00
# If all layers are on the same device, use the old GPU generation mode
2021-11-15 03:08:49 +01:00
while ( len ( breakmodel . gpu_blocks ) and breakmodel . gpu_blocks [ - 1 ] == 0 ) :
breakmodel . gpu_blocks . pop ( )
2022-05-13 07:03:38 +02:00
if ( len ( breakmodel . gpu_blocks ) and breakmodel . gpu_blocks [ - 1 ] in ( - 1 , utils . num_layers ( config ) ) ) :
2021-11-15 00:42:18 +01:00
vars . breakmodel = False
vars . usegpu = True
2021-12-16 01:03:49 +01:00
vars . gpu_device = len ( breakmodel . gpu_blocks ) - 1
2021-11-15 00:42:18 +01:00
return
2021-11-18 00:06:57 +01:00
if ( not breakmodel . gpu_blocks ) :
print ( " Nothing assigned to a GPU, reverting to CPU only mode " )
vars . breakmodel = False
vars . usegpu = False
2022-03-04 06:33:10 +01:00
return
def move_model_to_devices ( model ) :
global generator
if ( not vars . breakmodel ) :
if ( vars . usegpu ) :
model = model . half ( ) . to ( vars . gpu_device )
else :
model = model . to ( ' cpu ' ) . float ( )
2021-11-18 00:06:57 +01:00
generator = model . generate
return
2022-03-04 06:33:10 +01:00
model . half ( )
2021-11-15 00:42:18 +01:00
gc . collect ( )
2022-02-01 18:49:07 +01:00
if ( hasattr ( model , " transformer " ) ) :
model . transformer . wte . to ( breakmodel . primary_device )
model . transformer . ln_f . to ( breakmodel . primary_device )
if ( hasattr ( model , ' lm_head ' ) ) :
model . lm_head . to ( breakmodel . primary_device )
if ( hasattr ( model . transformer , ' wpe ' ) ) :
model . transformer . wpe . to ( breakmodel . primary_device )
2022-05-13 07:03:38 +02:00
elif ( not hasattr ( model . model , " decoder " ) ) :
2022-02-01 18:49:07 +01:00
model . model . embed_tokens . to ( breakmodel . primary_device )
model . model . layer_norm . to ( breakmodel . primary_device )
2021-11-14 03:13:52 +01:00
model . lm_head . to ( breakmodel . primary_device )
2022-02-01 18:49:07 +01:00
model . model . embed_positions . to ( breakmodel . primary_device )
2022-05-13 07:03:38 +02:00
else :
model . model . decoder . embed_tokens . to ( breakmodel . primary_device )
if ( model . model . decoder . project_in is not None ) :
model . model . decoder . project_in . to ( breakmodel . primary_device )
if ( model . model . decoder . project_out is not None ) :
model . model . decoder . project_out . to ( breakmodel . primary_device )
model . model . decoder . embed_positions . to ( breakmodel . primary_device )
2021-11-14 03:13:52 +01:00
gc . collect ( )
2022-02-01 18:49:07 +01:00
GPTNeoModel . forward = breakmodel . new_forward_neo
2021-11-26 00:09:16 +01:00
if ( " GPTJModel " in globals ( ) ) :
2022-03-09 12:03:37 +01:00
GPTJModel . forward = breakmodel . new_forward_neo # type: ignore
2022-02-01 18:49:07 +01:00
if ( " XGLMModel " in globals ( ) ) :
2022-03-09 12:03:37 +01:00
XGLMModel . forward = breakmodel . new_forward_xglm # type: ignore
2022-05-13 07:03:38 +02:00
if ( " OPTDecoder " in globals ( ) ) :
OPTDecoder . forward = breakmodel . new_forward_opt # type: ignore
2021-11-14 03:13:52 +01:00
generator = model . generate
2022-02-01 18:49:07 +01:00
if ( hasattr ( model , " transformer " ) ) :
breakmodel . move_hidden_layers ( model . transformer )
2022-05-13 07:03:38 +02:00
elif ( not hasattr ( model . model , " decoder " ) ) :
2022-02-01 18:49:07 +01:00
breakmodel . move_hidden_layers ( model . model , model . model . layers )
2022-05-13 07:03:38 +02:00
else :
breakmodel . move_hidden_layers ( model . model . decoder , model . model . decoder . layers )
2021-11-14 03:13:52 +01:00
2022-01-30 17:06:15 +01:00
#==================================================================#
# Allow the models to override some settings
#==================================================================#
def loadmodelsettings ( ) :
try :
2022-02-24 01:14:26 +01:00
js = json . loads ( str ( model_config ) . partition ( ' ' ) [ 2 ] )
2022-01-30 17:06:15 +01:00
except Exception as e :
try :
2022-02-24 00:22:18 +01:00
try :
2022-02-24 01:14:26 +01:00
js = json . load ( open ( vars . custmodpth + " /config.json " , " r " ) )
2022-02-24 00:22:18 +01:00
except Exception as e :
2022-02-24 01:14:26 +01:00
js = json . load ( open ( vars . custmodpth . replace ( ' / ' , ' _ ' ) + " /config.json " , " r " ) )
2022-01-30 17:06:15 +01:00
except Exception as e :
2022-02-24 00:22:18 +01:00
js = { }
2022-02-24 01:14:26 +01:00
if vars . model_type == " xglm " or js . get ( " compat " , " j " ) == " fairseq_lm " :
2022-02-24 00:22:18 +01:00
vars . newlinemode = " s " # Default to </s> newline mode if using XGLM
2022-05-13 10:44:12 +02:00
if vars . model_type == " opt " :
2022-05-13 10:53:19 +02:00
vars . newlinemode = " ns " # Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
2022-02-24 00:22:18 +01:00
vars . modelconfig = js
2022-01-30 17:06:15 +01:00
if ( " badwordsids " in js ) :
vars . badwordsids = js [ " badwordsids " ]
if ( " nobreakmodel " in js ) :
vars . nobreakmodel = js [ " nobreakmodel " ]
if ( " temp " in js ) :
vars . temp = js [ " temp " ]
if ( " top_p " in js ) :
vars . top_p = js [ " top_p " ]
if ( " top_k " in js ) :
vars . top_k = js [ " top_k " ]
if ( " tfs " in js ) :
vars . tfs = js [ " tfs " ]
2022-03-27 22:25:50 +02:00
if ( " typical " in js ) :
vars . typical = js [ " typical " ]
2022-01-30 17:06:15 +01:00
if ( " rep_pen " in js ) :
vars . rep_pen = js [ " rep_pen " ]
if ( " rep_pen_slope " in js ) :
vars . rep_pen_slope = js [ " rep_pen_slope " ]
if ( " rep_pen_range " in js ) :
vars . rep_pen_range = js [ " rep_pen_range " ]
if ( " adventure " in js ) :
vars . adventure = js [ " adventure " ]
if ( " chatmode " in js ) :
vars . chatmode = js [ " chatmode " ]
if ( " dynamicscan " in js ) :
vars . dynamicscan = js [ " dynamicscan " ]
if ( " formatoptns " in js ) :
vars . formatoptns = js [ " formatoptns " ]
2022-01-30 19:47:30 +01:00
if ( " welcome " in js ) :
vars . welcome = js [ " welcome " ]
2022-01-31 18:39:34 +01:00
if ( " newlinemode " in js ) :
vars . newlinemode = js [ " newlinemode " ]
2022-01-30 17:06:15 +01:00
if ( " antemplate " in js ) :
vars . setauthornotetemplate = js [ " antemplate " ]
if ( not vars . gamestarted ) :
vars . authornotetemplate = vars . setauthornotetemplate
2022-02-24 03:09:31 +01:00
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings ( ) :
# Build json to write
js = { }
js [ " apikey " ] = vars . apikey
js [ " andepth " ] = vars . andepth
js [ " temp " ] = vars . temp
js [ " top_p " ] = vars . top_p
js [ " top_k " ] = vars . top_k
js [ " tfs " ] = vars . tfs
2022-03-27 22:25:50 +02:00
js [ " typical " ] = vars . typical
2022-02-24 03:09:31 +01:00
js [ " rep_pen " ] = vars . rep_pen
js [ " rep_pen_slope " ] = vars . rep_pen_slope
js [ " rep_pen_range " ] = vars . rep_pen_range
js [ " genamt " ] = vars . genamt
js [ " max_length " ] = vars . max_length
js [ " ikgen " ] = vars . ikgen
js [ " formatoptns " ] = vars . formatoptns
js [ " numseqs " ] = vars . numseqs
js [ " widepth " ] = vars . widepth
js [ " useprompt " ] = vars . useprompt
js [ " adventure " ] = vars . adventure
js [ " chatmode " ] = vars . chatmode
js [ " chatname " ] = vars . chatname
js [ " dynamicscan " ] = vars . dynamicscan
js [ " nopromptgen " ] = vars . nopromptgen
js [ " rngpersist " ] = vars . rngpersist
js [ " nogenmod " ] = vars . nogenmod
js [ " autosave " ] = vars . autosave
js [ " welcome " ] = vars . welcome
js [ " newlinemode " ] = vars . newlinemode
js [ " antemplate " ] = vars . setauthornotetemplate
js [ " userscripts " ] = vars . userscripts
js [ " corescript " ] = vars . corescript
js [ " softprompt " ] = vars . spfilename
# Write it
if not os . path . exists ( ' settings ' ) :
os . mkdir ( ' settings ' )
file = open ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " , " w " )
try :
file . write ( json . dumps ( js , indent = 3 ) )
finally :
file . close ( )
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce ( 2 )
def settingschanged ( ) :
print ( " {0} Saving settings! {1} " . format ( colors . GREEN , colors . END ) )
savesettings ( )
2022-02-24 02:39:11 +01:00
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
2022-06-01 10:34:16 +02:00
2022-02-24 02:39:11 +01:00
def loadsettings ( ) :
2022-06-01 10:34:16 +02:00
if ( path . exists ( " defaults/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " ) ) :
# Read file contents into JSON object
file = open ( " defaults/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " , " r " )
js = json . load ( file )
processsettings ( js )
file . close ( )
2022-02-24 02:39:11 +01:00
if ( path . exists ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " ) ) :
# Read file contents into JSON object
file = open ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " , " r " )
js = json . load ( file )
2022-06-01 10:34:16 +02:00
processsettings ( js )
file . close ( )
2022-02-24 02:39:11 +01:00
2022-06-01 10:34:16 +02:00
def processsettings ( js ) :
# Copy file contents to vars
if ( " apikey " in js ) :
vars . apikey = js [ " apikey " ]
if ( " andepth " in js ) :
vars . andepth = js [ " andepth " ]
if ( " temp " in js ) :
vars . temp = js [ " temp " ]
if ( " top_p " in js ) :
vars . top_p = js [ " top_p " ]
if ( " top_k " in js ) :
vars . top_k = js [ " top_k " ]
if ( " tfs " in js ) :
vars . tfs = js [ " tfs " ]
if ( " typical " in js ) :
vars . typical = js [ " typical " ]
if ( " rep_pen " in js ) :
vars . rep_pen = js [ " rep_pen " ]
if ( " rep_pen_slope " in js ) :
vars . rep_pen_slope = js [ " rep_pen_slope " ]
if ( " rep_pen_range " in js ) :
vars . rep_pen_range = js [ " rep_pen_range " ]
if ( " genamt " in js ) :
vars . genamt = js [ " genamt " ]
if ( " max_length " in js ) :
vars . max_length = js [ " max_length " ]
if ( " ikgen " in js ) :
vars . ikgen = js [ " ikgen " ]
if ( " formatoptns " in js ) :
vars . formatoptns = js [ " formatoptns " ]
if ( " numseqs " in js ) :
vars . numseqs = js [ " numseqs " ]
if ( " widepth " in js ) :
vars . widepth = js [ " widepth " ]
if ( " useprompt " in js ) :
vars . useprompt = js [ " useprompt " ]
if ( " adventure " in js ) :
vars . adventure = js [ " adventure " ]
if ( " chatmode " in js ) :
vars . chatmode = js [ " chatmode " ]
if ( " chatname " in js ) :
vars . chatname = js [ " chatname " ]
if ( " dynamicscan " in js ) :
vars . dynamicscan = js [ " dynamicscan " ]
if ( " nopromptgen " in js ) :
vars . nopromptgen = js [ " nopromptgen " ]
if ( " rngpersist " in js ) :
vars . rngpersist = js [ " rngpersist " ]
if ( " nogenmod " in js ) :
vars . nogenmod = js [ " nogenmod " ]
if ( " autosave " in js ) :
vars . autosave = js [ " autosave " ]
if ( " newlinemode " in js ) :
vars . newlinemode = js [ " newlinemode " ]
if ( " welcome " in js ) :
vars . welcome = js [ " welcome " ]
2022-02-24 02:39:11 +01:00
2022-06-01 10:34:16 +02:00
if ( " antemplate " in js ) :
vars . setauthornotetemplate = js [ " antemplate " ]
if ( not vars . gamestarted ) :
vars . authornotetemplate = vars . setauthornotetemplate
if ( " userscripts " in js ) :
vars . userscripts = [ ]
for userscript in js [ " userscripts " ] :
if type ( userscript ) is not str :
continue
userscript = userscript . strip ( )
if len ( userscript ) != 0 and all ( q not in userscript for q in ( " .. " , " : " ) ) and all ( userscript [ 0 ] not in q for q in ( " / " , " \\ " ) ) and os . path . exists ( fileops . uspath ( userscript ) ) :
vars . userscripts . append ( userscript )
2022-02-24 02:39:11 +01:00
2022-06-01 10:34:16 +02:00
if ( " corescript " in js and type ( js [ " corescript " ] ) is str and all ( q not in js [ " corescript " ] for q in ( " .. " , " : " ) ) and all ( js [ " corescript " ] [ 0 ] not in q for q in ( " / " , " \\ " ) ) ) :
vars . corescript = js [ " corescript " ]
else :
vars . corescript = " default.lua "
2022-02-24 02:39:11 +01:00
2022-02-24 03:09:31 +01:00
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
2022-04-12 21:59:05 +02:00
def check_for_sp_change ( ) :
while ( True ) :
time . sleep ( 0.1 )
if ( vars . sp_changed ) :
with app . app_context ( ) :
emit ( ' from_server ' , { ' cmd ' : ' spstatitems ' , ' data ' : { vars . spfilename : vars . spmeta } if vars . allowsp and len ( vars . spfilename ) else { } } , namespace = None , broadcast = True )
vars . sp_changed = False
2022-06-06 15:49:46 +02:00
socketio . start_background_task ( check_for_sp_change )
2022-02-24 03:09:31 +01:00
def spRequest ( filename ) :
2022-04-12 21:59:05 +02:00
if ( not vars . allowsp ) :
raise RuntimeError ( " Soft prompts are not supported by your current model/backend " )
old_filename = vars . spfilename
2022-02-24 03:09:31 +01:00
vars . spfilename = " "
settingschanged ( )
if ( len ( filename ) == 0 ) :
vars . sp = None
vars . sp_length = 0
2022-04-12 21:59:05 +02:00
if ( old_filename != filename ) :
vars . sp_changed = True
2022-02-24 03:09:31 +01:00
return
global np
if ' np ' not in globals ( ) :
import numpy as np
z , version , shape , fortran_order , dtype = fileops . checksp ( filename , vars . modeldim )
2022-04-12 21:59:05 +02:00
if not isinstance ( z , zipfile . ZipFile ) :
raise RuntimeError ( f " { repr ( filename ) } is not a valid soft prompt file " )
2022-02-24 03:09:31 +01:00
with z . open ( ' meta.json ' ) as f :
vars . spmeta = json . load ( f )
z . close ( )
with np . load ( fileops . sppath ( filename ) , allow_pickle = False ) as f :
tensor = f [ ' tensor.npy ' ]
# If the tensor is in bfloat16 format, convert it to float32
if ( tensor . dtype == ' V2 ' ) :
tensor . dtype = np . uint16
tensor = np . uint32 ( tensor ) << 16
tensor . dtype = np . float32
if ( tensor . dtype != np . float16 ) :
tensor = np . float32 ( tensor )
assert not np . isinf ( tensor ) . any ( ) and not np . isnan ( tensor ) . any ( )
vars . sp_length = tensor . shape [ - 2 ]
vars . spmeta [ " n_tokens " ] = vars . sp_length
2022-03-15 04:14:20 +01:00
if ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2022-02-24 03:09:31 +01:00
rows = tensor . shape [ 0 ]
padding_amount = tpu_mtj_backend . params [ " seq " ] - ( tpu_mtj_backend . params [ " seq " ] % - tpu_mtj_backend . params [ " cores_per_replica " ] ) - rows
tensor = np . pad ( tensor , ( ( 0 , padding_amount ) , ( 0 , 0 ) ) )
tensor = tensor . reshape (
tpu_mtj_backend . params [ " cores_per_replica " ] ,
- 1 ,
2022-05-13 04:21:15 +02:00
tpu_mtj_backend . params . get ( " d_embed " , tpu_mtj_backend . params [ " d_model " ] ) ,
2022-02-24 03:09:31 +01:00
)
vars . sp = tpu_mtj_backend . shard_xmap ( np . float32 ( tensor ) )
else :
vars . sp = torch . from_numpy ( tensor )
vars . spfilename = filename
settingschanged ( )
2022-04-12 21:59:05 +02:00
if ( old_filename != filename ) :
vars . sp_changed = True
2022-02-24 03:09:31 +01:00
2021-11-14 03:13:52 +01:00
#==================================================================#
# Startup
#==================================================================#
2022-06-09 19:16:32 +02:00
def general_startup ( override_args = None ) :
2022-03-07 01:51:35 +01:00
global args
# Parsing Parameters
parser = argparse . ArgumentParser ( description = " KoboldAI Server " )
parser . add_argument ( " --remote " , action = ' store_true ' , help = " Optimizes KoboldAI for Remote Play " )
2022-06-07 20:44:14 +02:00
parser . add_argument ( " --noaimenu " , action = ' store_true ' , help = " Disables the ability to select the AI " )
2022-03-07 01:51:35 +01:00
parser . add_argument ( " --ngrok " , action = ' store_true ' , help = " Optimizes KoboldAI for Remote Play using Ngrok " )
2022-06-06 15:49:46 +02:00
parser . add_argument ( " --localtunnel " , action = ' store_true ' , help = " Optimizes KoboldAI for Remote Play using Localtunnel " )
2022-03-07 01:51:35 +01:00
parser . add_argument ( " --host " , action = ' store_true ' , help = " Optimizes KoboldAI for Remote Play without using a proxy service " )
2022-06-06 15:49:46 +02:00
parser . add_argument ( " --port " , type = int , help = " Specify the port on which the application will be joinable " )
parser . add_argument ( " --aria2_port " , type = int , help = " Specify the port on which aria2 ' s RPC interface will be open if aria2 is installed (defaults to 6799) " )
2022-03-07 01:51:35 +01:00
parser . add_argument ( " --model " , help = " Specify the Model Type to skip the Menu " )
parser . add_argument ( " --path " , help = " Specify the Path for local models (For model NeoCustom or GPT2Custom) " )
2022-06-06 15:49:46 +02:00
parser . add_argument ( " --revision " , help = " Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash) " )
2022-03-07 01:51:35 +01:00
parser . add_argument ( " --cpu " , action = ' store_true ' , help = " By default unattended launches are on the GPU use this option to force CPU usage. " )
parser . add_argument ( " --breakmodel " , action = ' store_true ' , help = argparse . SUPPRESS )
parser . add_argument ( " --breakmodel_layers " , type = int , help = argparse . SUPPRESS )
parser . add_argument ( " --breakmodel_gpulayers " , type = str , help = " If using a model that supports hybrid generation, this is a comma-separated list that specifies how many layers to put on each GPU device. For example to put 8 layers on device 0, 9 layers on device 1 and 11 layers on device 2, use --beakmodel_gpulayers 8,9,11 " )
parser . add_argument ( " --override_delete " , action = ' store_true ' , help = " Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise. " )
parser . add_argument ( " --override_rename " , action = ' store_true ' , help = " Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise. " )
parser . add_argument ( " --configname " , help = " Force a fixed configuration name to aid with config management. " )
parser . add_argument ( " --colab " , action = ' store_true ' , help = " Optimize for Google Colab. " )
parser . add_argument ( " --nobreakmodel " , action = ' store_true ' , help = " Disables Breakmodel support completely. " )
parser . add_argument ( " --unblock " , action = ' store_true ' , default = False , help = " Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead) " )
parser . add_argument ( " --quiet " , action = ' store_true ' , default = False , help = " If present will suppress any story related text from showing on the console " )
2022-06-06 15:49:46 +02:00
parser . add_argument ( " --no_aria2 " , action = ' store_true ' , default = False , help = " Prevents KoboldAI from using aria2 to download huggingface models more efficiently, in case aria2 is causing you issues " )
2022-03-07 01:51:35 +01:00
parser . add_argument ( " --lowmem " , action = ' store_true ' , help = " Extra Low Memory loading for the GPU, slower but memory does not peak to twice the usage " )
2022-06-06 15:49:46 +02:00
parser . add_argument ( " --savemodel " , action = ' store_true ' , help = " Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive) " )
#args: argparse.Namespace = None
2022-06-10 14:39:15 +02:00
if " pytest " in sys . modules and override_args is None :
args = parser . parse_args ( [ ] )
return
2022-06-09 19:16:32 +02:00
if override_args is not None :
import shlex
args = parser . parse_args ( shlex . split ( override_args ) )
elif ( os . environ . get ( " KOBOLDAI_ARGS " ) is not None ) :
2022-03-07 01:51:35 +01:00
import shlex
args = parser . parse_args ( shlex . split ( os . environ [ " KOBOLDAI_ARGS " ] ) )
2021-11-14 03:13:52 +01:00
else :
2022-03-07 01:51:35 +01:00
args = parser . parse_args ( )
vars . model = args . model ;
2022-06-06 15:49:46 +02:00
vars . revision = args . revision
2022-03-07 01:51:35 +01:00
if args . colab :
args . remote = True ;
args . override_rename = True ;
args . override_delete = True ;
args . nobreakmodel = True ;
args . quiet = True ;
args . lowmem = True ;
2022-06-07 20:44:14 +02:00
args . noaimenu = True ;
2022-03-07 01:51:35 +01:00
if args . quiet :
vars . quiet = True
if args . nobreakmodel :
vars . nobreakmodel = True ;
if args . remote :
vars . host = True ;
if args . ngrok :
vars . host = True ;
2022-06-06 15:49:46 +02:00
if args . localtunnel :
vars . host = True ;
2021-11-29 18:11:14 +01:00
2022-03-07 01:51:35 +01:00
if args . host :
vars . host = True ;
if args . cpu :
vars . use_colab_tpu = False
vars . smandelete = vars . host == args . override_delete
vars . smanrename = vars . host == args . override_rename
2022-06-06 15:49:46 +02:00
vars . aria2_port = args . aria2_port or 6799
2022-06-08 02:24:31 +02:00
#Now let's look to see if we are going to force a load of a model from a user selected folder
if ( vars . model == " selectfolder " ) :
print ( " {0} Please choose the folder where pytorch_model.bin is located: {1} \n " . format ( colors . CYAN , colors . END ) )
modpath = fileops . getdirpath ( getcwd ( ) + " /models " , " Select Model Folder " )
if ( modpath ) :
# Save directory to vars
2022-06-08 11:30:00 +02:00
vars . model = " NeoCustom "
vars . custmodpth = modpath
elif args . model :
print ( " Welcome to KoboldAI! \n You have selected the following Model: " , vars . model )
if args . path :
print ( " You have selected the following path for your Model : " , args . path )
vars . custmodpth = args . path ;
vars . colaburl = args . path + " /request " ; # Lets just use the same parameter to keep it simple
2022-03-07 01:51:35 +01:00
#==================================================================#
# Load Model
2022-06-07 03:29:14 +02:00
#==================================================================#
def tpumtjgetsofttokens ( ) :
2022-06-07 03:37:35 +02:00
soft_tokens = None
if ( vars . sp is None ) :
global np
if ' np ' not in globals ( ) :
import numpy as np
tensor = np . zeros ( ( 1 , tpu_mtj_backend . params . get ( " d_embed " , tpu_mtj_backend . params [ " d_model " ] ) ) , dtype = np . float32 )
rows = tensor . shape [ 0 ]
padding_amount = tpu_mtj_backend . params [ " seq " ] - ( tpu_mtj_backend . params [ " seq " ] % - tpu_mtj_backend . params [ " cores_per_replica " ] ) - rows
tensor = np . pad ( tensor , ( ( 0 , padding_amount ) , ( 0 , 0 ) ) )
tensor = tensor . reshape (
tpu_mtj_backend . params [ " cores_per_replica " ] ,
- 1 ,
tpu_mtj_backend . params . get ( " d_embed " , tpu_mtj_backend . params [ " d_model " ] ) ,
)
vars . sp = tpu_mtj_backend . shard_xmap ( tensor )
soft_tokens = np . arange (
tpu_mtj_backend . params [ " n_vocab " ] + tpu_mtj_backend . params [ " n_vocab_padding " ] ,
tpu_mtj_backend . params [ " n_vocab " ] + tpu_mtj_backend . params [ " n_vocab_padding " ] + vars . sp_length ,
dtype = np . uint32
)
return soft_tokens
2022-06-07 03:29:14 +02:00
2022-06-07 01:21:10 +02:00
def get_model_info ( model , directory = " " ) :
# if the model is in the api list
key = False
breakmodel = False
gpu = False
layer_count = None
key_value = " "
break_values = [ ]
2022-06-07 14:22:10 +02:00
url = False
2022-06-09 01:39:04 +02:00
gpu_count = torch . cuda . device_count ( )
2022-06-09 14:42:35 +02:00
gpu_names = [ ]
for i in range ( gpu_count ) :
gpu_names . append ( torch . cuda . get_device_name ( i ) )
2022-06-07 01:21:10 +02:00
if model in [ x [ 1 ] for x in model_menu [ ' apilist ' ] ] :
if path . exists ( " settings/ {} .settings " . format ( model ) ) :
with open ( " settings/ {} .settings " . format ( model ) , " r " ) as file :
# Check if API key exists
js = json . load ( file )
if ( " apikey " in js and js [ " apikey " ] != " " ) :
# API key exists, grab it and close the file
key_value = js [ " apikey " ]
elif ' oaiapikey ' in js and js [ ' oaiapikey ' ] != " " :
key_value = js [ " oaiapikey " ]
key = True
elif model == ' ReadOnly ' :
pass
2022-06-07 14:22:10 +02:00
elif model == ' Colab ' :
url = True
2022-06-07 01:21:10 +02:00
elif not torch . cuda . is_available ( ) :
pass
else :
layer_count = get_layer_count ( model , directory = directory )
if layer_count is None :
breakmodel = False
else :
breakmodel = True
if path . exists ( " settings/ {} .breakmodel " . format ( model . replace ( " / " , " _ " ) ) ) :
with open ( " settings/ {} .breakmodel " . format ( model . replace ( " / " , " _ " ) ) , " r " ) as file :
break_values = file . read ( ) . split ( " , " )
else :
break_values = [ layer_count ]
2022-06-09 01:39:04 +02:00
break_values + = [ 0 ] * ( gpu_count - len ( break_values ) )
2022-06-07 14:22:10 +02:00
emit ( ' from_server ' , { ' cmd ' : ' selected_model_info ' , ' key_value ' : key_value , ' key ' : key ,
' gpu ' : gpu , ' layer_count ' : layer_count , ' breakmodel ' : breakmodel ,
2022-06-09 01:39:04 +02:00
' break_values ' : break_values , ' gpu_count ' : gpu_count ,
2022-06-09 14:42:35 +02:00
' url ' : url , ' gpu_names ' : gpu_names } , broadcast = True )
2022-06-07 01:21:10 +02:00
if key_value != " " :
get_oai_models ( key_value )
2022-03-07 17:27:23 +01:00
def get_layer_count ( model , directory = " " ) :
if ( model not in [ " InferKit " , " Colab " , " OAI " , " GooseAI " , " ReadOnly " , " TPUMeshTransformerGPTJ " ] ) :
2022-06-06 19:57:19 +02:00
if ( vars . model == " GPT2Custom " ) :
model_config = open ( vars . custmodpth + " /config.json " , " r " )
# Get the model_type from the config or assume a model type if it isn't present
2022-03-07 17:27:23 +01:00
else :
2022-06-06 19:57:19 +02:00
from transformers import AutoConfig
2022-06-09 00:42:44 +02:00
if directory == " " :
2022-06-06 19:57:19 +02:00
model_config = AutoConfig . from_pretrained ( vars . model , revision = vars . revision , cache_dir = " cache " )
elif ( os . path . isdir ( vars . custmodpth . replace ( ' / ' , ' _ ' ) ) ) :
model_config = AutoConfig . from_pretrained ( vars . custmodpth . replace ( ' / ' , ' _ ' ) , revision = vars . revision , cache_dir = " cache " )
2022-06-09 00:42:44 +02:00
elif ( os . path . isdir ( directory ) ) :
model_config = AutoConfig . from_pretrained ( directory , revision = vars . revision , cache_dir = " cache " )
2022-06-06 19:57:19 +02:00
else :
model_config = AutoConfig . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
return utils . num_layers ( model_config )
else :
return None
2022-03-07 17:27:23 +01:00
2022-06-07 01:21:10 +02:00
def get_oai_models ( key ) :
vars . oaiapikey = key
if vars . model == ' OAI ' :
url = " https://api.openai.com/v1/engines "
elif vars . model == ' GooseAI ' :
url = " https://api.goose.ai/v1/engines "
2022-03-12 20:21:11 +01:00
else :
2022-06-07 01:21:10 +02:00
return
# Get list of models from OAI
print ( " {0} Retrieving engine list... {1} " . format ( colors . PURPLE , colors . END ) , end = " " )
req = requests . get (
url ,
headers = {
' Authorization ' : ' Bearer ' + key
}
)
if ( req . status_code == 200 ) :
engines = req . json ( ) [ " data " ]
try :
engines = [ [ en [ " id " ] , " {} ( {} ) " . format ( en [ ' id ' ] , " Ready " if en [ " ready " ] == True else " Not Ready " ) ] for en in engines ]
except :
print ( engines )
raise
online_model = " "
changed = False
#Save the key
if not path . exists ( " settings " ) :
# If the client settings file doesn't exist, create it
# Write API key to file
os . makedirs ( ' settings ' , exist_ok = True )
if path . exists ( " settings/ {} .settings " . format ( vars . model ) ) :
with open ( " settings/ {} .settings " . format ( vars . model ) , " r " ) as file :
js = json . load ( file )
if ' online_model ' in js :
online_model = js [ ' online_model ' ]
if " apikey " in js :
if js [ ' apikey ' ] != key :
changed = True
if changed :
with open ( " settings/ {} .settings " . format ( vars . model ) , " w " ) as file :
js [ " apikey " ] = key
file . write ( json . dumps ( js , indent = 3 ) )
emit ( ' from_server ' , { ' cmd ' : ' oai_engines ' , ' data ' : engines , ' online_model ' : online_model } , broadcast = True )
else :
# Something went wrong, print the message and quit since we can't initialize an engine
print ( " {0} ERROR! {1} " . format ( colors . RED , colors . END ) )
print ( req . json ( ) )
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : req . json ( ) } )
2022-03-12 20:21:11 +01:00
2022-06-09 00:42:44 +02:00
def patch_transformers ( ) :
global transformers
old_from_pretrained = PreTrainedModel . from_pretrained . __func__
@classmethod
def new_from_pretrained ( cls , pretrained_model_name_or_path , * model_args , * * kwargs ) :
vars . fp32_model = False
utils . num_shards = None
utils . current_shard = 0
utils . from_pretrained_model_name = pretrained_model_name_or_path
utils . from_pretrained_index_filename = None
utils . from_pretrained_kwargs = kwargs
utils . bar = None
if not args . no_aria2 :
utils . aria2_hook ( pretrained_model_name_or_path , * * kwargs )
return old_from_pretrained ( cls , pretrained_model_name_or_path , * model_args , * * kwargs )
PreTrainedModel . from_pretrained = new_from_pretrained
if ( hasattr ( modeling_utils , " get_checkpoint_shard_files " ) ) :
old_get_checkpoint_shard_files = modeling_utils . get_checkpoint_shard_files
def new_get_checkpoint_shard_files ( pretrained_model_name_or_path , index_filename , * args , * * kwargs ) :
utils . num_shards = utils . get_num_shards ( index_filename )
utils . from_pretrained_index_filename = index_filename
return old_get_checkpoint_shard_files ( pretrained_model_name_or_path , index_filename , * args , * * kwargs )
modeling_utils . get_checkpoint_shard_files = new_get_checkpoint_shard_files
# Some versions of transformers 4.17.0.dev0 are affected by
# https://github.com/huggingface/transformers/issues/15736
# This is a workaround for those versions of transformers.
if ( transformers_version == " 4.17.0.dev0 " ) :
try :
from transformers . models . xglm . modeling_xglm import XGLMSinusoidalPositionalEmbedding
except ImportError :
pass
else :
@torch.no_grad ( )
def new_forward ( self , input_ids : torch . Tensor = None , inputs_embeds : torch . Tensor = None , past_key_values_length : int = 0 ) :
bsz , seq_len = inputs_embeds . size ( ) [ : - 1 ]
input_shape = inputs_embeds . size ( ) [ : - 1 ]
sequence_length = input_shape [ 1 ]
position_ids = torch . arange (
past_key_values_length + self . padding_idx + 1 , past_key_values_length + sequence_length + self . padding_idx + 1 , dtype = torch . long , device = inputs_embeds . device
) . unsqueeze ( 0 ) . expand ( input_shape ) . contiguous ( )
max_pos = self . padding_idx + 1 + seq_len + past_key_values_length
if max_pos > self . weights . size ( 0 ) :
self . make_weights ( max_pos + self . offset , self . embedding_dim , self . padding_idx )
return self . weights . index_select ( 0 , position_ids . view ( - 1 ) ) . view ( bsz , seq_len , - 1 ) . detach ( )
XGLMSinusoidalPositionalEmbedding . forward = new_forward
# Patch transformers to use our soft prompt
def patch_causallm ( cls ) :
old_forward = cls . forward
def new_causallm_forward ( self , * args , * * kwargs ) :
input_ids = kwargs . get ( ' input_ids ' ) . to ( self . device )
assert input_ids is not None
kwargs [ ' input_ids ' ] = None
if ( vars . sp is not None ) :
shifted_input_ids = input_ids - self . config . vocab_size
input_ids . clamp_ ( max = self . config . vocab_size - 1 )
if ( hasattr ( self , " transformer " ) ) :
inputs_embeds = self . transformer . wte ( input_ids )
elif ( not hasattr ( self . model , " decoder " ) ) :
inputs_embeds = self . model . embed_tokens ( input_ids )
else :
inputs_embeds = self . model . decoder . embed_tokens ( input_ids )
if ( vars . sp is not None ) :
vars . sp = vars . sp . to ( inputs_embeds . dtype ) . to ( inputs_embeds . device )
inputs_embeds = torch . where (
( shifted_input_ids > = 0 ) [ . . . , None ] ,
vars . sp [ shifted_input_ids . clamp ( min = 0 ) ] ,
inputs_embeds ,
)
if ( hasattr ( self , " model " ) and hasattr ( self . model , " embed_scale " ) ) :
inputs_embeds * = self . model . embed_scale
kwargs [ ' inputs_embeds ' ] = inputs_embeds
return old_forward ( self , * args , * * kwargs )
cls . forward = new_causallm_forward
for cls in ( GPT2LMHeadModel , GPTNeoForCausalLM ) :
patch_causallm ( cls )
for c in ( " GPTJForCausalLM " , " XGLMForCausalLM " , " OPTForCausalLM " ) :
try :
patch_causallm ( getattr ( __import__ ( " transformers " ) , c ) )
except :
pass
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
2022-06-10 05:48:28 +02:00
if ( packaging . version . parse ( " 4.19.0.dev0 " ) < = packaging . version . parse ( transformers_version ) < packaging . version . parse ( " 4.20.0 " ) ) :
2022-06-09 00:42:44 +02:00
try :
from transformers import OPTForCausalLM , OPTModel
except ImportError :
pass
else :
# This is the same as the original __init__ but with
# config.hidden_size
# replaced with
# config.word_embed_proj_dim
def new_init ( self , config ) :
super ( OPTForCausalLM , self ) . __init__ ( config )
self . model = OPTModel ( config )
self . lm_head = torch . nn . Linear ( config . word_embed_proj_dim , config . vocab_size , bias = False )
self . post_init ( )
OPTForCausalLM . __init__ = new_init
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList , LogitsWarper , LogitsProcessor , TopKLogitsWarper , TopPLogitsWarper , TemperatureLogitsWarper , RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor , TailFreeLogitsWarper , TypicalLogitsWarper
def dynamic_processor_wrap ( cls , field_name , var_name , cond = None ) :
old_call = cls . __call__
def new_call ( self , * args , * * kwargs ) :
if ( not isinstance ( field_name , str ) and isinstance ( field_name , Iterable ) ) :
conds = [ ]
for f , v in zip ( field_name , var_name ) :
conds . append ( getattr ( vars , v ) )
setattr ( self , f , conds [ - 1 ] )
else :
conds = getattr ( vars , var_name )
setattr ( self , field_name , conds )
assert len ( args ) == 2
if ( cond is None or cond ( conds ) ) :
return old_call ( self , * args , * * kwargs )
return args [ 1 ]
cls . __call__ = new_call
dynamic_processor_wrap ( AdvancedRepetitionPenaltyLogitsProcessor , ( " penalty " , " penalty_slope " , " penalty_range " ) , ( " rep_pen " , " rep_pen_slope " , " rep_pen_range " ) , cond = lambda x : x [ 0 ] != 1.0 )
dynamic_processor_wrap ( TopKLogitsWarper , " top_k " , " top_k " , cond = lambda x : x > 0 )
dynamic_processor_wrap ( TopPLogitsWarper , " top_p " , " top_p " , cond = lambda x : x < 1.0 )
dynamic_processor_wrap ( TailFreeLogitsWarper , " tfs " , " tfs " , cond = lambda x : x < 1.0 )
dynamic_processor_wrap ( TypicalLogitsWarper , " typical " , " typical " , cond = lambda x : x < 1.0 )
dynamic_processor_wrap ( TemperatureLogitsWarper , " temperature " , " temp " , cond = lambda x : x != 1.0 )
RepetitionPenaltyLogitsProcessor . __init__ = AdvancedRepetitionPenaltyLogitsProcessor . __init__
RepetitionPenaltyLogitsProcessor . __call__ = AdvancedRepetitionPenaltyLogitsProcessor . __call__
class LuaLogitsProcessor ( LogitsProcessor ) :
def __init__ ( self ) :
pass
def __call__ ( self , input_ids : torch . LongTensor , scores : torch . FloatTensor ) - > torch . FloatTensor :
assert scores . ndim == 2
assert input_ids . ndim == 2
self . regeneration_required = False
self . halt = False
scores_shape = scores . shape
scores_list = scores . tolist ( )
vars . lua_koboldbridge . logits = vars . lua_state . table ( )
for r , row in enumerate ( scores_list ) :
vars . lua_koboldbridge . logits [ r + 1 ] = vars . lua_state . table ( * row )
vars . lua_koboldbridge . vocab_size = scores_shape [ - 1 ]
execute_genmod ( )
scores = torch . tensor (
tuple ( tuple ( row . values ( ) ) for row in vars . lua_koboldbridge . logits . values ( ) ) ,
device = scores . device ,
dtype = scores . dtype ,
)
assert scores . shape == scores_shape
return scores
def new_get_logits_processor ( * args , * * kwargs ) - > LogitsProcessorList :
processors = new_get_logits_processor . old_get_logits_processor ( * args , * * kwargs )
processors . insert ( 0 , LuaLogitsProcessor ( ) )
return processors
new_get_logits_processor . old_get_logits_processor = transformers . generation_utils . GenerationMixin . _get_logits_processor
transformers . generation_utils . GenerationMixin . _get_logits_processor = new_get_logits_processor
def new_get_logits_warper ( beams : int = 1 , ) - > LogitsProcessorList :
warper_list = LogitsProcessorList ( )
warper_list . append ( TopKLogitsWarper ( top_k = 1 , min_tokens_to_keep = 1 + ( beams > 1 ) ) )
warper_list . append ( TopPLogitsWarper ( top_p = 0.5 , min_tokens_to_keep = 1 + ( beams > 1 ) ) )
warper_list . append ( TailFreeLogitsWarper ( tfs = 0.5 , min_tokens_to_keep = 1 + ( beams > 1 ) ) )
warper_list . append ( TypicalLogitsWarper ( typical = 0.5 , min_tokens_to_keep = 1 + ( beams > 1 ) ) )
warper_list . append ( TemperatureLogitsWarper ( temperature = 0.5 ) )
return warper_list
def new_sample ( self , * args , * * kwargs ) :
assert kwargs . pop ( " logits_warper " , None ) is not None
kwargs [ " logits_warper " ] = new_get_logits_warper (
beams = 1 ,
)
if ( vars . newlinemode == " s " ) or ( vars . newlinemode == " ns " ) :
kwargs [ " eos_token_id " ] = - 1
kwargs . setdefault ( " pad_token_id " , 2 )
return new_sample . old_sample ( self , * args , * * kwargs )
new_sample . old_sample = transformers . generation_utils . GenerationMixin . sample
transformers . generation_utils . GenerationMixin . sample = new_sample
# Allow bad words filter to ban <|endoftext|> token
import transformers . generation_logits_process
def new_init ( self , bad_words_ids : List [ List [ int ] ] , eos_token_id : int ) :
return new_init . old_init ( self , bad_words_ids , - 1 )
new_init . old_init = transformers . generation_logits_process . NoBadWordsLogitsProcessor . __init__
transformers . generation_logits_process . NoBadWordsLogitsProcessor . __init__ = new_init
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria ( StoppingCriteria ) :
def __init__ (
self ,
tokenizer ,
excluded_world_info : List [ Set ] ,
) :
self . regeneration_required = False
self . halt = False
self . tokenizer = tokenizer
self . excluded_world_info = excluded_world_info
def __call__ (
self ,
input_ids : torch . LongTensor ,
scores : torch . FloatTensor ,
* * kwargs ,
) - > bool :
vars . generated_tkns + = 1
if ( vars . lua_koboldbridge . generated_cols and vars . generated_tkns != vars . lua_koboldbridge . generated_cols ) :
raise RuntimeError ( f " Inconsistency detected between KoboldAI Python and Lua backends ( { vars . generated_tkns } != { vars . lua_koboldbridge . generated_cols } ) " )
if ( vars . abort or vars . generated_tkns > = vars . genamt ) :
self . regeneration_required = False
self . halt = False
return True
assert input_ids . ndim == 2
assert len ( self . excluded_world_info ) == input_ids . shape [ 0 ]
self . regeneration_required = vars . lua_koboldbridge . regeneration_required
self . halt = not vars . lua_koboldbridge . generating
vars . lua_koboldbridge . regeneration_required = False
for i in range ( vars . numseqs ) :
vars . lua_koboldbridge . generated [ i + 1 ] [ vars . generated_tkns ] = int ( input_ids [ i , - 1 ] . item ( ) )
if ( not vars . dynamicscan ) :
return self . regeneration_required or self . halt
tail = input_ids [ . . . , - vars . generated_tkns : ]
for i , t in enumerate ( tail ) :
decoded = utils . decodenewlines ( tokenizer . decode ( t ) )
_ , found = checkworldinfo ( decoded , force_use_txt = True , actions = vars . _actions )
found - = self . excluded_world_info [ i ]
if ( len ( found ) != 0 ) :
self . regeneration_required = True
break
return self . regeneration_required or self . halt
old_get_stopping_criteria = transformers . generation_utils . GenerationMixin . _get_stopping_criteria
def new_get_stopping_criteria ( self , * args , * * kwargs ) :
stopping_criteria = old_get_stopping_criteria ( self , * args , * * kwargs )
global tokenizer
self . kai_scanner = DynamicWorldInfoScanCriteria (
tokenizer = tokenizer ,
excluded_world_info = self . kai_scanner_excluded_world_info ,
)
stopping_criteria . insert ( 0 , self . kai_scanner )
return stopping_criteria
transformers . generation_utils . GenerationMixin . _get_stopping_criteria = new_get_stopping_criteria
2022-03-12 20:21:11 +01:00
2022-06-07 01:21:10 +02:00
def load_model ( use_gpu = True , gpu_layers = None , initial_load = False , online_model = " " ) :
2022-03-07 01:51:35 +01:00
global model
global generator
2022-03-08 00:44:37 +01:00
global torch
2022-06-06 20:27:47 +02:00
global model_config
2022-06-09 00:53:56 +02:00
global GPT2TokenizerFast
2022-06-09 15:01:40 +02:00
global tokenizer
2022-06-09 00:42:44 +02:00
print ( " Loading vars.model: {} vars.custmodpth: {} " . format ( vars . model , vars . custmodpth ) )
2022-03-07 01:51:35 +01:00
vars . noai = False
2022-03-07 18:33:35 +01:00
if not initial_load :
set_aibusy ( True )
2022-06-06 18:49:40 +02:00
if vars . model != ' ReadOnly ' :
emit ( ' from_server ' , { ' cmd ' : ' model_load_status ' , ' data ' : " Loading {} " . format ( vars . model ) } , broadcast = True )
#Have to add a sleep so the server will send the emit for some reason
time . sleep ( 0.1 )
2022-03-07 17:27:23 +01:00
if gpu_layers is not None :
args . breakmodel_gpulayers = gpu_layers
2021-12-23 02:50:06 +01:00
2022-03-08 00:44:37 +01:00
#We need to wipe out the existing model and refresh the cuda cache
model = None
generator = None
2022-06-09 00:42:44 +02:00
model_config = None
2022-06-06 15:49:46 +02:00
try :
torch . cuda . empty_cache ( )
except :
pass
2022-06-09 00:42:44 +02:00
#Reload our badwords
vars . badwordsids = vars . badwordsids_default
2022-03-08 00:44:37 +01:00
2022-06-07 01:21:10 +02:00
#Let's set the GooseAI or OpenAI server URLs if that's applicable
if online_model != " " :
if path . exists ( " settings/ {} .settings " . format ( vars . model ) ) :
changed = False
with open ( " settings/ {} .settings " . format ( vars . model ) , " r " ) as file :
# Check if API key exists
js = json . load ( file )
if ' online_model ' in js :
if js [ ' online_model ' ] != online_model :
changed = True
js [ ' online_model ' ] = online_model
else :
changed = True
js [ ' online_model ' ] = online_model
if changed :
with open ( " settings/ {} .settings " . format ( vars . model ) , " w " ) as file :
file . write ( json . dumps ( js , indent = 3 ) )
# Swap OAI Server if GooseAI was selected
if ( vars . model == " GooseAI " ) :
vars . oaiengines = " https://api.goose.ai/v1/engines "
vars . model = " OAI "
2022-06-07 19:47:10 +02:00
args . configname = " GooseAI " + " / " + online_model
else :
args . configname = vars . model + " / " + online_model
2022-06-07 01:21:10 +02:00
vars . oaiurl = vars . oaiengines + " / {0} /completions " . format ( online_model )
2022-06-07 21:32:58 +02:00
2022-03-07 01:51:35 +01:00
# If transformers model was selected & GPU available, ask to use CPU or GPU
2022-06-06 15:49:46 +02:00
if ( vars . model not in [ " InferKit " , " Colab " , " OAI " , " GooseAI " , " ReadOnly " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ] ) :
2022-03-07 01:51:35 +01:00
vars . allowsp = True
# Test for GPU support
# Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type
# This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717
2022-06-06 15:49:46 +02:00
if ( not vars . model in [ " NeoCustom " , " GPT2Custom " ] ) :
vars . custmodpth = vars . model
elif ( vars . model == " NeoCustom " ) :
vars . model = os . path . basename ( os . path . normpath ( vars . custmodpth ) )
2022-03-07 01:51:35 +01:00
# Get the model_type from the config or assume a model type if it isn't present
from transformers import AutoConfig
2022-06-06 15:49:46 +02:00
if ( os . path . isdir ( vars . custmodpth . replace ( ' / ' , ' _ ' ) ) ) :
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
model_config = AutoConfig . from_pretrained ( vars . custmodpth . replace ( ' / ' , ' _ ' ) , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
vars . model_type = model_config . model_type
except ValueError as e :
vars . model_type = " not_found "
2022-06-06 15:49:46 +02:00
elif ( os . path . isdir ( " models/ {} " . format ( vars . custmodpth . replace ( ' / ' , ' _ ' ) ) ) ) :
2022-03-08 00:44:37 +01:00
try :
2022-06-06 15:49:46 +02:00
model_config = AutoConfig . from_pretrained ( " models/ {} " . format ( vars . custmodpth . replace ( ' / ' , ' _ ' ) ) , revision = vars . revision , cache_dir = " cache " )
2022-03-08 00:44:37 +01:00
vars . model_type = model_config . model_type
except ValueError as e :
vars . model_type = " not_found "
2022-03-07 01:51:35 +01:00
else :
try :
2022-06-06 15:49:46 +02:00
model_config = AutoConfig . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
vars . model_type = model_config . model_type
except ValueError as e :
vars . model_type = " not_found "
if ( vars . model_type == " not_found " and vars . model == " NeoCustom " ) :
vars . model_type = " gpt_neo "
elif ( vars . model_type == " not_found " and vars . model == " GPT2Custom " ) :
vars . model_type = " gpt2 "
elif ( vars . model_type == " not_found " ) :
print ( " WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom) " )
vars . model_type = " gpt_neo "
2022-06-06 15:49:46 +02:00
if ( vars . model_type == " opt " ) :
vars . badwordsids = vars . badwordsids_opt
if ( not vars . use_colab_tpu and vars . model not in [ " InferKit " , " Colab " , " OAI " , " GooseAI " , " ReadOnly " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ] ) :
2022-03-07 01:51:35 +01:00
loadmodelsettings ( )
loadsettings ( )
print ( " {0} Looking for GPU support... {1} " . format ( colors . PURPLE , colors . END ) , end = " " )
vars . hascuda = torch . cuda . is_available ( )
2022-06-06 15:49:46 +02:00
vars . bmsupported = vars . model_type in ( " gpt_neo " , " gptj " , " xglm " , " opt " ) and not vars . nobreakmodel
2022-03-07 01:51:35 +01:00
if ( args . breakmodel is not None and args . breakmodel ) :
print ( " WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details). " , file = sys . stderr )
if ( args . breakmodel_layers is not None ) :
print ( " WARNING: --breakmodel_layers is deprecated. Use --breakmodel_gpulayers instead (see --help for details). " , file = sys . stderr )
if ( args . model and vars . bmsupported and not args . breakmodel_gpulayers and not args . breakmodel_layers ) :
print ( " WARNING: Model launched without the --breakmodel_gpulayers argument, defaulting to GPU only mode. " , file = sys . stderr )
vars . bmsupported = False
if ( not vars . bmsupported and ( args . breakmodel_gpulayers is not None or args . breakmodel_layers is not None ) ) :
print ( " WARNING: This model does not support hybrid generation. --breakmodel_gpulayers will be ignored. " , file = sys . stderr )
2021-11-14 03:13:52 +01:00
if ( vars . hascuda ) :
2022-03-07 01:51:35 +01:00
print ( " {0} FOUND! {1} " . format ( colors . GREEN , colors . END ) )
2021-11-14 03:13:52 +01:00
else :
2022-03-07 01:51:35 +01:00
print ( " {0} NOT FOUND! {1} " . format ( colors . YELLOW , colors . END ) )
if args . model :
if ( vars . hascuda ) :
2021-11-14 03:13:52 +01:00
genselected = True
2022-03-07 01:51:35 +01:00
vars . usegpu = True
2021-11-14 03:13:52 +01:00
vars . breakmodel = False
2022-03-07 01:51:35 +01:00
if ( vars . bmsupported ) :
2021-11-14 03:13:52 +01:00
vars . usegpu = False
2022-03-07 01:51:35 +01:00
vars . breakmodel = True
if ( args . cpu ) :
vars . usegpu = False
vars . breakmodel = False
elif ( vars . hascuda ) :
if ( vars . bmsupported ) :
2021-11-14 03:13:52 +01:00
genselected = True
2022-03-07 01:51:35 +01:00
vars . usegpu = False
vars . breakmodel = True
2021-11-14 03:13:52 +01:00
else :
2022-03-07 01:51:35 +01:00
genselected = False
2021-11-14 03:13:52 +01:00
else :
2022-03-07 01:51:35 +01:00
genselected = False
if ( vars . hascuda ) :
2022-06-06 15:49:46 +02:00
if ( use_gpu ) :
2021-11-14 03:13:52 +01:00
if ( vars . bmsupported ) :
vars . breakmodel = True
vars . usegpu = False
genselected = True
2022-03-07 01:51:35 +01:00
else :
vars . breakmodel = False
2021-11-14 03:13:52 +01:00
vars . usegpu = True
2022-03-07 01:51:35 +01:00
genselected = True
2022-06-06 15:49:46 +02:00
else :
2021-11-14 03:13:52 +01:00
vars . breakmodel = False
vars . usegpu = False
genselected = True
2022-03-07 01:51:35 +01:00
# Ask for API key if InferKit was selected
if ( vars . model == " InferKit " ) :
2022-06-07 14:22:10 +02:00
vars . apikey = vars . oaiapikey
2022-03-07 01:51:35 +01:00
# Swap OAI Server if GooseAI was selected
if ( vars . model == " GooseAI " ) :
vars . oaiengines = " https://api.goose.ai/v1/engines "
vars . model = " OAI "
args . configname = " GooseAI "
# Ask for API key if OpenAI was selected
if ( vars . model == " OAI " ) :
if not args . configname :
args . configname = " OAI "
if ( vars . model == " ReadOnly " ) :
vars . noai = True
# Start transformers and create pipeline
2022-06-06 15:49:46 +02:00
if ( not vars . use_colab_tpu and vars . model not in [ " InferKit " , " Colab " , " OAI " , " GooseAI " , " ReadOnly " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ] ) :
2022-03-07 01:51:35 +01:00
if ( not vars . noai ) :
print ( " {0} Initializing transformers, please wait... {1} " . format ( colors . PURPLE , colors . END ) )
for m in ( " GPTJModel " , " XGLMModel " ) :
try :
globals ( ) [ m ] = getattr ( __import__ ( " transformers " ) , m )
except :
pass
2022-06-06 15:49:46 +02:00
2022-03-07 01:51:35 +01:00
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback ( n_layers , convert_to_float16 = True ) :
if not vars . lazy_load :
return
2021-11-14 03:13:52 +01:00
2022-06-06 15:49:46 +02:00
from tqdm . auto import tqdm
2021-11-14 03:13:52 +01:00
2022-03-07 01:51:35 +01:00
if " breakmodel " in globals ( ) :
gpu_blocks = breakmodel . gpu_blocks
ram_blocks = ram_blocks = n_layers - sum ( gpu_blocks )
cumulative_gpu_blocks = tuple ( itertools . accumulate ( gpu_blocks ) )
else :
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
2021-11-14 03:13:52 +01:00
2022-03-07 01:51:35 +01:00
def lazy_load_callback ( model_dict , f , * * _ ) :
2022-06-06 15:49:46 +02:00
if lazy_load_callback . nested :
return
lazy_load_callback . nested = True
2022-03-04 06:33:10 +01:00
2022-03-07 01:51:35 +01:00
device_map = { }
2021-11-14 03:13:52 +01:00
2022-03-07 01:51:35 +01:00
for _key , spec in lazy_load_spec . get ( " layer_weights " , { } ) . items ( ) :
for layer in range ( n_layers ) :
key = _key . format ( layer = layer )
if key not in model_dict :
continue
device = vars . gpu_device if vars . hascuda and vars . usegpu else " cpu " if not vars . hascuda or not vars . breakmodel or layer < ram_blocks else bisect . bisect_right ( cumulative_gpu_blocks , layer - ram_blocks )
device_map [ key ] = device
2022-03-04 06:33:10 +01:00
2022-03-07 01:51:35 +01:00
for key , value in model_dict . items ( ) :
if isinstance ( value , torch_lazy_loader . LazyTensor ) and key not in device_map :
device_map [ key ] = vars . gpu_device if vars . hascuda and vars . usegpu else " cpu "
2022-06-06 15:49:46 +02:00
if utils . num_shards is None or utils . current_shard == 0 :
2022-05-13 05:51:40 +02:00
if utils . num_shards is not None :
2022-06-06 15:49:46 +02:00
num_tensors = len ( utils . get_sharded_checkpoint_num_tensors ( utils . from_pretrained_model_name , utils . from_pretrained_index_filename , * * utils . from_pretrained_kwargs ) )
else :
num_tensors = len ( device_map )
print ( flush = True )
2022-06-06 18:49:40 +02:00
utils . bar = tqdm ( total = num_tensors , desc = " Loading model tensors " , file = Send_to_socketio ( ) )
2022-03-04 06:33:10 +01:00
2022-03-07 01:51:35 +01:00
with zipfile . ZipFile ( f , " r " ) as z :
try :
last_storage_key = None
f = None
2022-06-06 15:49:46 +02:00
current_offset = 0
if utils . num_shards is not None :
utils . current_shard + = 1
for key in sorted ( device_map . keys ( ) , key = lambda k : ( model_dict [ k ] . key , model_dict [ k ] . seek_offset ) ) :
2022-03-07 01:51:35 +01:00
storage_key = model_dict [ key ] . key
2022-06-06 15:49:46 +02:00
if storage_key != last_storage_key or model_dict [ key ] . seek_offset < current_offset :
2022-03-07 01:51:35 +01:00
last_storage_key = storage_key
if isinstance ( f , zipfile . ZipExtFile ) :
f . close ( )
f = z . open ( f " archive/data/ { storage_key } " )
2022-06-06 15:49:46 +02:00
current_offset = 0
2022-03-07 01:51:35 +01:00
if current_offset != model_dict [ key ] . seek_offset :
2022-06-06 15:49:46 +02:00
f . read ( model_dict [ key ] . seek_offset - current_offset )
current_offset = model_dict [ key ] . seek_offset
2022-03-07 01:51:35 +01:00
device = device_map [ key ]
2022-06-06 15:49:46 +02:00
size = functools . reduce ( lambda x , y : x * y , model_dict [ key ] . shape , 1 )
dtype = model_dict [ key ] . dtype
nbytes = size if dtype is torch . bool else size * ( ( torch . finfo if dtype . is_floating_point else torch . iinfo ) ( dtype ) . bits >> 3 )
2022-03-07 01:51:35 +01:00
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict [ key ] = model_dict [ key ] . materialize ( f , map_location = " cpu " )
2022-06-06 15:49:46 +02:00
if model_dict [ key ] . dtype is torch . float32 :
vars . fp32_model = True
2022-03-07 01:51:35 +01:00
if convert_to_float16 and vars . hascuda and ( vars . breakmodel or vars . usegpu ) and model_dict [ key ] . dtype is torch . float32 :
model_dict [ key ] = model_dict [ key ] . to ( torch . float16 )
if not vars . usegpu and not vars . breakmodel and model_dict [ key ] . dtype is torch . float16 :
model_dict [ key ] = model_dict [ key ] . to ( torch . float32 )
model_dict [ key ] = model_dict [ key ] . to ( device )
#print("OK", flush=True)
2022-06-06 15:49:46 +02:00
current_offset + = nbytes
utils . bar . update ( 1 )
2022-03-07 01:51:35 +01:00
finally :
2022-06-06 15:49:46 +02:00
if utils . num_shards is None or utils . current_shard > = utils . num_shards :
utils . bar . close ( )
utils . bar = None
lazy_load_callback . nested = False
2022-03-07 01:51:35 +01:00
if isinstance ( f , zipfile . ZipExtFile ) :
f . close ( )
2022-06-06 15:49:46 +02:00
lazy_load_callback . nested = False
2022-03-07 01:51:35 +01:00
return lazy_load_callback
2022-06-06 15:49:46 +02:00
lazy_load_config_path = os . path . join ( " maps " , vars . model_type + " .json " )
if ( vars . lazy_load and " model_config " in globals ( ) and os . path . isfile ( lazy_load_config_path ) ) :
2022-03-07 01:51:35 +01:00
with open ( lazy_load_config_path ) as f :
lazy_load_spec = json . load ( f )
2022-03-04 06:33:10 +01:00
else :
2022-03-07 01:51:35 +01:00
vars . lazy_load = False
2021-11-14 03:13:52 +01:00
2022-03-07 01:51:35 +01:00
def get_hidden_size_from_model ( model ) :
2021-11-15 17:56:02 +01:00
try :
2022-06-06 15:49:46 +02:00
return int ( model . model . decoder . project_in . in_features )
2021-11-15 17:56:02 +01:00
except :
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
return int ( model . model . decoder . embed_tokens . out_features )
2022-03-07 01:51:35 +01:00
except :
2022-05-13 07:03:38 +02:00
try :
2022-06-06 15:49:46 +02:00
return int ( model . transformer . hidden_size )
2022-05-13 07:03:38 +02:00
except :
2022-06-06 15:49:46 +02:00
try :
return int ( model . transformer . embed_dim )
except :
return int ( model . lm_head . in_features )
2022-03-04 06:33:10 +01:00
2022-03-07 01:51:35 +01:00
def maybe_low_cpu_mem_usage ( ) - > Dict [ str , Any ] :
if ( packaging . version . parse ( transformers_version ) < packaging . version . parse ( " 4.11.0 " ) ) :
print ( f " \n WARNING: Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers { transformers_version } . " , file = sys . stderr )
return { }
return { " low_cpu_mem_usage " : True }
2022-02-06 19:42:46 +01:00
2022-03-07 01:51:35 +01:00
@contextlib.contextmanager
def maybe_use_float16 ( always_use = False ) :
if ( always_use or ( vars . hascuda and args . lowmem and ( vars . usegpu or vars . breakmodel ) ) ) :
original_dtype = torch . get_default_dtype ( )
torch . set_default_dtype ( torch . float16 )
yield True
torch . set_default_dtype ( original_dtype )
2022-03-04 06:33:10 +01:00
else :
2022-03-07 01:51:35 +01:00
yield False
# If custom GPT2 model was chosen
if ( vars . model == " GPT2Custom " ) :
vars . lazy_load = False
2022-06-06 15:49:46 +02:00
model_config = open ( vars . custmodpth + " /config.json " , " r " )
2022-03-07 01:51:35 +01:00
js = json . load ( model_config )
with ( maybe_use_float16 ( ) ) :
2022-06-06 15:49:46 +02:00
model = GPT2LMHeadModel . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
tokenizer = GPT2TokenizerFast . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
vars . modeldim = get_hidden_size_from_model ( model )
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if ( vars . hascuda and vars . usegpu ) :
2021-12-16 01:03:49 +01:00
model = model . half ( ) . to ( vars . gpu_device )
2021-11-14 03:13:52 +01:00
generator = model . generate
else :
2021-12-16 01:40:04 +01:00
model = model . to ( ' cpu ' ) . float ( )
2021-11-14 03:13:52 +01:00
generator = model . generate
2022-03-07 01:51:35 +01:00
# Use the Generic implementation
2021-11-14 03:13:52 +01:00
else :
2022-03-07 01:51:35 +01:00
lowmem = maybe_low_cpu_mem_usage ( )
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
# using a GPT-2 model because GPT-2 is not compatible with this
# feature yet
if ( vars . model_type == " gpt2 " ) :
lowmem = { }
# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
if ( vars . lazy_load and vars . hascuda and vars . breakmodel ) :
device_config ( model_config )
# Download model from Huggingface if it does not exist, otherwise load locally
#If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new)
if os . path . isdir ( vars . model . replace ( ' / ' , ' _ ' ) ) :
import shutil
shutil . move ( vars . model . replace ( ' / ' , ' _ ' ) , " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) )
2022-06-06 15:49:46 +02:00
print ( " \n " , flush = True )
with maybe_use_float16 ( ) , torch_lazy_loader . use_lazy_torch_load ( enable = vars . lazy_load , callback = get_lazy_load_callback ( utils . num_layers ( model_config ) ) if vars . lazy_load else None , dematerialized_modules = True ) :
2022-03-07 01:51:35 +01:00
if ( vars . lazy_load ) : # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = { }
if ( os . path . isdir ( vars . custmodpth ) ) :
try :
2022-06-06 15:49:46 +02:00
tokenizer = AutoTokenizer . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
2022-03-10 19:52:15 +01:00
except Exception as e :
2022-06-06 15:49:46 +02:00
try :
tokenizer = GPT2TokenizerFast . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " )
except Exception as e :
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
model = AutoModelForCausalLM . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " , * * lowmem )
2022-03-10 19:52:15 +01:00
except Exception as e :
2022-06-06 15:49:46 +02:00
model = GPTNeoForCausalLM . from_pretrained ( vars . custmodpth , revision = vars . revision , cache_dir = " cache " , * * lowmem )
2022-03-07 01:51:35 +01:00
elif ( os . path . isdir ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) ) ) :
try :
2022-06-06 15:49:46 +02:00
tokenizer = AutoTokenizer . from_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , revision = vars . revision , cache_dir = " cache " )
2022-03-10 19:52:15 +01:00
except Exception as e :
2022-06-06 15:49:46 +02:00
try :
tokenizer = GPT2TokenizerFast . from_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , revision = vars . revision , cache_dir = " cache " )
except Exception as e :
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
model = AutoModelForCausalLM . from_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , revision = vars . revision , cache_dir = " cache " , * * lowmem )
except Exception as e :
model = GPTNeoForCausalLM . from_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , revision = vars . revision , cache_dir = " cache " , * * lowmem )
2022-03-07 01:51:35 +01:00
else :
2022-06-06 15:49:46 +02:00
old_rebuild_tensor = torch . _utils . _rebuild_tensor
def new_rebuild_tensor ( storage : Union [ torch_lazy_loader . LazyTensor , torch . Storage ] , storage_offset , shape , stride ) :
if ( not isinstance ( storage , torch_lazy_loader . LazyTensor ) ) :
dtype = storage . dtype
2022-05-14 06:45:43 +02:00
else :
2022-06-06 15:49:46 +02:00
dtype = storage . storage_type . dtype
if ( not isinstance ( dtype , torch . dtype ) ) :
dtype = storage . storage_type ( 0 ) . dtype
if ( dtype is torch . float32 and len ( shape ) > = 2 ) :
vars . fp32_model = True
return old_rebuild_tensor ( storage , storage_offset , shape , stride )
torch . _utils . _rebuild_tensor = new_rebuild_tensor
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
tokenizer = AutoTokenizer . from_pretrained ( vars . model , revision = vars . revision , cache_dir = " cache " )
except Exception as e :
try :
tokenizer = GPT2TokenizerFast . from_pretrained ( vars . model , revision = vars . revision , cache_dir = " cache " )
except Exception as e :
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
try :
2022-06-06 15:49:46 +02:00
model = AutoModelForCausalLM . from_pretrained ( vars . model , revision = vars . revision , cache_dir = " cache " , * * lowmem )
except Exception as e :
model = GPTNeoForCausalLM . from_pretrained ( vars . model , revision = vars . revision , cache_dir = " cache " , * * lowmem )
torch . _utils . _rebuild_tensor = old_rebuild_tensor
2022-03-07 01:51:35 +01:00
2022-06-06 15:49:46 +02:00
if not args . colab or args . savemodel :
2022-03-07 01:51:35 +01:00
import shutil
tokenizer . save_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) )
2022-06-06 15:49:46 +02:00
if ( vars . fp32_model ) : # Use save_pretrained to convert fp32 models to fp16
model = model . half ( )
model . save_pretrained ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , max_shard_size = " 500MiB " )
else : # For fp16 models, we can just copy the model files directly
import transformers . configuration_utils
import transformers . modeling_utils
import transformers . file_utils
# Save the config.json
shutil . move ( transformers . file_utils . get_from_cache ( transformers . file_utils . hf_bucket_url ( vars . model , transformers . configuration_utils . CONFIG_NAME , revision = vars . revision ) , cache_dir = " cache " , local_files_only = True ) , os . path . join ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , transformers . configuration_utils . CONFIG_NAME ) )
if ( utils . num_shards is None ) :
# Save the pytorch_model.bin of an unsharded model
shutil . move ( transformers . file_utils . get_from_cache ( transformers . file_utils . hf_bucket_url ( vars . model , transformers . modeling_utils . WEIGHTS_NAME , revision = vars . revision ) , cache_dir = " cache " , local_files_only = True ) , os . path . join ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , transformers . modeling_utils . WEIGHTS_NAME ) )
else :
with open ( utils . from_pretrained_index_filename ) as f :
map_data = json . load ( f )
filenames = set ( map_data [ " weight_map " ] . values ( ) )
# Save the pytorch_model.bin.index.json of a sharded model
shutil . move ( utils . from_pretrained_index_filename , os . path . join ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , transformers . modeling_utils . WEIGHTS_INDEX_NAME ) )
# Then save the pytorch_model-#####-of-#####.bin files
for filename in filenames :
shutil . move ( transformers . file_utils . get_from_cache ( transformers . file_utils . hf_bucket_url ( vars . model , filename , revision = vars . revision ) , cache_dir = " cache " , local_files_only = True ) , os . path . join ( " models/ {} " . format ( vars . model . replace ( ' / ' , ' _ ' ) ) , filename ) )
2022-03-07 01:51:35 +01:00
shutil . rmtree ( " cache/ " )
if ( vars . hascuda ) :
if ( vars . usegpu ) :
vars . modeldim = get_hidden_size_from_model ( model )
model = model . half ( ) . to ( vars . gpu_device )
generator = model . generate
elif ( vars . breakmodel ) : # Use both RAM and VRAM (breakmodel)
vars . modeldim = get_hidden_size_from_model ( model )
if ( not vars . lazy_load ) :
device_config ( model . config )
move_model_to_devices ( model )
else :
model = model . to ( ' cpu ' ) . float ( )
vars . modeldim = get_hidden_size_from_model ( model )
generator = model . generate
else :
model . to ( ' cpu ' ) . float ( )
vars . modeldim = get_hidden_size_from_model ( model )
generator = model . generate
# Suppress Author's Note by flagging square brackets (Old implementation)
#vocab = tokenizer.get_vocab()
#vocab_keys = vocab.keys()
#vars.badwords = gettokenids("[")
#for key in vars.badwords:
# vars.badwordsids.append([vocab[key]])
print ( " {0} OK! {1} pipeline created! {2} " . format ( colors . GREEN , vars . model , colors . END ) )
2021-11-14 03:13:52 +01:00
2022-03-07 01:51:35 +01:00
else :
2022-06-06 15:49:46 +02:00
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2021-12-11 18:45:45 +01:00
else :
2022-06-06 15:49:46 +02:00
from transformers import PreTrainedModel
from transformers import modeling_utils
old_from_pretrained = PreTrainedModel . from_pretrained . __func__
@classmethod
def new_from_pretrained ( cls , pretrained_model_name_or_path , * model_args , * * kwargs ) :
vars . fp32_model = False
utils . num_shards = None
utils . current_shard = 0
utils . from_pretrained_model_name = pretrained_model_name_or_path
utils . from_pretrained_index_filename = None
utils . from_pretrained_kwargs = kwargs
utils . bar = None
if not args . no_aria2 :
utils . aria2_hook ( pretrained_model_name_or_path , * * kwargs )
return old_from_pretrained ( cls , pretrained_model_name_or_path , * model_args , * * kwargs )
PreTrainedModel . from_pretrained = new_from_pretrained
if ( hasattr ( modeling_utils , " get_checkpoint_shard_files " ) ) :
old_get_checkpoint_shard_files = modeling_utils . get_checkpoint_shard_files
def new_get_checkpoint_shard_files ( pretrained_model_name_or_path , index_filename , * args , * * kwargs ) :
utils . num_shards = utils . get_num_shards ( index_filename )
utils . from_pretrained_index_filename = index_filename
return old_get_checkpoint_shard_files ( pretrained_model_name_or_path , index_filename , * args , * * kwargs )
modeling_utils . get_checkpoint_shard_files = new_get_checkpoint_shard_files
2022-01-07 19:47:21 +01:00
2022-03-07 01:51:35 +01:00
def tpumtjgenerate_warper_callback ( scores ) - > " np.array " :
scores_shape = scores . shape
scores_list = scores . tolist ( )
vars . lua_koboldbridge . logits = vars . lua_state . table ( )
for r , row in enumerate ( scores_list ) :
vars . lua_koboldbridge . logits [ r + 1 ] = vars . lua_state . table ( * row )
vars . lua_koboldbridge . vocab_size = scores_shape [ - 1 ]
2022-01-15 03:39:02 +01:00
2022-03-07 01:51:35 +01:00
execute_genmod ( )
2022-01-15 03:39:02 +01:00
2022-03-07 01:51:35 +01:00
scores = np . array (
tuple ( tuple ( row . values ( ) ) for row in vars . lua_koboldbridge . logits . values ( ) ) ,
dtype = scores . dtype ,
)
assert scores . shape == scores_shape
2022-01-15 03:39:02 +01:00
2022-03-07 01:51:35 +01:00
return scores
def tpumtjgenerate_stopping_callback ( generated , n_generated , excluded_world_info ) - > Tuple [ List [ set ] , bool , bool ] :
vars . generated_tkns + = 1
2022-01-15 05:00:06 +01:00
2022-03-07 01:51:35 +01:00
assert len ( excluded_world_info ) == len ( generated )
regeneration_required = vars . lua_koboldbridge . regeneration_required
halt = vars . abort or not vars . lua_koboldbridge . generating or vars . generated_tkns > = vars . genamt
vars . lua_koboldbridge . regeneration_required = False
2022-01-15 05:00:06 +01:00
2022-03-07 01:51:35 +01:00
global past
2022-01-15 05:00:06 +01:00
2022-03-07 01:51:35 +01:00
for i in range ( vars . numseqs ) :
vars . lua_koboldbridge . generated [ i + 1 ] [ vars . generated_tkns ] = int ( generated [ i , tpu_mtj_backend . params [ " seq " ] + n_generated - 1 ] . item ( ) )
if ( not vars . dynamicscan or halt ) :
return excluded_world_info , regeneration_required , halt
2022-01-15 05:00:06 +01:00
2022-03-07 01:51:35 +01:00
for i , t in enumerate ( generated ) :
decoded = utils . decodenewlines ( tokenizer . decode ( past [ i ] ) ) + utils . decodenewlines ( tokenizer . decode ( t [ tpu_mtj_backend . params [ " seq " ] : tpu_mtj_backend . params [ " seq " ] + n_generated ] ) )
_ , found = checkworldinfo ( decoded , force_use_txt = True , actions = vars . _actions )
found - = excluded_world_info [ i ]
if ( len ( found ) != 0 ) :
regeneration_required = True
break
2022-01-15 03:39:02 +01:00
return excluded_world_info , regeneration_required , halt
2022-03-07 01:51:35 +01:00
def tpumtjgenerate_compiling_callback ( ) - > None :
print ( colors . GREEN + " TPU backend compilation triggered " + colors . END )
vars . compiling = True
2022-01-15 03:39:02 +01:00
2022-03-07 01:51:35 +01:00
def tpumtjgenerate_stopped_compiling_callback ( ) - > None :
vars . compiling = False
def tpumtjgenerate_settings_callback ( ) - > dict :
return {
" top_p " : float ( vars . top_p ) ,
" temp " : float ( vars . temp ) ,
" top_k " : int ( vars . top_k ) ,
" tfs " : float ( vars . tfs ) ,
2022-06-06 15:49:46 +02:00
" typical " : float ( vars . typical ) ,
2022-03-07 01:51:35 +01:00
" repetition_penalty " : float ( vars . rep_pen ) ,
" rpslope " : float ( vars . rep_pen_slope ) ,
" rprange " : int ( vars . rep_pen_range ) ,
}
2022-01-17 03:09:10 +01:00
2022-03-07 01:51:35 +01:00
# If we're running Colab or OAI, we still need a tokenizer.
if ( vars . model == " Colab " ) :
2022-06-06 15:49:46 +02:00
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast . from_pretrained ( " EleutherAI/gpt-neo-2.7B " , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
loadsettings ( )
elif ( vars . model == " OAI " ) :
2022-06-06 15:49:46 +02:00
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-03-07 01:51:35 +01:00
loadsettings ( )
# Load the TPU backend if requested
2022-06-06 15:49:46 +02:00
elif ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2022-06-07 15:05:51 +02:00
global tpu_mtj_backend
import tpu_mtj_backend
2022-06-06 15:49:46 +02:00
if ( vars . model == " TPUMeshTransformerGPTNeoX " ) :
vars . badwordsids = vars . badwordsids_neox
2022-03-07 01:51:35 +01:00
print ( " {0} Initializing Mesh Transformer JAX, please wait... {1} " . format ( colors . PURPLE , colors . END ) )
2022-06-06 15:49:46 +02:00
if vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) and ( not vars . custmodpth or not os . path . isdir ( vars . custmodpth ) ) :
2022-03-07 01:51:35 +01:00
raise FileNotFoundError ( f " The specified model path { repr ( vars . custmodpth ) } is not the path to a valid folder " )
import tpu_mtj_backend
2022-06-06 15:49:46 +02:00
if ( vars . model == " TPUMeshTransformerGPTNeoX " or vars . model_type == " opt " ) :
tpu_mtj_backend . pad_token_id = 1
2022-03-07 01:51:35 +01:00
tpu_mtj_backend . vars = vars
tpu_mtj_backend . warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend . stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend . compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend . stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend . settings_callback = tpumtjgenerate_settings_callback
vars . allowsp = True
loadmodelsettings ( )
loadsettings ( )
2022-06-06 15:49:46 +02:00
tpu_mtj_backend . load_model ( vars . custmodpth , hf_checkpoint = vars . model not in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) and vars . use_colab_tpu , * * vars . modelconfig )
vars . modeldim = int ( tpu_mtj_backend . params . get ( " d_embed " , tpu_mtj_backend . params [ " d_model " ] ) )
2022-03-07 01:51:35 +01:00
tokenizer = tpu_mtj_backend . tokenizer
else :
loadsettings ( )
2022-01-17 20:10:32 +01:00
2022-03-07 01:51:35 +01:00
lua_startup ( )
# Load scripts
load_lua_scripts ( )
2022-01-17 20:10:32 +01:00
2022-03-07 01:51:35 +01:00
final_startup ( )
2022-03-07 18:33:35 +01:00
if not initial_load :
set_aibusy ( False )
2022-06-06 18:49:40 +02:00
emit ( ' from_server ' , { ' cmd ' : ' hide_model_name ' } , broadcast = True )
time . sleep ( 0.1 )
2022-06-07 19:47:10 +02:00
if not vars . gamestarted :
setStartState ( )
2022-06-08 13:26:30 +02:00
sendsettings ( )
refresh_settings ( )
2022-03-07 18:33:35 +01:00
2021-11-14 03:13:52 +01:00
# Set up Flask routes
@app.route ( ' / ' )
@app.route ( ' /index ' )
def index ( ) :
2022-06-07 20:44:14 +02:00
return render_template ( ' index.html ' , hide_ai_menu = args . noaimenu )
2022-06-09 14:42:35 +02:00
@app.route ( ' /favicon.ico ' )
def favicon ( ) :
return send_from_directory ( app . root_path ,
' koboldai.ico ' , mimetype = ' image/vnd.microsoft.icon ' )
2021-11-14 03:13:52 +01:00
@app.route ( ' /download ' )
def download ( ) :
save_format = request . args . get ( " format " , " json " ) . strip ( ) . lower ( )
if ( save_format == " plaintext " ) :
txt = vars . prompt + " " . join ( vars . actions . values ( ) )
save = Response ( txt )
filename = path . basename ( vars . savedir )
if filename [ - 5 : ] == " .json " :
filename = filename [ : - 5 ]
save . headers . set ( ' Content-Disposition ' , ' attachment ' , filename = ' %s .txt ' % filename )
return ( save )
# Build json to write
js = { }
js [ " gamestarted " ] = vars . gamestarted
js [ " prompt " ] = vars . prompt
js [ " memory " ] = vars . memory
js [ " authorsnote " ] = vars . authornote
2021-12-30 05:43:36 +01:00
js [ " anotetemplate " ] = vars . authornotetemplate
2021-11-14 03:13:52 +01:00
js [ " actions " ] = tuple ( vars . actions . values ( ) )
2022-01-21 21:30:37 +01:00
js [ " actions_metadata " ] = vars . actions_metadata
2021-11-14 03:13:52 +01:00
js [ " worldinfo " ] = [ ]
# Extract only the important bits of WI
for wi in vars . worldinfo :
if ( wi [ " constant " ] or wi [ " key " ] != " " ) :
js [ " worldinfo " ] . append ( {
" key " : wi [ " key " ] ,
" keysecondary " : wi [ " keysecondary " ] ,
" content " : wi [ " content " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi [ " comment " ] ,
" folder " : wi [ " folder " ] ,
2021-11-14 03:13:52 +01:00
" selective " : wi [ " selective " ] ,
" constant " : wi [ " constant " ]
} )
save = Response ( json . dumps ( js , indent = 3 ) )
filename = path . basename ( vars . savedir )
if filename [ - 5 : ] == " .json " :
filename = filename [ : - 5 ]
save . headers . set ( ' Content-Disposition ' , ' attachment ' , filename = ' %s .json ' % filename )
return ( save )
2021-12-11 18:45:45 +01:00
2021-12-13 07:03:26 +01:00
#============================ LUA API =============================#
2022-03-07 01:51:35 +01:00
_bridged = { }
F = TypeVar ( " F " , bound = Callable )
def lua_startup ( ) :
global _bridged
global F
global bridged
if ( path . exists ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " ) ) :
file = open ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " , " r " )
js = json . load ( file )
if ( " userscripts " in js ) :
vars . userscripts = [ ]
for userscript in js [ " userscripts " ] :
if type ( userscript ) is not str :
continue
userscript = userscript . strip ( )
if len ( userscript ) != 0 and all ( q not in userscript for q in ( " .. " , " : " ) ) and all ( userscript [ 0 ] not in q for q in ( " / " , " \\ " ) ) and os . path . exists ( fileops . uspath ( userscript ) ) :
vars . userscripts . append ( userscript )
if ( " corescript " in js and type ( js [ " corescript " ] ) is str and all ( q not in js [ " corescript " ] for q in ( " .. " , " : " ) ) and all ( js [ " corescript " ] [ 0 ] not in q for q in ( " / " , " \\ " ) ) ) :
vars . corescript = js [ " corescript " ]
else :
vars . corescript = " default.lua "
file . close ( )
2022-06-06 15:49:46 +02:00
2022-03-07 01:51:35 +01:00
#==================================================================#
# Lua runtime startup
#==================================================================#
print ( " " , end = " " , flush = True )
print ( colors . PURPLE + " Initializing Lua Bridge... " + colors . END , end = " " , flush = True )
# Set up Lua state
vars . lua_state = lupa . LuaRuntime ( unpack_returned_tuples = True )
# Load bridge.lua
bridged = {
2022-06-06 15:49:46 +02:00
" corescript_path " : " cores " ,
" userscript_path " : " userscripts " ,
" config_path " : " userscripts " ,
" lib_paths " : vars . lua_state . table ( " lualibs " , os . path . join ( " extern " , " lualibs " ) ) ,
2022-03-07 01:51:35 +01:00
" vars " : vars ,
}
for kwarg in _bridged :
bridged [ kwarg ] = _bridged [ kwarg ]
try :
2022-06-06 15:49:46 +02:00
vars . lua_kobold , vars . lua_koboldcore , vars . lua_koboldbridge = vars . lua_state . globals ( ) . dofile ( " bridge.lua " ) (
2022-03-07 01:51:35 +01:00
vars . lua_state . globals ( ) . python ,
bridged ,
)
except lupa . LuaError as e :
print ( colors . RED + " ERROR! " + colors . END )
vars . lua_koboldbridge . obliterate_multiverse ( )
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
exit ( 1 )
print ( colors . GREEN + " OK! " + colors . END )
2021-12-13 07:59:53 +01:00
def lua_log_format_name ( name ) :
return f " [ { name } ] " if type ( name ) is str else " CORE "
2022-06-06 15:49:46 +02:00
2022-01-05 01:36:21 +01:00
def bridged_kwarg ( name = None ) :
2022-01-05 02:48:34 +01:00
def _bridged_kwarg ( f : F ) :
2022-01-05 01:36:21 +01:00
_bridged [ name if name is not None else f . __name__ [ 4 : ] if f . __name__ [ : 4 ] == " lua_ " else f . __name__ ] = f
return f
return _bridged_kwarg
2021-12-11 18:45:45 +01:00
#==================================================================#
# Event triggered when a userscript is loaded
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-13 07:03:26 +01:00
def load_callback ( filename , modulename ) :
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " Loading Userscript [ { modulename } ] < { filename } > " + colors . END )
2021-12-11 18:45:45 +01:00
#==================================================================#
# Load all Lua scripts
#==================================================================#
def load_lua_scripts ( ) :
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + " Loading Core Script " + colors . END )
2021-12-13 07:03:26 +01:00
filenames = [ ]
modulenames = [ ]
descriptions = [ ]
lst = fileops . getusfiles ( long_desc = True )
filenames_dict = { ob [ " filename " ] : i for i , ob in enumerate ( lst ) }
for filename in vars . userscripts :
if filename in filenames_dict :
i = filenames_dict [ filename ]
filenames . append ( filename )
modulenames . append ( lst [ i ] [ " modulename " ] )
descriptions . append ( lst [ i ] [ " description " ] )
2022-01-16 05:31:07 +01:00
vars . has_genmod = False
2021-12-11 18:45:45 +01:00
try :
vars . lua_koboldbridge . obliterate_multiverse ( )
2021-12-14 01:28:33 +01:00
tpool . execute ( vars . lua_koboldbridge . load_corescript , vars . corescript )
2022-01-16 05:31:07 +01:00
vars . has_genmod = tpool . execute ( vars . lua_koboldbridge . load_userscripts , filenames , modulenames , descriptions )
2021-12-23 05:33:27 +01:00
vars . lua_running = True
2021-12-11 18:45:45 +01:00
except lupa . LuaError as e :
2022-01-16 05:31:07 +01:00
try :
vars . lua_koboldbridge . obliterate_multiverse ( )
except :
pass
2021-12-23 05:33:27 +01:00
vars . lua_running = False
2021-12-13 08:32:09 +01:00
if ( vars . serverstarted ) :
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Lua script error; please check console. ' } , broadcast = True )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
2021-12-13 17:47:34 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . YELLOW , " Lua engine stopped; please open ' Userscripts ' and press Load to reinitialize scripts. " , colors . END ) , file = sys . stderr )
if ( vars . serverstarted ) :
set_aibusy ( 0 )
2021-12-11 18:45:45 +01:00
2021-12-13 07:59:53 +01:00
#==================================================================#
# Print message that originates from the userscript with the given name
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-13 07:59:53 +01:00
def lua_print ( msg ) :
if ( vars . lua_logname != vars . lua_koboldbridge . logging_name ) :
vars . lua_logname = vars . lua_koboldbridge . logging_name
print ( colors . BLUE + lua_log_format_name ( vars . lua_logname ) + " : " + colors . END , file = sys . stderr )
print ( colors . PURPLE + msg . replace ( " \033 " , " " ) + colors . END )
#==================================================================#
# Print warning that originates from the userscript with the given name
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-13 07:59:53 +01:00
def lua_warn ( msg ) :
if ( vars . lua_logname != vars . lua_koboldbridge . logging_name ) :
vars . lua_logname = vars . lua_koboldbridge . logging_name
print ( colors . BLUE + lua_log_format_name ( vars . lua_logname ) + " : " + colors . END , file = sys . stderr )
2021-12-20 02:18:28 +01:00
print ( colors . YELLOW + msg . replace ( " \033 " , " " ) + colors . END )
2021-12-13 07:59:53 +01:00
2021-12-11 18:45:45 +01:00
#==================================================================#
# Decode tokens into a string using current tokenizer
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_decode ( tokens ) :
2021-12-12 03:24:34 +01:00
tokens = list ( tokens . values ( ) )
2021-12-11 18:45:45 +01:00
assert type ( tokens ) is list
2021-12-12 03:24:34 +01:00
if ( " tokenizer " not in globals ( ) ) :
from transformers import GPT2TokenizerFast
global tokenizer
2022-05-11 04:14:56 +02:00
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-02-12 19:23:59 +01:00
return utils . decodenewlines ( tokenizer . decode ( tokens ) )
2021-12-11 18:45:45 +01:00
#==================================================================#
# Encode string into list of token IDs using current tokenizer
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_encode ( string ) :
assert type ( string ) is str
2021-12-12 03:24:34 +01:00
if ( " tokenizer " not in globals ( ) ) :
from transformers import GPT2TokenizerFast
global tokenizer
2022-05-11 04:14:56 +02:00
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2022-02-12 19:23:59 +01:00
return tokenizer . encode ( utils . encodenewlines ( string ) , max_length = int ( 4e9 ) , truncation = True )
2021-12-11 18:45:45 +01:00
2021-12-20 02:18:28 +01:00
#==================================================================#
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2022-01-05 01:26:59 +01:00
def lua_compute_context ( submission , entries , folders , kwargs ) :
2021-12-20 02:18:28 +01:00
assert type ( submission ) is str
2022-01-05 01:26:59 +01:00
if ( kwargs is None ) :
kwargs = vars . lua_state . table ( )
2021-12-20 02:18:28 +01:00
actions = vars . _actions if vars . lua_koboldbridge . userstate == " genmod " else vars . actions
allowed_entries = None
allowed_folders = None
if ( entries is not None ) :
allowed_entries = set ( )
i = 1
while ( entries [ i ] is not None ) :
allowed_entries . add ( int ( entries [ i ] ) )
i + = 1
if ( folders is not None ) :
allowed_folders = set ( )
i = 1
while ( folders [ i ] is not None ) :
allowed_folders . add ( int ( folders [ i ] ) )
i + = 1
2022-01-05 01:26:59 +01:00
winfo , mem , anotetxt , _ = calcsubmitbudgetheader (
submission ,
allowed_entries = allowed_entries ,
allowed_folders = allowed_folders ,
force_use_txt = True ,
scan_story = kwargs [ " scan_story " ] if kwargs [ " scan_story " ] != None else True ,
)
txt , _ , _ = calcsubmitbudget (
len ( actions ) ,
winfo ,
mem ,
anotetxt ,
actions ,
)
2022-02-12 19:23:59 +01:00
return utils . decodenewlines ( tokenizer . decode ( txt ) )
2021-12-20 02:18:28 +01:00
2021-12-11 18:45:45 +01:00
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_get_attr ( uid , k ) :
assert type ( uid ) is int and type ( k ) is str
if ( uid in vars . worldinfo_u and k in (
" key " ,
" keysecondary " ,
" content " ,
" comment " ,
" folder " ,
" num " ,
" selective " ,
" constant " ,
" uid " ,
) ) :
return vars . worldinfo_u [ uid ] [ k ]
#==================================================================#
# Set property of a world info entry given its UID, property name and new value
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_set_attr ( uid , k , v ) :
assert type ( uid ) is int and type ( k ) is str
assert uid in vars . worldinfo_u and k in (
" key " ,
" keysecondary " ,
" content " ,
" comment " ,
" selective " ,
" constant " ,
)
if ( type ( vars . worldinfo_u [ uid ] [ k ] ) is int and type ( v ) is float ) :
v = int ( v )
assert type ( vars . worldinfo_u [ uid ] [ k ] ) is type ( v )
vars . worldinfo_u [ uid ] [ k ] = v
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } set { k } of world info entry { uid } to { v } " + colors . END )
2021-12-11 18:45:45 +01:00
#==================================================================#
# Get property of a world info folder given its UID and property name
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_folder_get_attr ( uid , k ) :
assert type ( uid ) is int and type ( k ) is str
if ( uid in vars . wifolders_d and k in (
2021-12-12 01:11:38 +01:00
" name " ,
2021-12-11 18:45:45 +01:00
) ) :
return vars . wifolders_d [ uid ] [ k ]
#==================================================================#
# Set property of a world info folder given its UID, property name and new value
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_folder_set_attr ( uid , k , v ) :
assert type ( uid ) is int and type ( k ) is str
assert uid in vars . wifolders_d and k in (
2021-12-12 01:11:38 +01:00
" name " ,
2021-12-11 18:45:45 +01:00
)
if ( type ( vars . wifolders_d [ uid ] [ k ] ) is int and type ( v ) is float ) :
v = int ( v )
assert type ( vars . wifolders_d [ uid ] [ k ] ) is type ( v )
vars . wifolders_d [ uid ] [ k ] = v
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } set { k } of world info folder { uid } to { v } " + colors . END )
2021-12-11 18:45:45 +01:00
#==================================================================#
# Get the "Amount to Generate"
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 23:01:41 +01:00
def lua_get_genamt ( ) :
2021-12-11 18:45:45 +01:00
return vars . genamt
#==================================================================#
# Set the "Amount to Generate"
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 23:01:41 +01:00
def lua_set_genamt ( genamt ) :
2021-12-11 18:45:45 +01:00
assert vars . lua_koboldbridge . userstate != " genmod " and type ( genamt ) in ( int , float ) and genamt > = 0
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } set genamt to { int ( genamt ) } " + colors . END )
2021-12-11 18:45:45 +01:00
vars . genamt = int ( genamt )
#==================================================================#
# Get the "Gens Per Action"
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_get_numseqs ( ) :
return vars . numseqs
#==================================================================#
# Set the "Gens Per Action"
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_set_numseqs ( numseqs ) :
assert type ( numseqs ) in ( int , float ) and numseqs > = 1
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } set numseqs to { int ( numseqs ) } " + colors . END )
2022-01-01 00:28:03 +01:00
vars . numseqs = int ( numseqs )
2021-12-11 18:45:45 +01:00
#==================================================================#
# Check if a setting exists with the given name
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_has_setting ( setting ) :
return setting in (
2021-12-20 02:18:28 +01:00
" anotedepth " ,
2021-12-11 18:45:45 +01:00
" settemp " ,
" settopp " ,
" settopk " ,
" settfs " ,
2022-03-27 22:25:50 +02:00
" settypical " ,
2021-12-11 18:45:45 +01:00
" setreppen " ,
2022-01-24 21:30:38 +01:00
" setreppenslope " ,
" setreppenrange " ,
2021-12-11 18:45:45 +01:00
" settknmax " ,
" setwidepth " ,
" setuseprompt " ,
" setadventure " ,
2021-12-26 01:51:32 +01:00
" setchatmode " ,
2021-12-12 07:55:46 +01:00
" setdynamicscan " ,
2021-12-20 02:18:28 +01:00
" setnopromptgen " ,
2022-03-20 19:12:11 +01:00
" autosave " ,
2021-12-30 05:15:59 +01:00
" setrngpersist " ,
2021-12-20 02:18:28 +01:00
" temp " ,
" topp " ,
2021-12-30 05:15:59 +01:00
" top_p " ,
2021-12-20 02:18:28 +01:00
" topk " ,
2021-12-30 05:15:59 +01:00
" top_k " ,
2021-12-20 02:18:28 +01:00
" tfs " ,
2022-03-27 22:25:50 +02:00
" typical " ,
2021-12-20 02:18:28 +01:00
" reppen " ,
2022-01-24 21:30:38 +01:00
" reppenslope " ,
" reppenrange " ,
2021-12-20 02:18:28 +01:00
" tknmax " ,
" widepth " ,
" useprompt " ,
2021-12-26 01:51:32 +01:00
" chatmode " ,
2021-12-26 01:55:27 +01:00
" chatname " ,
2021-12-20 02:18:28 +01:00
" adventure " ,
" dynamicscan " ,
2021-12-16 12:47:44 +01:00
" nopromptgen " ,
2021-12-30 05:15:59 +01:00
" rngpersist " ,
2021-12-11 18:45:45 +01:00
" frmttriminc " ,
" frmtrmblln " ,
" frmtrmspch " ,
" frmtadsnsp " ,
2021-12-20 02:18:28 +01:00
" frmtsingleline " ,
" triminc " ,
" rmblln " ,
" rmspch " ,
" adsnsp " ,
2021-12-11 18:45:45 +01:00
" singleline " ,
)
#==================================================================#
# Return the setting with the given name if it exists
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_get_setting ( setting ) :
2021-12-20 02:18:28 +01:00
if ( setting in ( " settemp " , " temp " ) ) : return vars . temp
2021-12-30 05:15:59 +01:00
if ( setting in ( " settopp " , " topp " , " top_p " ) ) : return vars . top_p
if ( setting in ( " settopk " , " topk " , " top_k " ) ) : return vars . top_k
2021-12-20 02:18:28 +01:00
if ( setting in ( " settfs " , " tfs " ) ) : return vars . tfs
2022-03-27 22:25:50 +02:00
if ( setting in ( " settypical " , " typical " ) ) : return vars . typical
2021-12-20 02:18:28 +01:00
if ( setting in ( " setreppen " , " reppen " ) ) : return vars . rep_pen
2022-01-24 21:30:38 +01:00
if ( setting in ( " setreppenslope " , " reppenslope " ) ) : return vars . rep_pen_slope
if ( setting in ( " setreppenrange " , " reppenrange " ) ) : return vars . rep_pen_range
2021-12-20 02:18:28 +01:00
if ( setting in ( " settknmax " , " tknmax " ) ) : return vars . max_length
2021-12-11 18:45:45 +01:00
if ( setting == " anotedepth " ) : return vars . andepth
2021-12-20 02:18:28 +01:00
if ( setting in ( " setwidepth " , " widepth " ) ) : return vars . widepth
if ( setting in ( " setuseprompt " , " useprompt " ) ) : return vars . useprompt
if ( setting in ( " setadventure " , " adventure " ) ) : return vars . adventure
2021-12-26 01:51:32 +01:00
if ( setting in ( " setchatmode " , " chatmode " ) ) : return vars . chatmode
2021-12-20 02:18:28 +01:00
if ( setting in ( " setdynamicscan " , " dynamicscan " ) ) : return vars . dynamicscan
if ( setting in ( " setnopromptgen " , " nopromptgen " ) ) : return vars . nopromptgen
2022-03-20 19:12:11 +01:00
if ( setting in ( " autosave " , " autosave " ) ) : return vars . autosave
2021-12-30 05:15:59 +01:00
if ( setting in ( " setrngpersist " , " rngpersist " ) ) : return vars . rngpersist
2021-12-20 02:18:28 +01:00
if ( setting in ( " frmttriminc " , " triminc " ) ) : return vars . formatoptns [ " frmttriminc " ]
if ( setting in ( " frmtrmblln " , " rmblln " ) ) : return vars . formatoptns [ " frmttrmblln " ]
if ( setting in ( " frmtrmspch " , " rmspch " ) ) : return vars . formatoptns [ " frmttrmspch " ]
if ( setting in ( " frmtadsnsp " , " adsnsp " ) ) : return vars . formatoptns [ " frmtadsnsp " ]
if ( setting in ( " frmtsingleline " , " singleline " ) ) : return vars . formatoptns [ " singleline " ]
2021-12-11 18:45:45 +01:00
#==================================================================#
# Set the setting with the given name if it exists
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_set_setting ( setting , v ) :
actual_type = type ( lua_get_setting ( setting ) )
assert v is not None and ( actual_type is type ( v ) or ( actual_type is int and type ( v ) is float ) )
v = actual_type ( v )
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } set { setting } to { v } " + colors . END )
2021-12-20 02:18:28 +01:00
if ( setting in ( " setadventure " , " adventure " ) and v ) :
2021-12-11 18:45:45 +01:00
vars . actionmode = 1
2021-12-20 02:18:28 +01:00
if ( setting in ( " settemp " , " temp " ) ) : vars . temp = v
if ( setting in ( " settopp " , " topp " ) ) : vars . top_p = v
if ( setting in ( " settopk " , " topk " ) ) : vars . top_k = v
if ( setting in ( " settfs " , " tfs " ) ) : vars . tfs = v
2022-03-27 22:25:50 +02:00
if ( setting in ( " settypical " , " typical " ) ) : vars . typical = v
2021-12-20 02:18:28 +01:00
if ( setting in ( " setreppen " , " reppen " ) ) : vars . rep_pen = v
2022-01-24 21:30:38 +01:00
if ( setting in ( " setreppenslope " , " reppenslope " ) ) : vars . rep_pen_slope = v
if ( setting in ( " setreppenrange " , " reppenrange " ) ) : vars . rep_pen_range = v
2021-12-20 02:18:28 +01:00
if ( setting in ( " settknmax " , " tknmax " ) ) : vars . max_length = v ; return True
if ( setting == " anotedepth " ) : vars . andepth = v ; return True
if ( setting in ( " setwidepth " , " widepth " ) ) : vars . widepth = v ; return True
if ( setting in ( " setuseprompt " , " useprompt " ) ) : vars . useprompt = v ; return True
if ( setting in ( " setadventure " , " adventure " ) ) : vars . adventure = v
if ( setting in ( " setdynamicscan " , " dynamicscan " ) ) : vars . dynamicscan = v
if ( setting in ( " setnopromptgen " , " nopromptgen " ) ) : vars . nopromptgen = v
2022-03-20 19:12:11 +01:00
if ( setting in ( " autosave " , " noautosave " ) ) : vars . autosave = v
2021-12-30 05:15:59 +01:00
if ( setting in ( " setrngpersist " , " rngpersist " ) ) : vars . rngpersist = v
2021-12-26 01:51:32 +01:00
if ( setting in ( " setchatmode " , " chatmode " ) ) : vars . chatmode = v
2021-12-20 02:18:28 +01:00
if ( setting in ( " frmttriminc " , " triminc " ) ) : vars . formatoptns [ " frmttriminc " ] = v
if ( setting in ( " frmtrmblln " , " rmblln " ) ) : vars . formatoptns [ " frmttrmblln " ] = v
if ( setting in ( " frmtrmspch " , " rmspch " ) ) : vars . formatoptns [ " frmttrmspch " ] = v
if ( setting in ( " frmtadsnsp " , " adsnsp " ) ) : vars . formatoptns [ " frmtadsnsp " ] = v
if ( setting in ( " frmtsingleline " , " singleline " ) ) : vars . formatoptns [ " singleline " ] = v
2021-12-11 18:45:45 +01:00
#==================================================================#
# Get contents of memory
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_get_memory ( ) :
return vars . memory
#==================================================================#
# Set contents of memory
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 18:45:45 +01:00
def lua_set_memory ( m ) :
assert type ( m ) is str
vars . memory = m
2021-12-20 02:18:28 +01:00
#==================================================================#
# Get contents of author's note
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-20 02:18:28 +01:00
def lua_get_authorsnote ( ) :
return vars . authornote
#==================================================================#
# Set contents of author's note
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-20 02:18:28 +01:00
def lua_set_authorsnote ( m ) :
assert type ( m ) is str
vars . authornote = m
2021-12-30 05:43:36 +01:00
#==================================================================#
# Get contents of author's note template
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-30 05:43:36 +01:00
def lua_get_authorsnotetemplate ( ) :
return vars . authornotetemplate
#==================================================================#
# Set contents of author's note template
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-30 05:43:36 +01:00
def lua_set_authorsnotetemplate ( m ) :
assert type ( m ) is str
vars . authornotetemplate = m
2021-12-11 23:01:41 +01:00
#==================================================================#
# Save settings and send them to client
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-11 23:01:41 +01:00
def lua_resend_settings ( ) :
settingschanged ( )
refresh_settings ( )
2021-12-12 05:44:07 +01:00
#==================================================================#
# Set story chunk text and delete the chunk if the new chunk is empty
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-12 05:44:07 +01:00
def lua_set_chunk ( k , v ) :
assert type ( k ) in ( int , None ) and type ( v ) is str
assert k > = 0
assert k != 0 or len ( v ) != 0
if ( len ( v ) == 0 ) :
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } deleted story chunk { k } " + colors . END )
2021-12-20 02:18:28 +01:00
chunk = int ( k )
if ( vars . lua_koboldbridge . userstate == " genmod " ) :
del vars . _actions [ chunk - 1 ]
2022-01-01 00:28:03 +01:00
vars . lua_deleted . add ( chunk )
if ( not hasattr ( vars , " _actions " ) or vars . _actions is not vars . actions ) :
2022-01-20 21:18:43 +01:00
#Instead of deleting we'll blank out the text. This way our actions and actions_metadata stay in sync and we can restore the chunk on an undo
vars . actions [ chunk - 1 ] = " "
2022-01-21 21:30:37 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ] = [ { " Text " : vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] , " Pinned " : False , " Editted " : True } ] + vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ]
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] = ' '
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-12-12 05:44:07 +01:00
else :
if ( k == 0 ) :
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } edited prompt chunk " + colors . END )
2021-12-12 05:44:07 +01:00
else :
2021-12-13 07:59:53 +01:00
print ( colors . GREEN + f " { lua_log_format_name ( vars . lua_koboldbridge . logging_name ) } edited story chunk { k } " + colors . END )
2021-12-20 02:18:28 +01:00
chunk = int ( k )
if ( chunk == 0 ) :
if ( vars . lua_koboldbridge . userstate == " genmod " ) :
vars . _prompt = v
2022-01-01 00:28:03 +01:00
vars . lua_edited . add ( chunk )
2021-12-20 02:18:28 +01:00
vars . prompt = v
else :
if ( vars . lua_koboldbridge . userstate == " genmod " ) :
vars . _actions [ chunk - 1 ] = v
2022-01-01 00:28:03 +01:00
vars . lua_edited . add ( chunk )
2021-12-20 02:18:28 +01:00
vars . actions [ chunk - 1 ] = v
2022-01-21 21:30:37 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ] = [ { " Text " : vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] , " Pinned " : False , " Editted " : True } ] + vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ]
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] = v
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-12-12 05:44:07 +01:00
2021-12-12 18:09:59 +01:00
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-12 18:09:59 +01:00
def lua_get_modeltype ( ) :
if ( vars . noai ) :
return " readonly "
if ( vars . model in ( " Colab " , " OAI " , " InferKit " ) ) :
return " api "
2022-03-15 04:14:20 +01:00
if ( not vars . use_colab_tpu and vars . model not in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) and ( vars . model in ( " GPT2Custom " , " NeoCustom " ) or vars . model_type in ( " gpt2 " , " gpt_neo " , " gptj " ) ) ) :
2021-12-12 18:09:59 +01:00
hidden_size = get_hidden_size_from_model ( model )
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " gpt2 " , ) or ( vars . model_type == " gpt2 " and hidden_size == 768 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt2 "
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " gpt2-medium " , ) or ( vars . model_type == " gpt2 " and hidden_size == 1024 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt2-medium "
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " gpt2-large " , ) or ( vars . model_type == " gpt2 " and hidden_size == 1280 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt2-large "
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " gpt2-xl " , ) or ( vars . model_type == " gpt2 " and hidden_size == 1600 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt2-xl "
2021-12-23 22:35:52 +01:00
if ( vars . model_type == " gpt_neo " and hidden_size == 768 ) :
2021-12-12 18:09:59 +01:00
return " gpt-neo-125M "
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " EleutherAI/gpt-neo-1.3B " , ) or ( vars . model_type == " gpt_neo " and hidden_size == 2048 ) ) :
2021-12-20 02:18:28 +01:00
return " gpt-neo-1.3B "
2021-12-23 22:35:52 +01:00
if ( vars . model in ( " EleutherAI/gpt-neo-2.7B " , ) or ( vars . model_type == " gpt_neo " and hidden_size == 2560 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt-neo-2.7B "
2022-03-05 20:07:23 +01:00
if ( vars . model in ( " EleutherAI/gpt-j-6B " , ) or ( ( vars . use_colab_tpu or vars . model == " TPUMeshTransformerGPTJ " ) and tpu_mtj_backend . params [ " d_model " ] == 4096 ) or ( vars . model_type in ( " gpt_neo " , " gptj " ) and hidden_size == 4096 ) ) :
2021-12-12 18:09:59 +01:00
return " gpt-j-6B "
return " unknown "
#==================================================================#
# Get model backend as "transformers" or "mtj"
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-12 18:09:59 +01:00
def lua_get_modelbackend ( ) :
if ( vars . noai ) :
return " readonly "
if ( vars . model in ( " Colab " , " OAI " , " InferKit " ) ) :
return " api "
2022-03-15 04:14:20 +01:00
if ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2021-12-12 18:09:59 +01:00
return " mtj "
return " transformers "
#==================================================================#
# Check whether model is loaded from a custom path
#==================================================================#
2022-01-05 01:36:21 +01:00
@bridged_kwarg ( )
2021-12-12 18:09:59 +01:00
def lua_is_custommodel ( ) :
2022-03-15 04:14:20 +01:00
return vars . model in ( " GPT2Custom " , " NeoCustom " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " )
2021-12-12 18:09:59 +01:00
2022-04-12 21:59:05 +02:00
#==================================================================#
# Return the filename (as a string) of the current soft prompt, or
# None if no soft prompt is loaded
#==================================================================#
@bridged_kwarg ( )
def lua_get_spfilename ( ) :
return vars . spfilename . strip ( ) or None
#==================================================================#
# When called with a string as argument, sets the current soft prompt;
# when called with None as argument, uses no soft prompt.
# Returns True if soft prompt changed, False otherwise.
#==================================================================#
@bridged_kwarg ( )
def lua_set_spfilename ( filename : Union [ str , None ] ) :
if ( filename is None ) :
filename = " "
filename = str ( filename ) . strip ( )
changed = lua_get_spfilename ( ) != filename
assert all ( q not in filename for q in ( " / " , " \\ " ) )
spRequest ( filename )
return changed
2021-12-12 18:09:59 +01:00
2021-12-11 23:01:41 +01:00
#==================================================================#
#
#==================================================================#
def execute_inmod ( ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-13 07:59:53 +01:00
vars . lua_logname = . . .
2021-12-26 18:49:28 +01:00
vars . lua_edited = set ( )
vars . lua_deleted = set ( )
2021-12-13 08:32:09 +01:00
try :
2021-12-13 23:19:04 +01:00
tpool . execute ( vars . lua_koboldbridge . execute_inmod )
2021-12-13 08:32:09 +01:00
except lupa . LuaError as e :
vars . lua_koboldbridge . obliterate_multiverse ( )
2021-12-23 05:33:27 +01:00
vars . lua_running = False
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Lua script error; please check console. ' } , broadcast = True )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
2021-12-13 17:47:34 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . YELLOW , " Lua engine stopped; please open ' Userscripts ' and press Load to reinitialize scripts. " , colors . END ) , file = sys . stderr )
set_aibusy ( 0 )
2021-12-11 23:01:41 +01:00
def execute_genmod ( ) :
vars . lua_koboldbridge . execute_genmod ( )
def execute_outmod ( ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2022-01-17 03:09:10 +01:00
emit ( ' from_server ' , { ' cmd ' : ' hidemsg ' , ' data ' : ' ' } , broadcast = True )
2021-12-13 08:32:09 +01:00
try :
2021-12-13 23:19:04 +01:00
tpool . execute ( vars . lua_koboldbridge . execute_outmod )
2021-12-13 08:32:09 +01:00
except lupa . LuaError as e :
vars . lua_koboldbridge . obliterate_multiverse ( )
2021-12-23 05:33:27 +01:00
vars . lua_running = False
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Lua script error; please check console. ' } , broadcast = True )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
2021-12-13 17:47:34 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . YELLOW , " Lua engine stopped; please open ' Userscripts ' and press Load to reinitialize scripts. " , colors . END ) , file = sys . stderr )
set_aibusy ( 0 )
2021-12-13 23:19:04 +01:00
if ( vars . lua_koboldbridge . resend_settings_required ) :
vars . lua_koboldbridge . resend_settings_required = False
lua_resend_settings ( )
2021-12-20 02:18:28 +01:00
for k in vars . lua_edited :
inlineedit ( k , vars . actions [ k ] )
for k in vars . lua_deleted :
inlinedelete ( k )
2021-12-11 23:01:41 +01:00
2021-12-11 18:45:45 +01:00
2021-11-14 03:13:52 +01:00
#============================ METHODS =============================#
#==================================================================#
# Event triggered when browser SocketIO is loaded and connects to server
#==================================================================#
@socketio.on ( ' connect ' )
def do_connect ( ) :
print ( " {0} Client connected! {1} " . format ( colors . GREEN , colors . END ) )
2021-12-27 18:52:06 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setchatname ' , ' data ' : vars . chatname } )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } )
2022-01-18 23:20:45 +01:00
emit ( ' from_server ' , { ' cmd ' : ' connected ' , ' smandelete ' : vars . smandelete , ' smanrename ' : vars . smanrename , ' modelname ' : getmodelname ( ) } )
2022-02-18 01:08:12 +01:00
if ( vars . host ) :
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' runs_remotely ' } )
if ( vars . allowsp ) :
emit ( ' from_server ' , { ' cmd ' : ' allowsp ' , ' data ' : vars . allowsp } )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-23 19:02:11 +01:00
emit ( ' from_server ' , { ' cmd ' : ' spstatitems ' , ' data ' : { vars . spfilename : vars . spmeta } if vars . allowsp and len ( vars . spfilename ) else { } } , broadcast = True )
2021-12-23 05:33:27 +01:00
2021-11-14 03:13:52 +01:00
if ( not vars . gamestarted ) :
setStartState ( )
sendsettings ( )
refresh_settings ( )
vars . laststory = None
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } )
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } )
vars . mode = " play "
else :
# Game in session, send current game data and ready state to browser
refresh_story ( )
sendsettings ( )
refresh_settings ( )
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } )
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } )
if ( vars . mode == " play " ) :
if ( not vars . aibusy ) :
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' ready ' } )
else :
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' wait ' } )
elif ( vars . mode == " edit " ) :
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' true ' } )
elif ( vars . mode == " memory " ) :
emit ( ' from_server ' , { ' cmd ' : ' memmode ' , ' data ' : ' true ' } )
elif ( vars . mode == " wi " ) :
emit ( ' from_server ' , { ' cmd ' : ' wimode ' , ' data ' : ' true ' } )
2022-01-18 23:20:45 +01:00
emit ( ' from_server ' , { ' cmd ' : ' gamesaved ' , ' data ' : vars . gamesaved } , broadcast = True )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Event triggered when browser SocketIO sends data to the server
#==================================================================#
@socketio.on ( ' message ' )
def get_message ( msg ) :
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} Data received: {1} {2} " . format ( colors . GREEN , msg , colors . END ) )
2021-11-14 03:13:52 +01:00
# Submit action
if ( msg [ ' cmd ' ] == ' submit ' ) :
if ( vars . mode == " play " ) :
2022-01-10 23:09:47 +01:00
if ( vars . aibusy ) :
if ( msg . get ( ' allowabort ' , False ) ) :
vars . abort = True
2022-01-10 22:36:15 +01:00
return
vars . abort = False
2021-12-20 02:18:28 +01:00
vars . lua_koboldbridge . feedback = None
2021-12-27 04:21:58 +01:00
if ( vars . chatmode ) :
if ( type ( msg [ ' chatname ' ] ) is not str ) :
raise ValueError ( " Chatname must be a string " )
vars . chatname = msg [ ' chatname ' ]
settingschanged ( )
2022-01-11 21:31:44 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setchatname ' , ' data ' : vars . chatname } )
2022-01-04 20:40:10 +01:00
vars . recentrng = vars . recentrngm = None
2021-11-14 03:13:52 +01:00
actionsubmit ( msg [ ' data ' ] , actionmode = msg [ ' actionmode ' ] )
elif ( vars . mode == " edit " ) :
editsubmit ( msg [ ' data ' ] )
elif ( vars . mode == " memory " ) :
memsubmit ( msg [ ' data ' ] )
# Retry Action
elif ( msg [ ' cmd ' ] == ' retry ' ) :
2022-01-10 23:09:47 +01:00
if ( vars . aibusy ) :
if ( msg . get ( ' allowabort ' , False ) ) :
vars . abort = True
2022-01-10 22:36:15 +01:00
return
vars . abort = False
2021-12-27 04:21:58 +01:00
if ( vars . chatmode ) :
if ( type ( msg [ ' chatname ' ] ) is not str ) :
raise ValueError ( " Chatname must be a string " )
vars . chatname = msg [ ' chatname ' ]
settingschanged ( )
2022-01-11 21:31:44 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setchatname ' , ' data ' : vars . chatname } )
2021-11-14 03:13:52 +01:00
actionretry ( msg [ ' data ' ] )
# Back/Undo Action
elif ( msg [ ' cmd ' ] == ' back ' ) :
2022-03-12 20:21:11 +01:00
ignore = actionback ( )
2022-01-21 21:30:37 +01:00
# Forward/Redo Action
elif ( msg [ ' cmd ' ] == ' redo ' ) :
actionredo ( )
2021-11-14 03:13:52 +01:00
# EditMode Action (old)
elif ( msg [ ' cmd ' ] == ' edit ' ) :
if ( vars . mode == " play " ) :
vars . mode = " edit "
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' true ' } , broadcast = True )
elif ( vars . mode == " edit " ) :
vars . mode = " play "
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } , broadcast = True )
# EditLine Action (old)
elif ( msg [ ' cmd ' ] == ' editline ' ) :
editrequest ( int ( msg [ ' data ' ] ) )
# Inline edit
elif ( msg [ ' cmd ' ] == ' inlineedit ' ) :
inlineedit ( msg [ ' chunk ' ] , msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' inlinedelete ' ) :
inlinedelete ( msg [ ' data ' ] )
# DeleteLine Action (old)
elif ( msg [ ' cmd ' ] == ' delete ' ) :
deleterequest ( )
elif ( msg [ ' cmd ' ] == ' memory ' ) :
togglememorymode ( )
2022-02-18 01:08:12 +01:00
elif ( not vars . host and msg [ ' cmd ' ] == ' savetofile ' ) :
2021-11-14 03:13:52 +01:00
savetofile ( )
2022-02-18 01:08:12 +01:00
elif ( not vars . host and msg [ ' cmd ' ] == ' loadfromfile ' ) :
2021-11-14 03:13:52 +01:00
loadfromfile ( )
elif ( msg [ ' cmd ' ] == ' loadfromstring ' ) :
loadRequest ( json . loads ( msg [ ' data ' ] ) , filename = msg [ ' filename ' ] )
2022-02-18 01:08:12 +01:00
elif ( not vars . host and msg [ ' cmd ' ] == ' import ' ) :
2021-11-14 03:13:52 +01:00
importRequest ( )
elif ( msg [ ' cmd ' ] == ' newgame ' ) :
newGameRequest ( )
elif ( msg [ ' cmd ' ] == ' rndgame ' ) :
2021-12-30 05:15:59 +01:00
randomGameRequest ( msg [ ' data ' ] , memory = msg [ ' memory ' ] )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' settemp ' ) :
vars . temp = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltemp ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' settopp ' ) :
vars . top_p = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltopp ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' settopk ' ) :
vars . top_k = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltopk ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' settfs ' ) :
vars . tfs = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltfs ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
2022-03-27 22:25:50 +02:00
elif ( msg [ ' cmd ' ] == ' settypical ' ) :
vars . typical = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltypical ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' setreppen ' ) :
vars . rep_pen = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelreppen ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
2022-01-24 21:30:38 +01:00
elif ( msg [ ' cmd ' ] == ' setreppenslope ' ) :
vars . rep_pen_slope = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelreppenslope ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' setreppenrange ' ) :
vars . rep_pen_range = float ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelreppenrange ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' setoutput ' ) :
vars . genamt = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeloutput ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' settknmax ' ) :
vars . max_length = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabeltknmax ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' setikgen ' ) :
vars . ikgen = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelikgen ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
# Author's Note field update
elif ( msg [ ' cmd ' ] == ' anote ' ) :
2021-12-30 05:43:36 +01:00
anotesubmit ( msg [ ' data ' ] , template = msg [ ' template ' ] )
2021-11-14 03:13:52 +01:00
# Author's Note depth update
elif ( msg [ ' cmd ' ] == ' anotedepth ' ) :
vars . andepth = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelanotedepth ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
settingschanged ( )
refresh_settings ( )
# Format - Trim incomplete sentences
elif ( msg [ ' cmd ' ] == ' frmttriminc ' ) :
if ( ' frmttriminc ' in vars . formatoptns ) :
vars . formatoptns [ " frmttriminc " ] = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' frmtrmblln ' ) :
if ( ' frmtrmblln ' in vars . formatoptns ) :
vars . formatoptns [ " frmtrmblln " ] = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' frmtrmspch ' ) :
if ( ' frmtrmspch ' in vars . formatoptns ) :
vars . formatoptns [ " frmtrmspch " ] = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' frmtadsnsp ' ) :
if ( ' frmtadsnsp ' in vars . formatoptns ) :
vars . formatoptns [ " frmtadsnsp " ] = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' singleline ' ) :
if ( ' singleline ' in vars . formatoptns ) :
vars . formatoptns [ " singleline " ] = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' importselect ' ) :
vars . importnum = int ( msg [ " data " ] . replace ( " import " , " " ) )
elif ( msg [ ' cmd ' ] == ' importcancel ' ) :
emit ( ' from_server ' , { ' cmd ' : ' popupshow ' , ' data ' : False } )
vars . importjs = { }
elif ( msg [ ' cmd ' ] == ' importaccept ' ) :
emit ( ' from_server ' , { ' cmd ' : ' popupshow ' , ' data ' : False } )
importgame ( )
elif ( msg [ ' cmd ' ] == ' wi ' ) :
togglewimode ( )
elif ( msg [ ' cmd ' ] == ' wiinit ' ) :
if ( int ( msg [ ' data ' ] ) < len ( vars . worldinfo ) ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . worldinfo [ msg [ ' data ' ] ] [ " init " ] = True
2021-12-05 05:59:28 +01:00
addwiitem ( folder_uid = msg [ ' folder ' ] )
elif ( msg [ ' cmd ' ] == ' wifolderinit ' ) :
addwifolder ( )
elif ( msg [ ' cmd ' ] == ' wimoveitem ' ) :
movewiitem ( msg [ ' destination ' ] , msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' wimovefolder ' ) :
movewifolder ( msg [ ' destination ' ] , msg [ ' data ' ] )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' widelete ' ) :
deletewi ( msg [ ' data ' ] )
2021-12-05 05:59:28 +01:00
elif ( msg [ ' cmd ' ] == ' wifolderdelete ' ) :
deletewifolder ( msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' wiexpand ' ) :
assert 0 < = int ( msg [ ' data ' ] ) < len ( vars . worldinfo )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiexpand ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
elif ( msg [ ' cmd ' ] == ' wiexpandfolder ' ) :
assert 0 < = int ( msg [ ' data ' ] ) < len ( vars . worldinfo )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiexpandfolder ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-12-22 19:12:35 +01:00
elif ( msg [ ' cmd ' ] == ' wifoldercollapsecontent ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-22 19:12:35 +01:00
vars . wifolders_d [ msg [ ' data ' ] ] [ ' collapsed ' ] = True
emit ( ' from_server ' , { ' cmd ' : ' wifoldercollapsecontent ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
elif ( msg [ ' cmd ' ] == ' wifolderexpandcontent ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-22 19:12:35 +01:00
vars . wifolders_d [ msg [ ' data ' ] ] [ ' collapsed ' ] = False
emit ( ' from_server ' , { ' cmd ' : ' wifolderexpandcontent ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-12-05 05:59:28 +01:00
elif ( msg [ ' cmd ' ] == ' wiupdate ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
num = int ( msg [ ' num ' ] )
fields = ( " key " , " keysecondary " , " content " , " comment " )
for field in fields :
if ( field in msg [ ' data ' ] and type ( msg [ ' data ' ] [ field ] ) is str ) :
vars . worldinfo [ num ] [ field ] = msg [ ' data ' ] [ field ]
emit ( ' from_server ' , { ' cmd ' : ' wiupdate ' , ' num ' : msg [ ' num ' ] , ' data ' : { field : vars . worldinfo [ num ] [ field ] for field in fields } } , broadcast = True )
elif ( msg [ ' cmd ' ] == ' wifolderupdate ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
uid = int ( msg [ ' uid ' ] )
fields = ( " name " , " collapsed " )
for field in fields :
if ( field in msg [ ' data ' ] and type ( msg [ ' data ' ] [ field ] ) is ( str if field != " collapsed " else bool ) ) :
vars . wifolders_d [ uid ] [ field ] = msg [ ' data ' ] [ field ]
emit ( ' from_server ' , { ' cmd ' : ' wifolderupdate ' , ' uid ' : msg [ ' uid ' ] , ' data ' : { field : vars . wifolders_d [ uid ] [ field ] for field in fields } } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' wiselon ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . worldinfo [ msg [ ' data ' ] ] [ " selective " ] = True
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiselon ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' wiseloff ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . worldinfo [ msg [ ' data ' ] ] [ " selective " ] = False
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiseloff ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' wiconstanton ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . worldinfo [ msg [ ' data ' ] ] [ " constant " ] = True
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiconstanton ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' wiconstantoff ' ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . worldinfo [ msg [ ' data ' ] ] [ " constant " ] = False
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wiconstantoff ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' sendwilist ' ) :
commitwi ( msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' aidgimport ' ) :
importAidgRequest ( msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' saveasrequest ' ) :
saveas ( msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' saverequest ' ) :
save ( )
elif ( msg [ ' cmd ' ] == ' loadlistrequest ' ) :
getloadlist ( )
elif ( msg [ ' cmd ' ] == ' splistrequest ' ) :
getsplist ( )
2021-12-13 07:03:26 +01:00
elif ( msg [ ' cmd ' ] == ' uslistrequest ' ) :
2021-12-23 05:33:27 +01:00
unloaded , loaded = getuslist ( )
emit ( ' from_server ' , { ' cmd ' : ' buildus ' , ' data ' : { " unloaded " : unloaded , " loaded " : loaded } } )
2021-12-13 07:03:26 +01:00
elif ( msg [ ' cmd ' ] == ' usloaded ' ) :
vars . userscripts = [ ]
for userscript in msg [ ' data ' ] :
if type ( userscript ) is not str :
continue
userscript = userscript . strip ( )
if len ( userscript ) != 0 and all ( q not in userscript for q in ( " .. " , " : " ) ) and all ( userscript [ 0 ] not in q for q in ( " / " , " \\ " ) ) and os . path . exists ( fileops . uspath ( userscript ) ) :
vars . userscripts . append ( userscript )
settingschanged ( )
elif ( msg [ ' cmd ' ] == ' usload ' ) :
load_lua_scripts ( )
2021-12-23 05:33:27 +01:00
unloaded , loaded = getuslist ( )
sendUSStatItems ( )
2022-03-07 01:51:35 +01:00
elif ( msg [ ' cmd ' ] == ' list_model ' ) :
sendModelSelection ( menu = msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' load_model ' ) :
2022-03-11 17:31:41 +01:00
if not os . path . exists ( " settings/ " ) :
2022-03-08 01:21:55 +01:00
os . mkdir ( " settings " )
2022-06-07 01:21:10 +02:00
changed = True
if os . path . exists ( " settings/ " + vars . model . replace ( ' / ' , ' _ ' ) + " .breakmodel " ) :
with open ( " settings/ " + vars . model . replace ( ' / ' , ' _ ' ) + " .breakmodel " , " r " ) as file :
if file . read ( ) == msg [ ' gpu_layers ' ] :
changed = False
if changed :
f = open ( " settings/ " + vars . model . replace ( ' / ' , ' _ ' ) + " .breakmodel " , " w " )
f . write ( msg [ ' gpu_layers ' ] )
f . close ( )
2022-03-12 20:21:11 +01:00
vars . colaburl = msg [ ' url ' ] + " /request "
2022-06-07 15:05:51 +02:00
load_model ( use_gpu = msg [ ' use_gpu ' ] , gpu_layers = msg [ ' gpu_layers ' ] , online_model = msg [ ' online_model ' ] )
2022-03-07 20:20:25 +01:00
elif ( msg [ ' cmd ' ] == ' show_model ' ) :
print ( " Model Name: {} " . format ( getmodelname ( ) ) )
emit ( ' from_server ' , { ' cmd ' : ' show_model_name ' , ' data ' : getmodelname ( ) } , broadcast = True )
2022-03-07 01:51:35 +01:00
elif ( msg [ ' cmd ' ] == ' selectmodel ' ) :
2022-03-12 20:21:11 +01:00
# This is run when a model line is selected from the UI (line from the model_menu variable) that is tagged as not a menu
# otherwise we should be running the msg['cmd'] == 'list_model'
2022-06-07 01:21:10 +02:00
# We have to do a bit of processing though, if we select a custom path, we need to list out the contents of folders
# But if we select something else, we need to potentially show model layers for each GPU
# We might also need to show key input. All of that happens here
2022-03-12 20:21:11 +01:00
# The data variable will contain the model name. But our Custom lines need a bit more processing
# If we're on a custom line that we have selected a model for, the path variable will be in msg
# so if that's missing we need to run the menu to show the model folders in the models folder
2022-03-07 02:55:11 +01:00
if msg [ ' data ' ] in ( ' NeoCustom ' , ' GPT2Custom ' ) and ' path ' not in msg :
2022-06-09 00:42:44 +02:00
if ' folder ' not in msg :
folder = " ./models "
else :
folder = msg [ ' folder ' ]
sendModelSelection ( menu = msg [ ' data ' ] , folder = folder )
elif msg [ ' data ' ] in ( ' NeoCustom ' , ' GPT2Custom ' ) :
if check_if_dir_is_model ( msg [ ' path ' ] ) :
vars . model = msg [ ' data ' ]
vars . custmodpth = msg [ ' path ' ]
get_model_info ( msg [ ' data ' ] , directory = msg [ ' path ' ] )
else :
sendModelSelection ( menu = msg [ ' data ' ] , folder = msg [ ' path ' ] )
2022-03-07 17:27:23 +01:00
else :
2022-03-08 00:44:37 +01:00
vars . model = msg [ ' data ' ]
if ' path ' in msg :
2022-06-09 00:42:44 +02:00
vars . custmodpth = msg [ ' path ' ]
get_model_info ( msg [ ' data ' ] , directory = msg [ ' path ' ] )
2022-03-07 17:49:34 +01:00
else :
2022-06-07 01:21:10 +02:00
get_model_info ( vars . model )
2022-03-12 20:21:11 +01:00
elif ( msg [ ' cmd ' ] == ' OAI_Key_Update ' ) :
2022-06-07 01:21:10 +02:00
get_oai_models ( msg [ ' key ' ] )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' loadselect ' ) :
vars . loadselect = msg [ " data " ]
elif ( msg [ ' cmd ' ] == ' spselect ' ) :
vars . spselect = msg [ " data " ]
elif ( msg [ ' cmd ' ] == ' loadrequest ' ) :
loadRequest ( fileops . storypath ( vars . loadselect ) )
elif ( msg [ ' cmd ' ] == ' sprequest ' ) :
spRequest ( vars . spselect )
elif ( msg [ ' cmd ' ] == ' deletestory ' ) :
deletesave ( msg [ ' data ' ] )
elif ( msg [ ' cmd ' ] == ' renamestory ' ) :
renamesave ( msg [ ' data ' ] , msg [ ' newname ' ] )
elif ( msg [ ' cmd ' ] == ' clearoverwrite ' ) :
vars . svowname = " "
vars . saveow = False
elif ( msg [ ' cmd ' ] == ' seqsel ' ) :
selectsequence ( msg [ ' data ' ] )
2022-01-21 21:30:37 +01:00
elif ( msg [ ' cmd ' ] == ' seqpin ' ) :
pinsequence ( msg [ ' data ' ] )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' setnumseq ' ) :
vars . numseqs = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelnumseq ' , ' data ' : msg [ ' data ' ] } )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' setwidepth ' ) :
vars . widepth = int ( msg [ ' data ' ] )
emit ( ' from_server ' , { ' cmd ' : ' setlabelwidepth ' , ' data ' : msg [ ' data ' ] } )
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' setuseprompt ' ) :
vars . useprompt = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
elif ( msg [ ' cmd ' ] == ' setadventure ' ) :
vars . adventure = msg [ ' data ' ]
2021-12-27 13:32:25 +01:00
vars . chatmode = False
2021-11-14 03:13:52 +01:00
settingschanged ( )
refresh_settings ( )
2022-01-20 13:46:34 +01:00
elif ( msg [ ' cmd ' ] == ' autosave ' ) :
vars . autosave = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
2021-12-26 01:51:32 +01:00
elif ( msg [ ' cmd ' ] == ' setchatmode ' ) :
vars . chatmode = msg [ ' data ' ]
2021-12-27 13:32:25 +01:00
vars . adventure = False
2021-12-26 01:51:32 +01:00
settingschanged ( )
refresh_settings ( )
2021-11-14 03:13:52 +01:00
elif ( msg [ ' cmd ' ] == ' setdynamicscan ' ) :
vars . dynamicscan = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
2021-12-16 12:47:44 +01:00
elif ( msg [ ' cmd ' ] == ' setnopromptgen ' ) :
vars . nopromptgen = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
2021-12-30 05:15:59 +01:00
elif ( msg [ ' cmd ' ] == ' setrngpersist ' ) :
vars . rngpersist = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
2022-01-16 05:31:07 +01:00
elif ( msg [ ' cmd ' ] == ' setnogenmod ' ) :
vars . nogenmod = msg [ ' data ' ]
settingschanged ( )
refresh_settings ( )
2022-02-18 01:08:12 +01:00
elif ( not vars . host and msg [ ' cmd ' ] == ' importwi ' ) :
2021-11-14 03:13:52 +01:00
wiimportrequest ( )
2022-01-24 18:54:44 +01:00
elif ( msg [ ' cmd ' ] == ' debug ' ) :
vars . debug = msg [ ' data ' ]
emit ( ' from_server ' , { ' cmd ' : ' set_debug ' , ' data ' : msg [ ' data ' ] } , broadcast = True )
if vars . debug :
send_debug ( )
2021-12-23 05:33:27 +01:00
#==================================================================#
# Send userscripts list to client
#==================================================================#
def sendUSStatItems ( ) :
_ , loaded = getuslist ( )
loaded = loaded if vars . lua_running else [ ]
last_userscripts = [ e [ " filename " ] for e in loaded ]
emit ( ' from_server ' , { ' cmd ' : ' usstatitems ' , ' data ' : loaded , ' flash ' : last_userscripts != vars . last_userscripts } , broadcast = True )
vars . last_userscripts = last_userscripts
2022-01-30 19:47:30 +01:00
#==================================================================#
# KoboldAI Markup Formatting (Mixture of Markdown and sanitized html)
#==================================================================#
def kml ( txt ) :
2022-03-20 19:22:53 +01:00
txt = txt . replace ( ' \ > ' , ' > ' )
2022-01-30 20:07:47 +01:00
txt = bleach . clean ( markdown . markdown ( txt ) , tags = [ ' p ' , ' em ' , ' strong ' , ' code ' , ' h1 ' , ' h2 ' , ' h3 ' , ' h4 ' , ' h5 ' , ' h6 ' , ' li ' , ' ul ' , ' b ' , ' i ' , ' a ' , ' span ' , ' button ' ] , styles = [ ' color ' , ' font-weight ' ] , attributes = [ ' id ' , ' class ' , ' style ' , ' href ' ] )
2022-01-30 19:47:30 +01:00
return txt
2021-11-14 03:13:52 +01:00
#==================================================================#
# Send start message and tell Javascript to set UI state
#==================================================================#
def setStartState ( ) :
2022-01-30 19:47:30 +01:00
if ( vars . welcome ) :
txt = kml ( vars . welcome ) + " <br/> "
2021-11-14 03:13:52 +01:00
else :
2022-01-30 19:47:30 +01:00
txt = " <span>Welcome to <span class= \" color_cyan \" >KoboldAI</span>! You are running <span class= \" color_green \" > " + getmodelname ( ) + " </span>.<br/> "
if ( not vars . noai and not vars . welcome ) :
txt = txt + " Please load a game or enter a prompt below to begin!</span> "
if ( vars . noai ) :
2021-11-14 03:13:52 +01:00
txt = txt + " Please load or import a story to read. There is no AI in this mode. "
emit ( ' from_server ' , { ' cmd ' : ' updatescreen ' , ' gamestarted ' : vars . gamestarted , ' data ' : txt } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' start ' } , broadcast = True )
#==================================================================#
# Transmit applicable settings to SocketIO to build UI sliders/toggles
#==================================================================#
def sendsettings ( ) :
# Send settings for selected AI type
if ( vars . model != " InferKit " ) :
for set in gensettings . gensettingstf :
emit ( ' from_server ' , { ' cmd ' : ' addsetting ' , ' data ' : set } )
else :
for set in gensettings . gensettingsik :
emit ( ' from_server ' , { ' cmd ' : ' addsetting ' , ' data ' : set } )
# Send formatting options
for frm in gensettings . formatcontrols :
emit ( ' from_server ' , { ' cmd ' : ' addformat ' , ' data ' : frm } )
# Add format key to vars if it wasn't loaded with client.settings
if ( not frm [ " id " ] in vars . formatoptns ) :
vars . formatoptns [ frm [ " id " ] ] = False ;
2022-01-18 23:20:45 +01:00
#==================================================================#
# Set value of gamesaved
#==================================================================#
def setgamesaved ( gamesaved ) :
assert type ( gamesaved ) is bool
if ( gamesaved != vars . gamesaved ) :
emit ( ' from_server ' , { ' cmd ' : ' gamesaved ' , ' data ' : gamesaved } , broadcast = True )
vars . gamesaved = gamesaved
2021-11-14 03:13:52 +01:00
#==================================================================#
# Take input text from SocketIO and decide what to do with it
#==================================================================#
2022-01-17 03:09:10 +01:00
def check_for_backend_compilation ( ) :
if ( vars . checking ) :
return
vars . checking = True
for _ in range ( 31 ) :
time . sleep ( 0.06276680299820175 )
if ( vars . compiling ) :
emit ( ' from_server ' , { ' cmd ' : ' warnmsg ' , ' data ' : ' Compiling TPU backend—this usually takes 1–2 minutes... ' } , broadcast = True )
break
vars . checking = False
2021-12-27 01:52:56 +01:00
def actionsubmit ( data , actionmode = 0 , force_submit = False , force_prompt_gen = False , disable_recentrng = False ) :
2021-11-14 03:13:52 +01:00
# Ignore new submissions if the AI is currently busy
if ( vars . aibusy ) :
return
2022-01-31 18:39:34 +01:00
2021-12-20 02:18:28 +01:00
while ( True ) :
set_aibusy ( 1 )
2021-12-27 01:52:56 +01:00
if ( disable_recentrng ) :
2022-01-04 20:40:10 +01:00
vars . recentrng = vars . recentrngm = None
2021-12-27 01:52:56 +01:00
2021-12-20 02:18:28 +01:00
vars . recentback = False
vars . recentedit = False
vars . actionmode = actionmode
# "Action" mode
if ( actionmode == 1 ) :
data = data . strip ( ) . lstrip ( ' > ' )
data = re . sub ( r ' \ n+ ' , ' ' , data )
if ( len ( data ) ) :
data = f " \n \n > { data } \n "
2021-12-26 01:51:32 +01:00
# "Chat" mode
if ( vars . chatmode and vars . gamestarted ) :
data = re . sub ( r ' \ n+ ' , ' ' , data )
if ( len ( data ) ) :
2022-03-04 15:46:00 +01:00
data = f " \n { vars . chatname } : { data } \n "
2022-01-31 19:39:32 +01:00
2021-12-20 02:18:28 +01:00
# If we're not continuing, store a copy of the raw input
if ( data != " " ) :
vars . lastact = data
if ( not vars . gamestarted ) :
vars . submission = data
execute_inmod ( )
data = vars . submission
if ( not force_submit and len ( data . strip ( ) ) == 0 ) :
assert False
# Start the game
vars . gamestarted = True
2021-12-27 01:40:20 +01:00
if ( not vars . noai and vars . lua_koboldbridge . generating and ( not vars . nopromptgen or force_prompt_gen ) ) :
2021-12-20 02:18:28 +01:00
# Save this first action as the prompt
vars . prompt = data
# Clear the startup text from game screen
emit ( ' from_server ' , { ' cmd ' : ' updatescreen ' , ' gamestarted ' : False , ' data ' : ' Please wait, generating story... ' } , broadcast = True )
calcsubmit ( data ) # Run the first action through the generator
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None and len ( vars . genseqs ) == 0 ) :
2021-12-20 02:18:28 +01:00
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
emit ( ' from_server ' , { ' cmd ' : ' scrolldown ' , ' data ' : ' ' } , broadcast = True )
break
else :
# Save this first action as the prompt
2022-01-10 22:36:15 +01:00
vars . prompt = data if len ( data ) > 0 else ' " '
2021-12-20 02:18:28 +01:00
for i in range ( vars . numseqs ) :
vars . lua_koboldbridge . outputs [ i + 1 ] = " "
execute_outmod ( )
2021-12-13 01:27:20 +01:00
vars . lua_koboldbridge . regeneration_required = False
genout = [ ]
for i in range ( vars . numseqs ) :
genout . append ( { " generated_text " : vars . lua_koboldbridge . outputs [ i + 1 ] } )
assert type ( genout [ - 1 ] [ " generated_text " ] ) is str
if ( len ( genout ) == 1 ) :
2021-12-29 20:23:22 +01:00
genresult ( genout [ 0 ] [ " generated_text " ] , flash = False )
refresh_story ( )
if ( len ( vars . actions ) > 0 ) :
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : vars . actions . get_last_key ( ) + 1 } , broadcast = True )
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None ) :
2021-12-20 02:18:28 +01:00
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
2021-12-13 01:27:20 +01:00
else :
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None and vars . lua_koboldbridge . restart_sequence > 0 ) :
2021-12-29 20:23:22 +01:00
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [ " generated_text " ] , flash = False )
2021-12-20 02:18:28 +01:00
refresh_story ( )
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
2021-12-13 01:27:20 +01:00
genselect ( genout )
2021-12-29 20:23:22 +01:00
refresh_story ( )
2021-12-20 02:18:28 +01:00
set_aibusy ( 0 )
emit ( ' from_server ' , { ' cmd ' : ' scrolldown ' , ' data ' : ' ' } , broadcast = True )
break
2021-11-14 03:13:52 +01:00
else :
2021-12-20 02:18:28 +01:00
# Apply input formatting & scripts before sending to tokenizer
if ( vars . actionmode == 0 ) :
data = applyinputformatting ( data )
vars . submission = data
execute_inmod ( )
data = vars . submission
# Dont append submission if it's a blank/continue action
if ( data != " " ) :
# Store the result in the Action log
if ( len ( vars . prompt . strip ( ) ) == 0 ) :
vars . prompt = data
else :
vars . actions . append ( data )
2022-01-20 21:18:43 +01:00
# we now need to update the actions_metadata
# we'll have two conditions.
# 1. This is totally new (user entered)
2022-03-04 20:14:44 +01:00
if vars . actions . get_last_key ( ) not in vars . actions_metadata :
vars . actions_metadata [ vars . actions . get_last_key ( ) ] = { " Selected Text " : data , " Alternative Text " : [ ] }
2022-01-20 21:18:43 +01:00
else :
# 2. We've selected a chunk of text that is was presented previously
2022-01-24 18:54:44 +01:00
try :
alternatives = [ item [ ' Text ' ] for item in vars . actions_metadata [ len ( vars . actions ) - 1 ] [ " Alternative Text " ] ]
except :
print ( len ( vars . actions ) )
print ( vars . actions_metadata )
raise
2022-01-20 21:18:43 +01:00
if data in alternatives :
2022-03-04 20:14:44 +01:00
alternatives = [ item for item in vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] if item [ ' Text ' ] != data ]
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] = alternatives
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Selected Text " ] = data
2021-12-20 02:18:28 +01:00
update_story_chunk ( ' last ' )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-12-20 02:18:28 +01:00
if ( not vars . noai and vars . lua_koboldbridge . generating ) :
# Off to the tokenizer!
calcsubmit ( data )
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None and len ( vars . genseqs ) == 0 ) :
2021-12-20 02:18:28 +01:00
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
emit ( ' from_server ' , { ' cmd ' : ' scrolldown ' , ' data ' : ' ' } , broadcast = True )
break
else :
for i in range ( vars . numseqs ) :
vars . lua_koboldbridge . outputs [ i + 1 ] = " "
execute_outmod ( )
2021-12-13 01:27:20 +01:00
vars . lua_koboldbridge . regeneration_required = False
genout = [ ]
for i in range ( vars . numseqs ) :
genout . append ( { " generated_text " : vars . lua_koboldbridge . outputs [ i + 1 ] } )
assert type ( genout [ - 1 ] [ " generated_text " ] ) is str
if ( len ( genout ) == 1 ) :
genresult ( genout [ 0 ] [ " generated_text " ] )
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None ) :
2021-12-20 02:18:28 +01:00
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
2021-12-13 01:27:20 +01:00
else :
2022-01-10 22:36:15 +01:00
if ( not vars . abort and vars . lua_koboldbridge . restart_sequence is not None and vars . lua_koboldbridge . restart_sequence > 0 ) :
2021-12-20 02:18:28 +01:00
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [ " generated_text " ] )
data = " "
force_submit = True
2021-12-27 01:52:56 +01:00
disable_recentrng = True
2021-12-20 02:18:28 +01:00
continue
2021-12-13 01:27:20 +01:00
genselect ( genout )
2021-12-20 02:18:28 +01:00
set_aibusy ( 0 )
emit ( ' from_server ' , { ' cmd ' : ' scrolldown ' , ' data ' : ' ' } , broadcast = True )
break
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def actionretry ( data ) :
if ( vars . noai ) :
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : " Retry function unavailable in Read Only mode. " } )
return
2022-03-23 18:39:46 +01:00
if ( vars . recentrng is not None ) :
2022-03-24 03:09:35 +01:00
if ( not vars . aibusy ) :
randomGameRequest ( vars . recentrng , memory = vars . recentrngm )
2022-03-23 18:39:46 +01:00
return
2022-03-12 20:21:11 +01:00
if actionback ( ) :
2021-12-20 02:18:28 +01:00
actionsubmit ( " " , actionmode = vars . actionmode , force_submit = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
elif ( not vars . useprompt ) :
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : " Please enable \" Always Add Prompt \" to retry with your prompt. " } )
#==================================================================#
#
#==================================================================#
def actionback ( ) :
if ( vars . aibusy ) :
return
# Remove last index of actions and refresh game screen
if ( len ( vars . genseqs ) == 0 and len ( vars . actions ) > 0 ) :
2022-01-20 21:18:43 +01:00
# We are going to move the selected text to alternative text in the actions_metadata variable so we can redo this action
2022-03-04 20:14:44 +01:00
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Alternative Text ' ] = [ { ' Text ' : vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Selected Text ' ] ,
2022-01-20 21:18:43 +01:00
' Pinned ' : False ,
" Previous Selection " : True ,
2022-03-04 20:14:44 +01:00
" Edited " : False } ] + vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Alternative Text ' ]
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Selected Text ' ] = " "
2022-01-20 21:18:43 +01:00
2021-11-14 03:13:52 +01:00
last_key = vars . actions . get_last_key ( )
vars . actions . pop ( )
vars . recentback = True
remove_story_chunk ( last_key + 1 )
2022-02-28 14:31:26 +01:00
#for the redo to not get out of whack, need to reset the max # in the actions sequence
2022-03-05 16:31:28 +01:00
vars . actions . set_next_id ( last_key )
2022-03-12 20:21:11 +01:00
success = True
2021-11-14 03:13:52 +01:00
elif ( len ( vars . genseqs ) == 0 ) :
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : " Cannot delete the prompt. " } )
2022-03-12 20:21:11 +01:00
success = False
2021-11-14 03:13:52 +01:00
else :
vars . genseqs = [ ]
2022-03-12 20:21:11 +01:00
success = True
2022-01-24 18:54:44 +01:00
send_debug ( )
2022-03-12 20:21:11 +01:00
return success
2022-01-21 21:30:37 +01:00
def actionredo ( ) :
i = 0
2022-03-06 23:18:14 +01:00
#First we need to find the next valid key
#We might have deleted text so we don't want to show a redo for that blank chunk
restore_id = vars . actions . get_last_key ( ) + 1
if restore_id in vars . actions_metadata :
ok_to_use = False
while not ok_to_use :
for item in vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] :
if item [ ' Previous Selection ' ] and item [ ' Text ' ] != " " :
ok_to_use = True
if not ok_to_use :
restore_id + = 1
if restore_id not in vars . actions_metadata :
return
else :
vars . actions . set_next_id ( restore_id )
if restore_id in vars . actions_metadata :
genout = [ { " generated_text " : item [ ' Text ' ] } for item in vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] if ( item [ " Previous Selection " ] == True ) ]
2022-02-28 14:31:26 +01:00
if len ( genout ) > 0 :
2022-03-06 23:18:14 +01:00
genout = genout + [ { " generated_text " : item [ ' Text ' ] } for item in vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] if ( item [ " Pinned " ] == True ) and ( item [ " Previous Selection " ] == False ) ]
2022-02-28 14:31:26 +01:00
if len ( genout ) == 1 :
2022-03-06 23:18:14 +01:00
vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] = [ item for item in vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] if ( item [ " Previous Selection " ] != True ) ]
2022-03-05 17:15:33 +01:00
genresult ( genout [ 0 ] [ ' generated_text ' ] , flash = True , ignore_formatting = True )
2022-02-28 14:31:26 +01:00
else :
# Store sequences in memory until selection is made
vars . genseqs = genout
# Send sequences to UI for selection
2022-03-06 23:18:14 +01:00
genout = [ [ item [ ' Text ' ] , " redo " ] for item in vars . actions_metadata [ restore_id ] [ ' Alternative Text ' ] if ( item [ " Previous Selection " ] == True ) ]
2022-03-05 16:31:28 +01:00
2022-02-28 14:31:26 +01:00
emit ( ' from_server ' , { ' cmd ' : ' genseqs ' , ' data ' : genout } , broadcast = True )
2022-01-22 01:02:56 +01:00
else :
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : " There ' s nothing to undo " } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def calcsubmitbudgetheader ( txt , * * kwargs ) :
# Scan for WorldInfo matches
winfo , found_entries = checkworldinfo ( txt , * * kwargs )
# Add a newline to the end of memory
if ( vars . memory != " " and vars . memory [ - 1 ] != " \n " ) :
mem = vars . memory + " \n "
else :
mem = vars . memory
# Build Author's Note if set
if ( vars . authornote != " " ) :
2021-12-30 05:43:36 +01:00
anotetxt = ( " \n " + vars . authornotetemplate + " \n " ) . replace ( " <|> " , vars . authornote )
2021-11-14 03:13:52 +01:00
else :
anotetxt = " "
return winfo , mem , anotetxt , found_entries
2021-12-27 00:29:54 +01:00
def calcsubmitbudget ( actionlen , winfo , mem , anotetxt , actions , submission = None , budget_deduction = 0 ) :
2021-11-14 03:13:52 +01:00
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [ ] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
2022-01-17 19:17:20 +01:00
lnsp = vars . sp_length
2021-12-27 00:29:54 +01:00
2021-12-31 23:13:11 +01:00
if ( " tokenizer " not in globals ( ) ) :
from transformers import GPT2TokenizerFast
global tokenizer
2022-05-11 04:14:56 +02:00
tokenizer = GPT2TokenizerFast . from_pretrained ( " gpt2 " , revision = vars . revision , cache_dir = " cache " )
2021-12-31 23:13:11 +01:00
2021-11-14 03:13:52 +01:00
# Calculate token budget
2022-02-12 19:23:59 +01:00
prompttkns = tokenizer . encode ( utils . encodenewlines ( vars . comregex_ai . sub ( ' ' , vars . prompt ) ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
lnprompt = len ( prompttkns )
2021-12-27 00:29:54 +01:00
2022-02-12 19:23:59 +01:00
memtokens = tokenizer . encode ( utils . encodenewlines ( mem ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
lnmem = len ( memtokens )
2021-12-27 00:29:54 +01:00
if ( lnmem > vars . max_length - lnsp - vars . genamt - budget_deduction ) :
raise OverflowError ( " The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. " )
2022-02-12 19:23:59 +01:00
witokens = tokenizer . encode ( utils . encodenewlines ( winfo ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
lnwi = len ( witokens )
2021-12-27 00:29:54 +01:00
if ( lnmem + lnwi > vars . max_length - lnsp - vars . genamt - budget_deduction ) :
raise OverflowError ( " The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. " )
2021-11-14 03:13:52 +01:00
if ( anotetxt != " " ) :
2022-02-12 19:23:59 +01:00
anotetkns = tokenizer . encode ( utils . encodenewlines ( anotetxt ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
lnanote = len ( anotetkns )
2021-12-27 00:29:54 +01:00
if ( lnmem + lnwi + lnanote > vars . max_length - lnsp - vars . genamt - budget_deduction ) :
raise OverflowError ( " The author ' s note in your story is too long. Please either write a shorter author ' s note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. " )
2021-11-14 03:13:52 +01:00
if ( vars . useprompt ) :
2021-12-27 00:29:54 +01:00
budget = vars . max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars . genamt - budget_deduction
2021-11-14 03:13:52 +01:00
else :
2021-12-27 00:29:54 +01:00
budget = vars . max_length - lnsp - lnmem - lnanote - lnwi - vars . genamt - budget_deduction
2022-02-12 19:23:59 +01:00
lnsubmission = len ( tokenizer . encode ( utils . encodenewlines ( vars . comregex_ai . sub ( ' ' , submission ) ) , max_length = int ( 2e9 ) , truncation = True ) ) if submission is not None else 0
2021-12-27 00:29:54 +01:00
maybe_lnprompt = lnprompt if vars . useprompt and actionlen > 0 else 0
if ( lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars . max_length - lnsp - vars . genamt - budget_deduction ) :
raise OverflowError ( " Your submission is too long. Please either write a shorter submission or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. If you are using the Always Add Prompt setting, turning it off may help. " )
assert budget > = 0
2021-11-14 03:13:52 +01:00
if ( actionlen == 0 ) :
# First/Prompt action
2021-12-27 00:29:54 +01:00
tokens = memtokens + witokens + anotetkns + prompttkns
assert len ( tokens ) < = vars . max_length - lnsp - vars . genamt - budget_deduction
ln = len ( tokens ) + lnsp
return tokens , ln + 1 , ln + vars . genamt
2021-11-14 03:13:52 +01:00
else :
tokens = [ ]
# Check if we have the action depth to hit our A.N. depth
if ( anotetxt != " " and actionlen < vars . andepth ) :
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed ( actions ) :
2021-11-20 07:27:57 +01:00
chunk = vars . comregex_ai . sub ( ' ' , actions [ key ] )
2021-11-14 03:13:52 +01:00
2021-12-27 00:29:54 +01:00
assert budget > = 0
2021-11-14 03:13:52 +01:00
if ( budget < = 0 ) :
break
2022-02-12 19:23:59 +01:00
acttkns = tokenizer . encode ( utils . encodenewlines ( chunk ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
tknlen = len ( acttkns )
if ( tknlen < budget ) :
tokens = acttkns + tokens
budget - = tknlen
else :
count = budget * - 1
tokens = acttkns [ count : ] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if ( n == vars . andepth - 1 ) :
if ( anotetxt != " " ) :
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n + = 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if ( not vars . useprompt ) :
if ( budget > 0 ) :
prompttkns = prompttkns [ - budget : ]
else :
prompttkns = [ ]
2021-12-27 00:29:54 +01:00
2021-11-14 03:13:52 +01:00
# Did we get to add the A.N.? If not, do it here
if ( anotetxt != " " ) :
if ( ( not anoteadded ) or forceanote ) :
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else :
tokens = memtokens + witokens + prompttkns + tokens
else :
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
2021-12-27 00:29:54 +01:00
2021-11-14 03:13:52 +01:00
# Send completed bundle to generator
2021-12-27 00:29:54 +01:00
assert len ( tokens ) < = vars . max_length - lnsp - vars . genamt - budget_deduction
2021-11-14 03:13:52 +01:00
ln = len ( tokens ) + lnsp
2021-12-27 00:29:54 +01:00
return tokens , ln + 1 , ln + vars . genamt
2021-11-14 03:13:52 +01:00
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit ( txt ) :
anotetxt = " " # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len ( vars . actions )
winfo , mem , anotetxt , found_entries = calcsubmitbudgetheader ( txt )
2022-01-31 18:39:34 +01:00
2021-11-14 03:13:52 +01:00
# For all transformers models
if ( vars . model != " InferKit " ) :
2021-12-27 00:29:54 +01:00
subtxt , min , max = calcsubmitbudget ( actionlen , winfo , mem , anotetxt , vars . actions , submission = txt )
2021-11-14 03:13:52 +01:00
if ( actionlen == 0 ) :
2022-03-15 04:14:20 +01:00
if ( not vars . use_colab_tpu and vars . model not in [ " Colab " , " OAI " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ] ) :
2021-11-14 03:13:52 +01:00
generate ( subtxt , min , max , found_entries = found_entries )
elif ( vars . model == " Colab " ) :
2022-02-12 19:23:59 +01:00
sendtocolab ( utils . decodenewlines ( tokenizer . decode ( subtxt ) ) , min , max )
2021-11-14 03:13:52 +01:00
elif ( vars . model == " OAI " ) :
2022-02-12 19:23:59 +01:00
oairequest ( utils . decodenewlines ( tokenizer . decode ( subtxt ) ) , min , max )
2022-03-15 04:14:20 +01:00
elif ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2021-11-20 00:06:57 +01:00
tpumtjgenerate ( subtxt , min , max , found_entries = found_entries )
2021-11-14 03:13:52 +01:00
else :
2022-03-15 04:14:20 +01:00
if ( not vars . use_colab_tpu and vars . model not in [ " Colab " , " OAI " , " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ] ) :
2021-11-14 03:13:52 +01:00
generate ( subtxt , min , max , found_entries = found_entries )
elif ( vars . model == " Colab " ) :
2022-02-12 19:23:59 +01:00
sendtocolab ( utils . decodenewlines ( tokenizer . decode ( subtxt ) ) , min , max )
2021-11-14 03:13:52 +01:00
elif ( vars . model == " OAI " ) :
2022-02-12 19:23:59 +01:00
oairequest ( utils . decodenewlines ( tokenizer . decode ( subtxt ) ) , min , max )
2022-03-15 04:14:20 +01:00
elif ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2021-11-20 00:06:57 +01:00
tpumtjgenerate ( subtxt , min , max , found_entries = found_entries )
2021-11-14 03:13:52 +01:00
# For InferKit web API
else :
# Check if we have the action depth to hit our A.N. depth
if ( anotetxt != " " and actionlen < vars . andepth ) :
forceanote = True
if ( vars . useprompt ) :
2021-11-21 04:23:06 +01:00
budget = vars . ikmax - len ( vars . comregex_ai . sub ( ' ' , vars . prompt ) ) - len ( anotetxt ) - len ( mem ) - len ( winfo ) - 1
2021-11-14 03:13:52 +01:00
else :
budget = vars . ikmax - len ( anotetxt ) - len ( mem ) - len ( winfo ) - 1
subtxt = " "
2021-11-21 04:23:06 +01:00
prompt = vars . comregex_ai . sub ( ' ' , vars . prompt )
2021-11-14 03:13:52 +01:00
n = 0
for key in reversed ( vars . actions ) :
chunk = vars . actions [ key ]
if ( budget < = 0 ) :
break
actlen = len ( chunk )
if ( actlen < budget ) :
subtxt = chunk + subtxt
budget - = actlen
else :
count = budget * - 1
subtxt = chunk [ count : ] + subtxt
budget = 0
break
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if ( not vars . useprompt ) :
if ( budget > 0 ) :
2021-11-21 04:23:06 +01:00
prompt = vars . comregex_ai . sub ( ' ' , vars . prompt ) [ - budget : ]
2021-11-14 03:13:52 +01:00
else :
prompt = " "
# Inject Author's Note if we've reached the desired depth
if ( n == vars . andepth - 1 ) :
if ( anotetxt != " " ) :
subtxt = anotetxt + subtxt # A.N. len already taken from bdgt
anoteadded = True
n + = 1
# Did we get to add the A.N.? If not, do it here
if ( anotetxt != " " ) :
if ( ( not anoteadded ) or forceanote ) :
subtxt = mem + winfo + anotetxt + prompt + subtxt
else :
subtxt = mem + winfo + prompt + subtxt
else :
subtxt = mem + winfo + prompt + subtxt
# Send it!
ikrequest ( subtxt )
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
2021-12-13 23:19:04 +01:00
def _generate ( txt , minimum , maximum , found_entries ) :
2021-12-27 00:29:54 +01:00
gen_in = torch . tensor ( txt , dtype = torch . long ) [ None ]
2021-12-13 23:19:04 +01:00
if ( vars . sp is not None ) :
soft_tokens = torch . arange (
model . config . vocab_size ,
model . config . vocab_size + vars . sp . shape [ 0 ] ,
)
gen_in = torch . cat ( ( soft_tokens [ None ] , gen_in ) , dim = - 1 )
2021-12-27 00:29:54 +01:00
assert gen_in . shape [ - 1 ] + vars . genamt < = vars . max_length
2021-12-13 23:19:04 +01:00
if ( vars . hascuda and vars . usegpu ) :
2021-12-16 01:03:49 +01:00
gen_in = gen_in . to ( vars . gpu_device )
2021-12-13 23:19:04 +01:00
elif ( vars . hascuda and vars . breakmodel ) :
gen_in = gen_in . to ( breakmodel . primary_device )
else :
gen_in = gen_in . to ( ' cpu ' )
model . kai_scanner_excluded_world_info = found_entries
2021-12-20 02:18:28 +01:00
vars . _actions = vars . actions
vars . _prompt = vars . prompt
2021-12-13 23:19:04 +01:00
if ( vars . dynamicscan ) :
2021-12-20 02:18:28 +01:00
vars . _actions = vars . _actions . copy ( )
2021-12-13 23:19:04 +01:00
with torch . no_grad ( ) :
already_generated = 0
numseqs = vars . numseqs
while True :
genout = generator (
gen_in ,
do_sample = True ,
2021-12-20 02:18:28 +01:00
max_length = int ( 2e9 ) ,
2022-01-04 20:18:58 +01:00
repetition_penalty = 1.1 ,
2021-12-13 23:19:04 +01:00
bad_words_ids = vars . badwordsids ,
use_cache = True ,
num_return_sequences = numseqs
)
already_generated + = len ( genout [ 0 ] ) - len ( gen_in [ 0 ] )
2021-12-27 00:29:54 +01:00
assert already_generated < = vars . genamt
2021-12-13 23:19:04 +01:00
if ( model . kai_scanner . halt or not model . kai_scanner . regeneration_required ) :
break
assert genout . ndim > = 2
assert genout . shape [ 0 ] == vars . numseqs
2021-12-27 00:29:54 +01:00
if ( vars . lua_koboldbridge . generated_cols and vars . generated_tkns != vars . lua_koboldbridge . generated_cols ) :
raise RuntimeError ( " Inconsistency detected between KoboldAI Python and Lua backends " )
if ( already_generated != vars . generated_tkns ) :
2021-12-13 23:19:04 +01:00
raise RuntimeError ( " WI scanning error " )
for r in range ( vars . numseqs ) :
for c in range ( already_generated ) :
assert vars . lua_koboldbridge . generated [ r + 1 ] [ c + 1 ] is not None
2021-12-15 05:04:03 +01:00
genout [ r ] [ genout . shape [ - 1 ] - already_generated + c ] = vars . lua_koboldbridge . generated [ r + 1 ] [ c + 1 ]
2021-12-13 23:19:04 +01:00
encoded = [ ]
for i in range ( vars . numseqs ) :
2022-02-12 19:23:59 +01:00
txt = utils . decodenewlines ( tokenizer . decode ( genout [ i , - already_generated : ] ) )
2022-01-10 21:52:49 +01:00
winfo , mem , anotetxt , _found_entries = calcsubmitbudgetheader ( txt , force_use_txt = True , actions = vars . _actions )
2021-12-13 23:19:04 +01:00
found_entries [ i ] . update ( _found_entries )
2021-12-27 00:29:54 +01:00
txt , _ , _ = calcsubmitbudget ( len ( vars . _actions ) , winfo , mem , anotetxt , vars . _actions , submission = txt )
encoded . append ( torch . tensor ( txt , dtype = torch . long , device = genout . device ) )
2021-12-13 23:19:04 +01:00
max_length = len ( max ( encoded , key = len ) )
encoded = torch . stack ( tuple ( torch . nn . functional . pad ( e , ( max_length - len ( e ) , 0 ) , value = model . config . pad_token_id or model . config . eos_token_id ) for e in encoded ) )
genout = torch . cat (
(
encoded ,
genout [ . . . , - already_generated : ] ,
) ,
dim = - 1
)
if ( vars . sp is not None ) :
soft_tokens = torch . arange (
model . config . vocab_size ,
model . config . vocab_size + vars . sp . shape [ 0 ] ,
device = genout . device ,
)
genout = torch . cat ( ( soft_tokens . tile ( vars . numseqs , 1 ) , genout ) , dim = - 1 )
2021-12-27 00:29:54 +01:00
assert genout . shape [ - 1 ] + vars . genamt - already_generated < = vars . max_length
2021-12-13 23:19:04 +01:00
diff = genout . shape [ - 1 ] - gen_in . shape [ - 1 ]
minimum + = diff
maximum + = diff
gen_in = genout
numseqs = 1
return genout , already_generated
2021-11-17 22:17:59 +01:00
def generate ( txt , minimum , maximum , found_entries = None ) :
2021-12-27 00:29:54 +01:00
vars . generated_tkns = 0
2021-11-17 22:17:59 +01:00
if ( found_entries is None ) :
found_entries = set ( )
found_entries = tuple ( found_entries . copy ( ) for _ in range ( vars . numseqs ) )
2022-01-22 21:30:56 +01:00
if not vars . quiet :
2022-02-12 19:23:59 +01:00
print ( " {0} Min: {1} , Max: {2} , Txt: {3} {4} " . format ( colors . YELLOW , minimum , maximum , utils . decodenewlines ( tokenizer . decode ( txt ) ) , colors . END ) )
2021-12-13 23:19:04 +01:00
2021-11-14 03:13:52 +01:00
# Store context in memory to use it for comparison with generated content
2022-02-12 19:23:59 +01:00
vars . lastctx = utils . decodenewlines ( tokenizer . decode ( txt ) )
2021-12-13 23:19:04 +01:00
2021-11-14 03:13:52 +01:00
# Clear CUDA cache if using GPU
if ( vars . hascuda and ( vars . usegpu or vars . breakmodel ) ) :
gc . collect ( )
torch . cuda . empty_cache ( )
2021-12-13 23:19:04 +01:00
2021-11-14 03:13:52 +01:00
# Submit input text to generator
try :
2021-12-13 23:19:04 +01:00
genout , already_generated = tpool . execute ( _generate , txt , minimum , maximum , found_entries )
2021-11-14 03:13:52 +01:00
except Exception as e :
2021-12-13 08:32:09 +01:00
if ( issubclass ( type ( e ) , lupa . LuaError ) ) :
vars . lua_koboldbridge . obliterate_multiverse ( )
2021-12-23 05:33:27 +01:00
vars . lua_running = False
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Lua script error; please check console. ' } , broadcast = True )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
2021-12-13 17:47:34 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . YELLOW , " Lua engine stopped; please open ' Userscripts ' and press Load to reinitialize scripts. " , colors . END ) , file = sys . stderr )
else :
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Error occurred during generator call; please check console. ' } , broadcast = True )
2021-12-15 08:03:08 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , traceback . format_exc ( ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-11-14 03:13:52 +01:00
set_aibusy ( 0 )
return
2021-12-13 01:27:20 +01:00
for i in range ( vars . numseqs ) :
2021-12-27 00:29:54 +01:00
vars . lua_koboldbridge . generated [ i + 1 ] [ vars . generated_tkns ] = int ( genout [ i , - 1 ] . item ( ) )
2022-02-12 19:23:59 +01:00
vars . lua_koboldbridge . outputs [ i + 1 ] = utils . decodenewlines ( tokenizer . decode ( genout [ i , - already_generated : ] ) )
2021-12-11 18:45:45 +01:00
2021-12-12 07:52:42 +01:00
execute_outmod ( )
2021-12-13 01:27:20 +01:00
if ( vars . lua_koboldbridge . regeneration_required ) :
vars . lua_koboldbridge . regeneration_required = False
genout = [ ]
for i in range ( vars . numseqs ) :
genout . append ( { " generated_text " : vars . lua_koboldbridge . outputs [ i + 1 ] } )
assert type ( genout [ - 1 ] [ " generated_text " ] ) is str
else :
2022-02-12 19:23:59 +01:00
genout = [ { " generated_text " : utils . decodenewlines ( tokenizer . decode ( tokens [ - already_generated : ] ) ) } for tokens in genout ]
2021-11-14 03:13:52 +01:00
if ( len ( genout ) == 1 ) :
genresult ( genout [ 0 ] [ " generated_text " ] )
else :
2021-12-20 02:18:28 +01:00
if ( vars . lua_koboldbridge . restart_sequence is not None and vars . lua_koboldbridge . restart_sequence > 0 ) :
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [ " generated_text " ] )
else :
genselect ( genout )
2021-11-14 03:13:52 +01:00
# Clear CUDA cache again if using GPU
if ( vars . hascuda and ( vars . usegpu or vars . breakmodel ) ) :
del genout
gc . collect ( )
torch . cuda . empty_cache ( )
set_aibusy ( 0 )
#==================================================================#
# Deal with a single return sequence from generate()
#==================================================================#
2022-03-05 17:15:33 +01:00
def genresult ( genout , flash = True , ignore_formatting = False ) :
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} {1} {2} " . format ( colors . CYAN , genout , colors . END ) )
2021-11-14 03:13:52 +01:00
# Format output before continuing
2022-03-05 17:15:33 +01:00
if not ignore_formatting :
genout = applyoutputformatting ( genout )
2021-12-20 02:18:28 +01:00
vars . lua_koboldbridge . feedback = genout
if ( len ( genout ) == 0 ) :
return
2021-11-14 03:13:52 +01:00
# Add formatted text to Actions array and refresh the game screen
if ( len ( vars . prompt . strip ( ) ) == 0 ) :
vars . prompt = genout
else :
vars . actions . append ( genout )
2022-03-04 20:14:44 +01:00
if vars . actions . get_last_key ( ) not in vars . actions_metadata :
vars . actions_metadata [ vars . actions . get_last_key ( ) ] = { ' Selected Text ' : genout , ' Alternative Text ' : [ ] }
2022-01-21 21:30:37 +01:00
else :
2022-03-04 20:14:44 +01:00
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Selected Text ' ] = genout
2021-11-14 03:13:52 +01:00
update_story_chunk ( ' last ' )
2021-12-29 20:23:22 +01:00
if ( flash ) :
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : vars . actions . get_last_key ( ) + 1 if len ( vars . actions ) else 0 } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Send generator sequences to the UI for selection
#==================================================================#
def genselect ( genout ) :
i = 0
for result in genout :
# Apply output formatting rules to sequences
result [ " generated_text " ] = applyoutputformatting ( result [ " generated_text " ] )
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} [Result {1} ] \n {2} {3} " . format ( colors . CYAN , i , result [ " generated_text " ] , colors . END ) )
2021-11-14 03:13:52 +01:00
i + = 1
2022-01-20 21:18:43 +01:00
# Add the options to the actions metadata
# If we've already generated text for this action but haven't selected one we'll want to kill all non-pinned, non-previous selection, and non-edited options then add the new ones
2022-02-28 14:31:26 +01:00
if vars . actions . get_next_id ( ) in vars . actions_metadata :
if ( vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Selected Text ' ] == " " ) :
vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ] = [ { " Text " : item [ ' Text ' ] , " Pinned " : item [ ' Pinned ' ] ,
2022-01-21 21:30:37 +01:00
" Previous Selection " : item [ " Previous Selection " ] ,
2022-02-28 14:31:26 +01:00
" Edited " : item [ " Edited " ] } for item in vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ]
2022-01-21 21:30:37 +01:00
if item [ ' Pinned ' ] or item [ " Previous Selection " ] or item [ " Edited " ] ] + [ { " Text " : text [ " generated_text " ] ,
" Pinned " : False , " Previous Selection " : False , " Edited " : False } for text in genout ]
else :
2022-02-28 14:31:26 +01:00
vars . actions_metadata [ vars . actions . get_next_id ( ) ] = { ' Selected Text ' : ' ' , ' Alternative Text ' : [ { " Text " : text [ " generated_text " ] , " Pinned " : False , " Previous Selection " : False , " Edited " : False } for text in genout ] }
2022-01-20 21:18:43 +01:00
else :
2022-02-28 14:31:26 +01:00
vars . actions_metadata [ vars . actions . get_next_id ( ) ] = { ' Selected Text ' : ' ' , ' Alternative Text ' : [ { " Text " : text [ " generated_text " ] , " Pinned " : False , " Previous Selection " : False , " Edited " : False } for text in genout ] }
2022-01-20 21:18:43 +01:00
2022-02-28 14:31:26 +01:00
genout = [ { " generated_text " : item [ ' Text ' ] } for item in vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ] if ( item [ " Previous Selection " ] == False ) and ( item [ " Edited " ] == False ) ]
2022-01-21 21:30:37 +01:00
2021-11-14 03:13:52 +01:00
# Store sequences in memory until selection is made
vars . genseqs = genout
2022-02-28 14:31:26 +01:00
genout = [ [ item [ ' Text ' ] , " pinned " if item [ ' Pinned ' ] else " normal " ] for item in vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ] if ( item [ " Previous Selection " ] == False ) and ( item [ " Edited " ] == False ) ]
2022-01-22 01:02:56 +01:00
2021-11-14 03:13:52 +01:00
# Send sequences to UI for selection
emit ( ' from_server ' , { ' cmd ' : ' genseqs ' , ' data ' : genout } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Send selected sequence to action log and refresh UI
#==================================================================#
def selectsequence ( n ) :
if ( len ( vars . genseqs ) == 0 ) :
return
2021-12-20 02:18:28 +01:00
vars . lua_koboldbridge . feedback = vars . genseqs [ int ( n ) ] [ " generated_text " ]
if ( len ( vars . lua_koboldbridge . feedback ) != 0 ) :
vars . actions . append ( vars . lua_koboldbridge . feedback )
2022-01-20 21:18:43 +01:00
#We'll want to remove the option from the alternative text and put it in selected text
2022-03-04 20:14:44 +01:00
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Alternative Text ' ] = [ item for item in vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Alternative Text ' ] if item [ ' Text ' ] != vars . lua_koboldbridge . feedback ]
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ ' Selected Text ' ] = vars . lua_koboldbridge . feedback
2021-12-20 02:18:28 +01:00
update_story_chunk ( ' last ' )
2021-12-29 20:23:22 +01:00
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : vars . actions . get_last_key ( ) + 1 if len ( vars . actions ) else 0 } , broadcast = True )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' hidegenseqs ' , ' data ' : ' ' } , broadcast = True )
vars . genseqs = [ ]
2021-12-20 02:18:28 +01:00
if ( vars . lua_koboldbridge . restart_sequence is not None ) :
2021-12-27 01:52:56 +01:00
actionsubmit ( " " , actionmode = vars . actionmode , force_submit = True , disable_recentrng = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-12-20 02:18:28 +01:00
2022-01-21 21:30:37 +01:00
#==================================================================#
# Pin/Unpin the selected sequence
#==================================================================#
def pinsequence ( n ) :
if n . isnumeric ( ) :
text = vars . genseqs [ int ( n ) ] [ ' generated_text ' ]
2022-02-28 14:31:26 +01:00
if text in [ item [ ' Text ' ] for item in vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ] ] :
alternatives = vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ]
2022-01-21 21:30:37 +01:00
for i in range ( len ( alternatives ) ) :
if alternatives [ i ] [ ' Text ' ] == text :
alternatives [ i ] [ ' Pinned ' ] = not alternatives [ i ] [ ' Pinned ' ]
break
2022-02-28 14:31:26 +01:00
vars . actions_metadata [ vars . actions . get_next_id ( ) ] [ ' Alternative Text ' ] = alternatives
2022-01-24 18:54:44 +01:00
send_debug ( )
2022-01-21 21:30:37 +01:00
2021-11-14 03:13:52 +01:00
#==================================================================#
# Send transformers-style request to ngrok/colab host
#==================================================================#
def sendtocolab ( txt , min , max ) :
# Log request to console
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} Tokens: {1} , Txt: {2} {3} " . format ( colors . YELLOW , min - 1 , txt , colors . END ) )
2021-11-14 03:13:52 +01:00
# Store context in memory to use it for comparison with generated content
vars . lastctx = txt
# Build request JSON data
reqdata = {
' text ' : txt ,
' min ' : min ,
' max ' : max ,
' rep_pen ' : vars . rep_pen ,
2022-01-24 21:30:38 +01:00
' rep_pen_slope ' : vars . rep_pen_slope ,
' rep_pen_range ' : vars . rep_pen_range ,
2021-11-14 03:13:52 +01:00
' temperature ' : vars . temp ,
' top_p ' : vars . top_p ,
' top_k ' : vars . top_k ,
' tfs ' : vars . tfs ,
2022-03-27 22:25:50 +02:00
' typical ' : vars . typical ,
2021-11-14 03:13:52 +01:00
' numseqs ' : vars . numseqs ,
' retfultxt ' : False
}
# Create request
req = requests . post (
vars . colaburl ,
json = reqdata
)
# Deal with the response
if ( req . status_code == 200 ) :
js = req . json ( ) [ " data " ]
# Try to be backwards compatible with outdated colab
if ( " text " in js ) :
genout = [ getnewcontent ( js [ " text " ] ) ]
else :
genout = js [ " seqs " ]
2021-12-13 01:27:20 +01:00
for i in range ( vars . numseqs ) :
vars . lua_koboldbridge . outputs [ i + 1 ] = genout [ i ]
2021-12-12 07:52:42 +01:00
execute_outmod ( )
2021-12-13 01:27:20 +01:00
if ( vars . lua_koboldbridge . regeneration_required ) :
vars . lua_koboldbridge . regeneration_required = False
genout = [ ]
for i in range ( vars . numseqs ) :
genout . append ( vars . lua_koboldbridge . outputs [ i + 1 ] )
assert type ( genout [ - 1 ] ) is str
2021-11-14 03:13:52 +01:00
if ( len ( genout ) == 1 ) :
genresult ( genout [ 0 ] )
else :
# Convert torch output format to transformers
seqs = [ ]
for seq in genout :
seqs . append ( { " generated_text " : seq } )
2021-12-20 02:18:28 +01:00
if ( vars . lua_koboldbridge . restart_sequence is not None and vars . lua_koboldbridge . restart_sequence > 0 ) :
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [ " generated_text " ] )
else :
genselect ( genout )
2021-11-14 03:13:52 +01:00
# Format output before continuing
#genout = applyoutputformatting(getnewcontent(genout))
# Add formatted text to Actions array and refresh the game screen
#vars.actions.append(genout)
#refresh_story()
2021-12-29 20:23:22 +01:00
#emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0})
2021-11-14 03:13:52 +01:00
set_aibusy ( 0 )
else :
errmsg = " Colab API Error: Failed to get a reply from the server. Please check the colab console. "
print ( " {0} {1} {2} " . format ( colors . RED , errmsg , colors . END ) )
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : errmsg } , broadcast = True )
set_aibusy ( 0 )
2021-11-20 00:06:57 +01:00
#==================================================================#
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate ( txt , minimum , maximum , found_entries = None ) :
2022-01-15 03:39:02 +01:00
vars . generated_tkns = 0
2021-11-20 00:06:57 +01:00
if ( found_entries is None ) :
found_entries = set ( )
found_entries = tuple ( found_entries . copy ( ) for _ in range ( vars . numseqs ) )
2022-01-22 21:30:56 +01:00
if not vars . quiet :
2022-02-12 19:23:59 +01:00
print ( " {0} Min: {1} , Max: {2} , Txt: {3} {4} " . format ( colors . YELLOW , minimum , maximum , utils . decodenewlines ( tokenizer . decode ( txt ) ) , colors . END ) )
2021-11-20 00:06:57 +01:00
2022-01-15 03:39:02 +01:00
vars . _actions = vars . actions
vars . _prompt = vars . prompt
if ( vars . dynamicscan ) :
vars . _actions = vars . _actions . copy ( )
2021-11-20 00:06:57 +01:00
# Submit input text to generator
try :
2022-01-07 19:47:21 +01:00
soft_tokens = tpumtjgetsofttokens ( )
2021-11-22 00:08:04 +01:00
2022-01-15 03:39:02 +01:00
global past
2022-01-17 03:09:10 +01:00
socketio . start_background_task ( copy_current_request_context ( check_for_backend_compilation ) )
2022-01-16 05:31:07 +01:00
if ( vars . dynamicscan or ( not vars . nogenmod and vars . has_genmod ) ) :
context = np . tile ( np . uint32 ( txt ) , ( vars . numseqs , 1 ) )
past = np . empty ( ( vars . numseqs , 0 ) , dtype = np . uint32 )
while ( True ) :
genout , n_generated , regeneration_required , halt = tpool . execute (
tpu_mtj_backend . infer_dynamic ,
context ,
gen_len = maximum - minimum + 1 ,
numseqs = vars . numseqs ,
soft_embeddings = vars . sp ,
soft_tokens = soft_tokens ,
excluded_world_info = found_entries ,
)
past = np . pad ( past , ( ( 0 , 0 ) , ( 0 , n_generated ) ) )
for r in range ( vars . numseqs ) :
for c in range ( vars . lua_koboldbridge . generated_cols ) :
assert vars . lua_koboldbridge . generated [ r + 1 ] [ c + 1 ] is not None
past [ r , c ] = vars . lua_koboldbridge . generated [ r + 1 ] [ c + 1 ]
2022-01-17 20:10:32 +01:00
if ( vars . abort or halt or not regeneration_required ) :
2022-01-16 05:31:07 +01:00
break
print ( " (regeneration triggered) " )
encoded = [ ]
for i in range ( vars . numseqs ) :
2022-02-12 19:23:59 +01:00
txt = utils . decodenewlines ( tokenizer . decode ( past [ i ] ) )
2022-01-17 19:24:11 +01:00
winfo , mem , anotetxt , _found_entries = calcsubmitbudgetheader ( txt , force_use_txt = True , actions = vars . _actions )
2022-01-16 05:31:07 +01:00
found_entries [ i ] . update ( _found_entries )
txt , _ , _ = calcsubmitbudget ( len ( vars . _actions ) , winfo , mem , anotetxt , vars . _actions , submission = txt )
encoded . append ( np . array ( txt , dtype = np . uint32 ) )
max_length = len ( max ( encoded , key = len ) )
encoded = np . stack ( tuple ( np . pad ( e , ( max_length - len ( e ) , 0 ) , constant_values = tpu_mtj_backend . pad_token_id ) for e in encoded ) )
context = np . concatenate (
(
encoded ,
past ,
) ,
axis = - 1 ,
)
else :
genout = tpool . execute (
tpu_mtj_backend . infer_static ,
np . uint32 ( txt ) ,
2022-01-15 03:39:02 +01:00
gen_len = maximum - minimum + 1 ,
temp = vars . temp ,
top_p = vars . top_p ,
top_k = vars . top_k ,
tfs = vars . tfs ,
2022-03-27 22:25:50 +02:00
typical = vars . typical ,
2022-01-15 03:39:02 +01:00
numseqs = vars . numseqs ,
repetition_penalty = vars . rep_pen ,
2022-01-24 21:30:38 +01:00
rpslope = vars . rep_pen_slope ,
rprange = vars . rep_pen_range ,
2022-01-15 03:39:02 +01:00
soft_embeddings = vars . sp ,
soft_tokens = soft_tokens ,
)
2022-01-16 05:31:07 +01:00
past = genout
2022-01-15 03:39:02 +01:00
for i in range ( vars . numseqs ) :
2022-01-16 05:31:07 +01:00
vars . lua_koboldbridge . generated [ i + 1 ] = vars . lua_state . table ( * genout [ i ] . tolist ( ) )
2022-01-17 20:52:29 +01:00
vars . lua_koboldbridge . generated_cols = vars . generated_tkns = genout [ 0 ] . shape [ - 1 ]
2021-11-20 00:06:57 +01:00
except Exception as e :
2021-12-13 08:32:09 +01:00
if ( issubclass ( type ( e ) , lupa . LuaError ) ) :
vars . lua_koboldbridge . obliterate_multiverse ( )
2021-12-23 05:33:27 +01:00
vars . lua_running = False
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Lua script error; please check console. ' } , broadcast = True )
2021-12-23 05:33:27 +01:00
sendUSStatItems ( )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , " ***LUA ERROR***: " , colors . END ) , end = " " , file = sys . stderr )
2021-12-13 17:47:34 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-12-13 08:32:09 +01:00
print ( " {0} {1} {2} " . format ( colors . YELLOW , " Lua engine stopped; please open ' Userscripts ' and press Load to reinitialize scripts. " , colors . END ) , file = sys . stderr )
else :
2022-01-14 04:33:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : ' Error occurred during generator call; please check console. ' } , broadcast = True )
2021-12-15 08:03:08 +01:00
print ( " {0} {1} {2} " . format ( colors . RED , traceback . format_exc ( ) . replace ( " \033 " , " " ) , colors . END ) , file = sys . stderr )
2021-11-20 00:06:57 +01:00
set_aibusy ( 0 )
return
2022-01-15 03:39:02 +01:00
2021-12-13 01:27:20 +01:00
for i in range ( vars . numseqs ) :
2022-02-12 19:23:59 +01:00
vars . lua_koboldbridge . outputs [ i + 1 ] = utils . decodenewlines ( tokenizer . decode ( past [ i ] ) )
2022-01-15 03:39:02 +01:00
genout = past
2021-11-20 00:06:57 +01:00
2021-12-12 07:52:42 +01:00
execute_outmod ( )
2021-12-13 01:27:20 +01:00
if ( vars . lua_koboldbridge . regeneration_required ) :
vars . lua_koboldbridge . regeneration_required = False
genout = [ ]
for i in range ( vars . numseqs ) :
genout . append ( { " generated_text " : vars . lua_koboldbridge . outputs [ i + 1 ] } )
assert type ( genout [ - 1 ] [ " generated_text " ] ) is str
else :
2022-02-12 19:23:59 +01:00
genout = [ { " generated_text " : utils . decodenewlines ( tokenizer . decode ( txt ) ) } for txt in genout ]
2021-12-11 18:45:45 +01:00
2021-11-20 00:06:57 +01:00
if ( len ( genout ) == 1 ) :
genresult ( genout [ 0 ] [ " generated_text " ] )
else :
2021-12-20 02:18:28 +01:00
if ( vars . lua_koboldbridge . restart_sequence is not None and vars . lua_koboldbridge . restart_sequence > 0 ) :
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [ " generated_text " ] )
else :
genselect ( genout )
2021-11-20 00:06:57 +01:00
set_aibusy ( 0 )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Replaces returns and newlines with HTML breaks
#==================================================================#
def formatforhtml ( txt ) :
2022-01-28 19:21:05 +01:00
return txt . replace ( " \\ r \\ n " , " <br/> " ) . replace ( " \\ r " , " <br/> " ) . replace ( " \\ n " , " <br/> " ) . replace ( " \r \n " , " <br/> " ) . replace ( ' \n ' , ' <br/> ' ) . replace ( ' \r ' , ' <br/> ' ) . replace ( ' </s> ' , ' <br/> ' )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Strips submitted text from the text returned by the AI
#==================================================================#
def getnewcontent ( txt ) :
# If the submitted context was blank, then everything is new
if ( vars . lastctx == " " ) :
return txt
# Tokenize the last context and the generated content
2022-02-12 19:23:59 +01:00
ctxtokens = tokenizer . encode ( utils . encodenewlines ( vars . lastctx ) , max_length = int ( 2e9 ) , truncation = True )
txttokens = tokenizer . encode ( utils . encodenewlines ( txt ) , max_length = int ( 2e9 ) , truncation = True )
2021-11-14 03:13:52 +01:00
dif = ( len ( txttokens ) - len ( ctxtokens ) ) * - 1
# Remove the context from the returned text
newtokens = txttokens [ dif : ]
2022-02-12 19:23:59 +01:00
return utils . decodenewlines ( tokenizer . decode ( newtokens ) )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Applies chosen formatting options to text submitted to AI
#==================================================================#
def applyinputformatting ( txt ) :
# Add sentence spacing
if ( vars . formatoptns [ " frmtadsnsp " ] ) :
txt = utils . addsentencespacing ( txt , vars )
2022-01-31 18:39:34 +01:00
2021-11-14 03:13:52 +01:00
return txt
#==================================================================#
# Applies chosen formatting options to text returned from AI
#==================================================================#
def applyoutputformatting ( txt ) :
# Use standard quotes and apostrophes
txt = utils . fixquotes ( txt )
# Adventure mode clipping of all characters after '>'
if ( vars . adventure ) :
txt = vars . acregex_ai . sub ( ' ' , txt )
# Trim incomplete sentences
2021-12-26 01:51:32 +01:00
if ( vars . formatoptns [ " frmttriminc " ] and not vars . chatmode ) :
2021-11-14 03:13:52 +01:00
txt = utils . trimincompletesentence ( txt )
# Replace blank lines
2021-12-27 13:32:25 +01:00
if ( vars . formatoptns [ " frmtrmblln " ] or vars . chatmode ) :
2021-11-14 03:13:52 +01:00
txt = utils . replaceblanklines ( txt )
# Remove special characters
if ( vars . formatoptns [ " frmtrmspch " ] ) :
txt = utils . removespecialchars ( txt , vars )
# Single Line Mode
2021-12-26 01:51:32 +01:00
if ( vars . formatoptns [ " singleline " ] or vars . chatmode ) :
2021-11-14 03:13:52 +01:00
txt = utils . singlelineprocessing ( txt , vars )
return txt
#==================================================================#
# Sends the current story content to the Game Screen
#==================================================================#
def refresh_story ( ) :
2021-11-21 06:26:45 +01:00
text_parts = [ ' <chunk n= " 0 " id= " n0 " tabindex= " -1 " > ' , vars . comregex_ui . sub ( lambda m : ' \n ' . join ( ' <comment> ' + l + ' </comment> ' for l in m . group ( ) . split ( ' \n ' ) ) , html . escape ( vars . prompt ) ) , ' </chunk> ' ]
2021-11-14 03:13:52 +01:00
for idx in vars . actions :
item = vars . actions [ idx ]
idx + = 1
item = html . escape ( item )
2021-11-20 07:27:57 +01:00
item = vars . comregex_ui . sub ( lambda m : ' \n ' . join ( ' <comment> ' + l + ' </comment> ' for l in m . group ( ) . split ( ' \n ' ) ) , item ) # Add special formatting to comments
2021-11-14 03:13:52 +01:00
item = vars . acregex_ui . sub ( ' <action> \\ 1</action> ' , item ) # Add special formatting to adventure actions
text_parts . extend ( ( ' <chunk n= " ' , str ( idx ) , ' " id= " n ' , str ( idx ) , ' " tabindex= " -1 " > ' , item , ' </chunk> ' ) )
emit ( ' from_server ' , { ' cmd ' : ' updatescreen ' , ' gamestarted ' : vars . gamestarted , ' data ' : formatforhtml ( ' ' . join ( text_parts ) ) } , broadcast = True )
#==================================================================#
# Signals the Game Screen to update one of the chunks
#==================================================================#
def update_story_chunk ( idx : Union [ int , str ] ) :
if idx == ' last ' :
if len ( vars . actions ) < = 1 :
# In this case, we are better off just refreshing the whole thing as the
# prompt might not have been shown yet (with a "Generating story..."
# message instead).
refresh_story ( )
2022-01-21 22:39:51 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
return
idx = ( vars . actions . get_last_key ( ) if len ( vars . actions ) else 0 ) + 1
if idx == 0 :
text = vars . prompt
else :
# Actions are 0 based, but in chunks 0 is the prompt.
# So the chunk index is one more than the corresponding action index.
2022-01-22 03:19:32 +01:00
if ( idx - 1 not in vars . actions ) :
return
2021-11-14 03:13:52 +01:00
text = vars . actions [ idx - 1 ]
item = html . escape ( text )
2021-11-20 07:27:57 +01:00
item = vars . comregex_ui . sub ( lambda m : ' \n ' . join ( ' <comment> ' + l + ' </comment> ' for l in m . group ( ) . split ( ' \n ' ) ) , item ) # Add special formatting to comments
2021-11-14 03:13:52 +01:00
item = vars . acregex_ui . sub ( ' <action> \\ 1</action> ' , item ) # Add special formatting to adventure actions
chunk_text = f ' <chunk n= " { idx } " id= " n { idx } " tabindex= " -1 " > { formatforhtml ( item ) } </chunk> '
emit ( ' from_server ' , { ' cmd ' : ' updatechunk ' , ' data ' : { ' index ' : idx , ' html ' : chunk_text } } , broadcast = True )
2022-01-21 22:39:51 +01:00
setgamesaved ( False )
2022-01-20 13:46:34 +01:00
#If we've set the auto save flag, we'll now save the file
if vars . autosave and ( " .json " in vars . savedir ) :
save ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Signals the Game Screen to remove one of the chunks
#==================================================================#
def remove_story_chunk ( idx : int ) :
emit ( ' from_server ' , { ' cmd ' : ' removechunk ' , ' data ' : idx } , broadcast = True )
2022-01-21 22:39:51 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Sends the current generator settings to the Game Menu
#==================================================================#
def refresh_settings ( ) :
# Suppress toggle change events while loading state
emit ( ' from_server ' , { ' cmd ' : ' allowtoggle ' , ' data ' : False } , broadcast = True )
if ( vars . model != " InferKit " ) :
emit ( ' from_server ' , { ' cmd ' : ' updatetemp ' , ' data ' : vars . temp } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatetopp ' , ' data ' : vars . top_p } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatetopk ' , ' data ' : vars . top_k } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatetfs ' , ' data ' : vars . tfs } , broadcast = True )
2022-03-27 22:25:50 +02:00
emit ( ' from_server ' , { ' cmd ' : ' updatetypical ' , ' data ' : vars . typical } , broadcast = True )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatereppen ' , ' data ' : vars . rep_pen } , broadcast = True )
2022-01-24 21:30:38 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatereppenslope ' , ' data ' : vars . rep_pen_slope } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatereppenrange ' , ' data ' : vars . rep_pen_range } , broadcast = True )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updateoutlen ' , ' data ' : vars . genamt } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatetknmax ' , ' data ' : vars . max_length } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatenumseq ' , ' data ' : vars . numseqs } , broadcast = True )
else :
emit ( ' from_server ' , { ' cmd ' : ' updatetemp ' , ' data ' : vars . temp } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatetopp ' , ' data ' : vars . top_p } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updateikgen ' , ' data ' : vars . ikgen } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updateanotedepth ' , ' data ' : vars . andepth } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatewidepth ' , ' data ' : vars . widepth } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updateuseprompt ' , ' data ' : vars . useprompt } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updateadventure ' , ' data ' : vars . adventure } , broadcast = True )
2021-12-26 01:51:32 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatechatmode ' , ' data ' : vars . chatmode } , broadcast = True )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatedynamicscan ' , ' data ' : vars . dynamicscan } , broadcast = True )
2022-03-20 19:12:11 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updateautosave ' , ' data ' : vars . autosave } , broadcast = True )
2021-12-16 12:47:44 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatenopromptgen ' , ' data ' : vars . nopromptgen } , broadcast = True )
2021-12-30 05:15:59 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updaterngpersist ' , ' data ' : vars . rngpersist } , broadcast = True )
2022-01-16 05:31:07 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatenogenmod ' , ' data ' : vars . nogenmod } , broadcast = True )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' updatefrmttriminc ' , ' data ' : vars . formatoptns [ " frmttriminc " ] } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatefrmtrmblln ' , ' data ' : vars . formatoptns [ " frmtrmblln " ] } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatefrmtrmspch ' , ' data ' : vars . formatoptns [ " frmtrmspch " ] } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatefrmtadsnsp ' , ' data ' : vars . formatoptns [ " frmtadsnsp " ] } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' updatesingleline ' , ' data ' : vars . formatoptns [ " singleline " ] } , broadcast = True )
# Allow toggle events again
emit ( ' from_server ' , { ' cmd ' : ' allowtoggle ' , ' data ' : True } , broadcast = True )
#==================================================================#
# Sets the logical and display states for the AI Busy condition
#==================================================================#
def set_aibusy ( state ) :
if ( state ) :
vars . aibusy = True
2022-03-07 18:33:35 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' wait ' } , broadcast = True )
2021-11-14 03:13:52 +01:00
else :
vars . aibusy = False
2022-03-07 18:33:35 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' ready ' } , broadcast = True )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def editrequest ( n ) :
if ( n == 0 ) :
txt = vars . prompt
else :
txt = vars . actions [ n - 1 ]
vars . editln = n
emit ( ' from_server ' , { ' cmd ' : ' setinputtext ' , ' data ' : txt } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' enablesubmit ' , ' data ' : ' ' } , broadcast = True )
#==================================================================#
#
#==================================================================#
def editsubmit ( data ) :
vars . recentedit = True
if ( vars . editln == 0 ) :
vars . prompt = data
else :
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ vars . editln - 1 ] [ ' Alternative Text ' ] = vars . actions_metadata [ vars . editln - 1 ] [ ' Alternative Text ' ] + [ { " Text " : vars . actions [ vars . editln - 1 ] , " Pinned " : False ,
" Previous Selection " : False ,
" Edited " : True } ]
vars . actions_metadata [ vars . editln - 1 ] [ ' Selected Text ' ] = data
2021-11-14 03:13:52 +01:00
vars . actions [ vars . editln - 1 ] = data
vars . mode = " play "
update_story_chunk ( vars . editln )
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : vars . editln } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def deleterequest ( ) :
vars . recentedit = True
# Don't delete prompt
if ( vars . editln == 0 ) :
# Send error message
pass
else :
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ vars . editln - 1 ] [ ' Alternative Text ' ] = [ { " Text " : vars . actions [ vars . editln - 1 ] , " Pinned " : False ,
" Previous Selection " : True , " Edited " : False } ] + vars . actions_metadata [ vars . editln - 1 ] [ ' Alternative Text ' ]
vars . actions_metadata [ vars . editln - 1 ] [ ' Selected Text ' ] = ' '
vars . actions [ vars . editln - 1 ] = ' '
2021-11-14 03:13:52 +01:00
vars . mode = " play "
remove_story_chunk ( vars . editln )
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def inlineedit ( chunk , data ) :
vars . recentedit = True
chunk = int ( chunk )
if ( chunk == 0 ) :
if ( len ( data . strip ( ) ) == 0 ) :
return
vars . prompt = data
else :
2022-01-01 00:28:03 +01:00
if ( chunk - 1 in vars . actions ) :
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ] = vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ] + [ { " Text " : vars . actions [ chunk - 1 ] , " Pinned " : False ,
" Previous Selection " : False ,
" Edited " : True } ]
vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] = data
2022-01-01 00:28:03 +01:00
vars . actions [ chunk - 1 ] = data
2022-01-22 03:19:32 +01:00
else :
print ( f " WARNING: Attempted to edit non-existent chunk { chunk } " )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
update_story_chunk ( chunk )
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : chunk } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def inlinedelete ( chunk ) :
vars . recentedit = True
chunk = int ( chunk )
# Don't delete prompt
if ( chunk == 0 ) :
# Send error message
update_story_chunk ( chunk )
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : " Cannot delete the prompt. " } )
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } , broadcast = True )
else :
2022-01-01 00:28:03 +01:00
if ( chunk - 1 in vars . actions ) :
2022-01-20 21:18:43 +01:00
vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ] = [ { " Text " : vars . actions [ chunk - 1 ] , " Pinned " : False ,
" Previous Selection " : True ,
" Edited " : False } ] + vars . actions_metadata [ chunk - 1 ] [ ' Alternative Text ' ]
vars . actions_metadata [ chunk - 1 ] [ ' Selected Text ' ] = ' '
vars . actions [ chunk - 1 ] = ' '
2022-01-22 03:19:32 +01:00
else :
print ( f " WARNING: Attempted to delete non-existent chunk { chunk } " )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
remove_story_chunk ( chunk )
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Toggles the game mode for memory editing and sends UI commands
#==================================================================#
def togglememorymode ( ) :
if ( vars . mode == " play " ) :
vars . mode = " memory "
emit ( ' from_server ' , { ' cmd ' : ' memmode ' , ' data ' : ' true ' } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setinputtext ' , ' data ' : vars . memory } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
elif ( vars . mode == " memory " ) :
vars . mode = " play "
emit ( ' from_server ' , { ' cmd ' : ' memmode ' , ' data ' : ' false ' } , broadcast = True )
#==================================================================#
# Toggles the game mode for WI editing and sends UI commands
#==================================================================#
def togglewimode ( ) :
if ( vars . mode == " play " ) :
vars . mode = " wi "
emit ( ' from_server ' , { ' cmd ' : ' wimode ' , ' data ' : ' true ' } , broadcast = True )
elif ( vars . mode == " wi " ) :
# Commit WI fields first
requestwi ( )
# Then set UI state back to Play
vars . mode = " play "
emit ( ' from_server ' , { ' cmd ' : ' wimode ' , ' data ' : ' false ' } , broadcast = True )
sendwi ( )
#==================================================================#
#
#==================================================================#
2021-12-05 05:59:28 +01:00
def addwiitem ( folder_uid = None ) :
assert folder_uid is None or folder_uid in vars . wifolders_d
ob = { " key " : " " , " keysecondary " : " " , " content " : " " , " comment " : " " , " folder " : folder_uid , " num " : len ( vars . worldinfo ) , " init " : False , " selective " : False , " constant " : False }
vars . worldinfo . append ( ob )
2021-12-11 01:45:57 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2021-12-12 01:11:38 +01:00
if ( folder_uid is not None ) :
vars . wifolders_u [ folder_uid ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
emit ( ' from_server ' , { ' cmd ' : ' addwiitem ' , ' data ' : ob } , broadcast = True )
2021-12-05 05:59:28 +01:00
#==================================================================#
# Creates a new WI folder with an unused cryptographically secure random UID
#==================================================================#
def addwifolder ( ) :
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . wifolders_d ) :
break
ob = { " name " : " " , " collapsed " : False }
vars . wifolders_d [ uid ] = ob
vars . wifolders_l . append ( uid )
2021-12-11 01:45:57 +01:00
vars . wifolders_u [ uid ] = [ ]
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' addwifolder ' , ' uid ' : uid , ' data ' : ob } , broadcast = True )
addwiitem ( folder_uid = uid )
#==================================================================#
2022-01-01 03:22:51 +01:00
# Move the WI entry with UID src so that it immediately precedes
# the WI entry with UID dst
2021-12-05 05:59:28 +01:00
#==================================================================#
def movewiitem ( dst , src ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2022-01-01 03:22:51 +01:00
if ( vars . worldinfo_u [ src ] [ " folder " ] is not None ) :
for i , e in enumerate ( vars . wifolders_u [ vars . worldinfo_u [ src ] [ " folder " ] ] ) :
if ( e is vars . worldinfo_u [ src ] ) :
vars . wifolders_u [ vars . worldinfo_u [ src ] [ " folder " ] ] . pop ( i )
2021-12-12 01:11:38 +01:00
break
2022-01-01 03:22:51 +01:00
if ( vars . worldinfo_u [ dst ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo_u [ dst ] [ " folder " ] ] . append ( vars . worldinfo_u [ src ] )
vars . worldinfo_u [ src ] [ " folder " ] = vars . worldinfo_u [ dst ] [ " folder " ]
for i , e in enumerate ( vars . worldinfo ) :
if ( e is vars . worldinfo_u [ src ] ) :
_src = i
elif ( e is vars . worldinfo_u [ dst ] ) :
_dst = i
vars . worldinfo . insert ( _dst - ( _dst > = _src ) , vars . worldinfo . pop ( _src ) )
2021-12-05 05:59:28 +01:00
sendwi ( )
#==================================================================#
# Move the WI folder with UID src so that it immediately precedes
# the WI folder with UID dst
#==================================================================#
def movewifolder ( dst , src ) :
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
vars . wifolders_l . remove ( src )
if ( dst is None ) :
# If dst is None, that means we should move src to be the last folder
vars . wifolders_l . append ( src )
else :
vars . wifolders_l . insert ( vars . wifolders_l . index ( dst ) , src )
sendwi ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
def sendwi ( ) :
# Cache len of WI
ln = len ( vars . worldinfo )
2021-12-05 05:59:28 +01:00
2021-11-14 03:13:52 +01:00
# Clear contents of WI container
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wistart ' , ' wifolders_d ' : vars . wifolders_d , ' wifolders_l ' : vars . wifolders_l , ' data ' : ' ' } , broadcast = True )
# Stable-sort WI entries in order of folder
stablesortwi ( )
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-11-14 03:13:52 +01:00
# If there are no WI entries, send an empty WI object
if ( ln == 0 ) :
addwiitem ( )
else :
# Send contents of WI array
2021-12-05 05:59:28 +01:00
last_folder = . . .
2021-11-14 03:13:52 +01:00
for wi in vars . worldinfo :
2021-12-05 05:59:28 +01:00
if ( wi [ " folder " ] != last_folder ) :
emit ( ' from_server ' , { ' cmd ' : ' addwifolder ' , ' uid ' : wi [ " folder " ] , ' data ' : vars . wifolders_d [ wi [ " folder " ] ] if wi [ " folder " ] is not None else None } , broadcast = True )
last_folder = wi [ " folder " ]
2021-11-14 03:13:52 +01:00
ob = wi
emit ( ' from_server ' , { ' cmd ' : ' addwiitem ' , ' data ' : ob } , broadcast = True )
2021-12-05 05:59:28 +01:00
emit ( ' from_server ' , { ' cmd ' : ' wifinish ' , ' data ' : ' ' } , broadcast = True )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Request current contents of all WI HTML elements
#==================================================================#
def requestwi ( ) :
list = [ ]
for wi in vars . worldinfo :
list . append ( wi [ " num " ] )
emit ( ' from_server ' , { ' cmd ' : ' requestwiitem ' , ' data ' : list } )
2021-12-05 05:59:28 +01:00
#==================================================================#
# Stable-sort WI items so that items in the same folder are adjacent,
# and items in different folders are sorted based on the order of the folders
#==================================================================#
def stablesortwi ( ) :
mapping = { uid : index for index , uid in enumerate ( vars . wifolders_l ) }
vars . worldinfo . sort ( key = lambda x : mapping [ x [ " folder " ] ] if x [ " folder " ] is not None else float ( " inf " ) )
last_folder = . . .
last_wi = None
2021-12-11 01:45:57 +01:00
for i , wi in enumerate ( vars . worldinfo ) :
wi [ " num " ] = i
2021-12-05 05:59:28 +01:00
wi [ " init " ] = True
if ( wi [ " folder " ] != last_folder ) :
if ( last_wi is not None and last_folder is not . . . ) :
last_wi [ " init " ] = False
last_folder = wi [ " folder " ]
last_wi = wi
2021-12-05 17:39:59 +01:00
if ( last_wi is not None ) :
2021-12-05 05:59:28 +01:00
last_wi [ " init " ] = False
2021-12-11 01:45:57 +01:00
for folder in vars . wifolders_u :
2021-12-12 01:11:38 +01:00
vars . wifolders_u [ folder ] . sort ( key = lambda x : x [ " num " ] )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Extract object from server and send it to WI objects
#==================================================================#
def commitwi ( ar ) :
for ob in ar :
2022-01-01 03:22:51 +01:00
ob [ " uid " ] = int ( ob [ " uid " ] )
vars . worldinfo_u [ ob [ " uid " ] ] [ " key " ] = ob [ " key " ]
vars . worldinfo_u [ ob [ " uid " ] ] [ " keysecondary " ] = ob [ " keysecondary " ]
vars . worldinfo_u [ ob [ " uid " ] ] [ " content " ] = ob [ " content " ]
vars . worldinfo_u [ ob [ " uid " ] ] [ " comment " ] = ob . get ( " comment " , " " )
vars . worldinfo_u [ ob [ " uid " ] ] [ " folder " ] = ob . get ( " folder " , None )
vars . worldinfo_u [ ob [ " uid " ] ] [ " selective " ] = ob [ " selective " ]
vars . worldinfo_u [ ob [ " uid " ] ] [ " constant " ] = ob . get ( " constant " , False )
stablesortwi ( )
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-11-14 03:13:52 +01:00
#==================================================================#
#
#==================================================================#
2022-01-01 03:22:51 +01:00
def deletewi ( uid ) :
if ( uid in vars . worldinfo_u ) :
2022-01-19 01:36:20 +01:00
setgamesaved ( False )
2022-01-01 03:22:51 +01:00
# Store UID of deletion request
vars . deletewi = uid
if ( vars . deletewi is not None ) :
if ( vars . worldinfo_u [ vars . deletewi ] [ " folder " ] is not None ) :
for i , e in enumerate ( vars . wifolders_u [ vars . worldinfo_u [ vars . deletewi ] [ " folder " ] ] ) :
if ( e is vars . worldinfo_u [ vars . deletewi ] ) :
vars . wifolders_u [ vars . worldinfo_u [ vars . deletewi ] [ " folder " ] ] . pop ( i )
for i , e in enumerate ( vars . worldinfo ) :
if ( e is vars . worldinfo_u [ vars . deletewi ] ) :
del vars . worldinfo [ i ]
break
del vars . worldinfo_u [ vars . deletewi ]
# Send the new WI array structure
sendwi ( )
# And reset deletewi
vars . deletewi = None
2021-11-14 03:13:52 +01:00
2021-12-05 05:59:28 +01:00
#==================================================================#
#
#==================================================================#
def deletewifolder ( uid ) :
uid = int ( uid )
2021-12-11 01:45:57 +01:00
del vars . wifolders_u [ uid ]
2021-12-05 05:59:28 +01:00
del vars . wifolders_d [ uid ]
del vars . wifolders_l [ vars . wifolders_l . index ( uid ) ]
2022-01-19 01:36:20 +01:00
setgamesaved ( False )
2021-12-05 05:59:28 +01:00
# Delete uninitialized entries in the folder we're going to delete
vars . worldinfo = [ wi for wi in vars . worldinfo if wi [ " folder " ] != uid or wi [ " init " ] ]
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-12-05 05:59:28 +01:00
# Move WI entries that are inside of the folder we're going to delete
# so that they're outside of all folders
for wi in vars . worldinfo :
if ( wi [ " folder " ] == uid ) :
wi [ " folder " ] = None
sendwi ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
2022-01-10 21:52:49 +01:00
def checkworldinfo ( txt , allowed_entries = None , allowed_folders = None , force_use_txt = False , scan_story = True , actions = None ) :
2021-11-14 03:13:52 +01:00
original_txt = txt
2022-01-10 21:52:49 +01:00
if ( actions is None ) :
actions = vars . actions
2021-11-14 03:13:52 +01:00
# Dont go any further if WI is empty
if ( len ( vars . worldinfo ) == 0 ) :
return " " , set ( )
# Cache actions length
2022-01-10 21:52:49 +01:00
ln = len ( actions )
2021-11-14 03:13:52 +01:00
# Don't bother calculating action history if widepth is 0
2022-01-05 01:26:59 +01:00
if ( vars . widepth > 0 and scan_story ) :
2021-11-14 03:13:52 +01:00
depth = vars . widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if ( not force_use_txt and ( txt != " " and vars . prompt != txt ) ) :
txt = " "
depth + = 1
if ( ln > 0 ) :
chunks = collections . deque ( )
i = 0
2022-01-10 21:52:49 +01:00
for key in reversed ( actions ) :
chunk = actions [ key ]
2021-11-14 03:13:52 +01:00
chunks . appendleft ( chunk )
i + = 1
if ( i == depth ) :
break
if ( ln > = depth ) :
txt = " " . join ( chunks )
elif ( ln > 0 ) :
2021-11-21 04:23:06 +01:00
txt = vars . comregex_ai . sub ( ' ' , vars . prompt ) + " " . join ( chunks )
2021-11-14 03:13:52 +01:00
elif ( ln == 0 ) :
2021-11-21 04:23:06 +01:00
txt = vars . comregex_ai . sub ( ' ' , vars . prompt )
2021-11-14 03:13:52 +01:00
if ( force_use_txt ) :
txt + = original_txt
# Scan text for matches on WI keys
wimem = " "
found_entries = set ( )
for wi in vars . worldinfo :
2021-12-20 02:18:28 +01:00
if ( allowed_entries is not None and wi [ " uid " ] not in allowed_entries ) :
continue
if ( allowed_folders is not None and wi [ " folder " ] not in allowed_folders ) :
continue
2021-11-14 03:13:52 +01:00
if ( wi . get ( " constant " , False ) ) :
wimem = wimem + wi [ " content " ] + " \n "
found_entries . add ( id ( wi ) )
continue
2022-01-05 01:26:59 +01:00
if ( len ( wi [ " key " ] . strip ( ) ) > 0 and ( not wi . get ( " selective " , False ) or len ( wi . get ( " keysecondary " , " " ) . strip ( ) ) > 0 ) ) :
2021-11-14 03:13:52 +01:00
# Split comma-separated keys
keys = wi [ " key " ] . split ( " , " )
keys_secondary = wi . get ( " keysecondary " , " " ) . split ( " , " )
for k in keys :
ky = k
# Remove leading/trailing spaces if the option is enabled
if ( vars . wirmvwhtsp ) :
ky = k . strip ( )
if ky in txt :
if wi . get ( " selective " , False ) and len ( keys_secondary ) :
found = False
for ks in keys_secondary :
ksy = ks
if ( vars . wirmvwhtsp ) :
ksy = ks . strip ( )
if ksy in txt :
wimem = wimem + wi [ " content " ] + " \n "
found_entries . add ( id ( wi ) )
found = True
break
if found :
break
else :
wimem = wimem + wi [ " content " ] + " \n "
found_entries . add ( id ( wi ) )
break
return wimem , found_entries
#==================================================================#
# Commit changes to Memory storage
#==================================================================#
def memsubmit ( data ) :
2022-01-17 18:11:06 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setinputtext ' , ' data ' : data } , broadcast = True )
2021-11-14 03:13:52 +01:00
# Maybe check for length at some point
# For now just send it to storage
2022-01-18 23:20:45 +01:00
if ( data != vars . memory ) :
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . memory = data
vars . mode = " play "
emit ( ' from_server ' , { ' cmd ' : ' memmode ' , ' data ' : ' false ' } , broadcast = True )
# Ask for contents of Author's Note field
emit ( ' from_server ' , { ' cmd ' : ' getanote ' , ' data ' : ' ' } )
#==================================================================#
# Commit changes to Author's Note
#==================================================================#
2021-12-30 05:43:36 +01:00
def anotesubmit ( data , template = " " ) :
assert type ( data ) is str and type ( template ) is str
2021-11-14 03:13:52 +01:00
# Maybe check for length at some point
# For now just send it to storage
2022-01-18 23:20:45 +01:00
if ( data != vars . authornote ) :
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
vars . authornote = data
2021-12-30 05:43:36 +01:00
if ( vars . authornotetemplate != template ) :
vars . setauthornotetemplate = template
settingschanged ( )
vars . authornotetemplate = template
2022-01-17 18:11:06 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Assembles game data into a request to InferKit API
#==================================================================#
def ikrequest ( txt ) :
# Log request to console
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} Len: {1} , Txt: {2} {3} " . format ( colors . YELLOW , len ( txt ) , txt , colors . END ) )
2021-11-14 03:13:52 +01:00
# Build request JSON data
reqdata = {
' forceNoEnd ' : True ,
' length ' : vars . ikgen ,
' prompt ' : {
' isContinuation ' : False ,
' text ' : txt
} ,
' startFromBeginning ' : False ,
' streamResponse ' : False ,
' temperature ' : vars . temp ,
' topP ' : vars . top_p
}
# Create request
req = requests . post (
vars . url ,
json = reqdata ,
headers = {
' Authorization ' : ' Bearer ' + vars . apikey
}
)
# Deal with the response
if ( req . status_code == 200 ) :
genout = req . json ( ) [ " data " ] [ " text " ]
2021-12-13 01:27:20 +01:00
vars . lua_koboldbridge . outputs [ 1 ] = genout
2021-12-12 07:52:42 +01:00
execute_outmod ( )
2021-12-13 01:27:20 +01:00
if ( vars . lua_koboldbridge . regeneration_required ) :
vars . lua_koboldbridge . regeneration_required = False
genout = vars . lua_koboldbridge . outputs [ 1 ]
assert genout is str
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} {1} {2} " . format ( colors . CYAN , genout , colors . END ) )
2021-11-14 03:13:52 +01:00
vars . actions . append ( genout )
2022-03-04 20:14:44 +01:00
if vars . actions . get_last_key ( ) in vars . actions_metadata :
vars . actions_metadata [ vars . actions . get_last_key ( ) ] = { " Selected Text " : genout , " Alternative Text " : [ ] }
2022-01-20 21:18:43 +01:00
else :
# 2. We've selected a chunk of text that is was presented previously
2022-03-04 20:14:44 +01:00
alternatives = [ item [ ' Text ' ] for item in vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] ]
2022-01-20 21:18:43 +01:00
if genout in alternatives :
2022-03-04 20:14:44 +01:00
alternatives = [ item for item in vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] if item [ ' Text ' ] != genout ]
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] = alternatives
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Selected Text " ] = genout
2021-11-14 03:13:52 +01:00
update_story_chunk ( ' last ' )
2021-12-29 20:23:22 +01:00
emit ( ' from_server ' , { ' cmd ' : ' texteffect ' , ' data ' : vars . actions . get_last_key ( ) + 1 if len ( vars . actions ) else 0 } , broadcast = True )
2022-01-24 18:54:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
set_aibusy ( 0 )
else :
# Send error message to web client
er = req . json ( )
if ( " error " in er ) :
code = er [ " error " ] [ " extensions " ] [ " code " ]
elif ( " errors " in er ) :
code = er [ " errors " ] [ 0 ] [ " extensions " ] [ " code " ]
errmsg = " InferKit API Error: {0} - {1} " . format ( req . status_code , code )
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : errmsg } , broadcast = True )
set_aibusy ( 0 )
#==================================================================#
# Assembles game data into a request to OpenAI API
#==================================================================#
def oairequest ( txt , min , max ) :
# Log request to console
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} Len: {1} , Txt: {2} {3} " . format ( colors . YELLOW , len ( txt ) , txt , colors . END ) )
2021-11-14 03:13:52 +01:00
# Store context in memory to use it for comparison with generated content
vars . lastctx = txt
# Build request JSON data
2022-03-28 02:02:37 +02:00
if ' GooseAI ' in args . configname :
reqdata = {
' prompt ' : txt ,
2022-04-07 02:39:37 +02:00
' max_tokens ' : vars . genamt ,
2022-03-28 02:02:37 +02:00
' temperature ' : vars . temp ,
' top_p ' : vars . top_p ,
' top_k ' : vars . top_k ,
' tfs ' : vars . tfs ,
' typical_p ' : vars . typical ,
' repetition_penalty ' : vars . rep_pen ,
' repetition_penalty_slope ' : vars . rep_pen_slope ,
' repetition_penalty_range ' : vars . rep_pen_range ,
2022-04-07 20:50:23 +02:00
' n ' : vars . numseqs ,
2022-03-28 02:02:37 +02:00
' stream ' : False
}
else :
reqdata = {
' prompt ' : txt ,
2022-04-07 02:39:37 +02:00
' max_tokens ' : vars . genamt ,
2022-03-28 02:02:37 +02:00
' temperature ' : vars . temp ,
' top_p ' : vars . top_p ,
2022-04-07 20:50:23 +02:00
' n ' : vars . numseqs ,
2022-03-28 02:02:37 +02:00
' stream ' : False
}
2021-11-14 03:13:52 +01:00
req = requests . post (
vars . oaiurl ,
json = reqdata ,
headers = {
' Authorization ' : ' Bearer ' + vars . oaiapikey ,
' Content-Type ' : ' application/json '
}
)
# Deal with the response
if ( req . status_code == 200 ) :
2022-04-07 20:50:23 +02:00
outputs = [ out [ " text " ] for out in req . json ( ) [ " choices " ] ]
2021-12-13 01:27:20 +01:00
2022-04-07 20:50:23 +02:00
for idx in range ( len ( outputs ) ) :
vars . lua_koboldbridge . outputs [ idx + 1 ] = outputs [ idx ]
2021-12-13 01:27:20 +01:00
2021-12-12 07:52:42 +01:00
execute_outmod ( )
2022-04-07 20:50:23 +02:00
if ( vars . lua_koboldbridge . regeneration_required ) :
2021-12-13 01:27:20 +01:00
vars . lua_koboldbridge . regeneration_required = False
2022-04-07 20:50:23 +02:00
genout = [ ]
for i in range ( len ( outputs ) ) :
genout . append (
{ " generated_text " : vars . lua_koboldbridge . outputs [ i + 1 ] } )
assert type ( genout [ - 1 ] [ " generated_text " ] ) is str
else :
genout = [
{ " generated_text " : utils . decodenewlines ( txt ) }
for txt in outputs ]
2021-12-13 01:27:20 +01:00
2022-04-06 04:45:01 +02:00
if vars . actions . get_last_key ( ) not in vars . actions_metadata :
2022-04-07 20:50:23 +02:00
vars . actions_metadata [ vars . actions . get_last_key ( ) ] = {
" Selected Text " : genout [ 0 ] , " Alternative Text " : [ ] }
2022-01-20 21:18:43 +01:00
else :
# 2. We've selected a chunk of text that is was presented previously
2022-04-06 04:45:01 +02:00
try :
alternatives = [ item [ ' Text ' ] for item in vars . actions_metadata [ len ( vars . actions ) - 1 ] [ " Alternative Text " ] ]
except :
print ( len ( vars . actions ) )
print ( vars . actions_metadata )
raise
2022-01-20 21:18:43 +01:00
if genout in alternatives :
2022-03-04 20:14:44 +01:00
alternatives = [ item for item in vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] if item [ ' Text ' ] != genout ]
2022-04-06 04:45:01 +02:00
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Alternative Text " ] = alternatives
vars . actions_metadata [ vars . actions . get_last_key ( ) ] [ " Selected Text " ] = genout
2022-04-07 20:50:23 +02:00
if ( len ( genout ) == 1 ) :
genresult ( genout [ 0 ] [ " generated_text " ] )
else :
if ( vars . lua_koboldbridge . restart_sequence is not None and
vars . lua_koboldbridge . restart_sequence > 0 ) :
genresult ( genout [ vars . lua_koboldbridge . restart_sequence - 1 ] [
" generated_text " ] )
else :
genselect ( genout )
if not vars . quiet :
print ( " {0} {1} {2} " . format ( colors . CYAN , genout , colors . END ) )
2021-11-14 03:13:52 +01:00
set_aibusy ( 0 )
else :
# Send error message to web client
er = req . json ( )
if ( " error " in er ) :
type = er [ " error " ] [ " type " ]
message = er [ " error " ] [ " message " ]
errmsg = " OpenAI API Error: {0} - {1} " . format ( type , message )
emit ( ' from_server ' , { ' cmd ' : ' errmsg ' , ' data ' : errmsg } , broadcast = True )
set_aibusy ( 0 )
#==================================================================#
# Forces UI to Play mode
#==================================================================#
def exitModes ( ) :
if ( vars . mode == " edit " ) :
emit ( ' from_server ' , { ' cmd ' : ' editmode ' , ' data ' : ' false ' } , broadcast = True )
elif ( vars . mode == " memory " ) :
emit ( ' from_server ' , { ' cmd ' : ' memmode ' , ' data ' : ' false ' } , broadcast = True )
elif ( vars . mode == " wi " ) :
emit ( ' from_server ' , { ' cmd ' : ' wimode ' , ' data ' : ' false ' } , broadcast = True )
vars . mode = " play "
#==================================================================#
# Launch in-browser save prompt
#==================================================================#
2022-01-22 01:02:56 +01:00
def saveas ( data ) :
name = data [ ' name ' ]
savepins = data [ ' pins ' ]
2021-11-14 03:13:52 +01:00
# Check if filename exists already
name = utils . cleanfilename ( name )
if ( not fileops . saveexists ( name ) or ( vars . saveow and vars . svowname == name ) ) :
# All clear to save
2022-01-22 01:02:56 +01:00
e = saveRequest ( fileops . storypath ( name ) , savepins = savepins )
2021-11-14 03:13:52 +01:00
vars . saveow = False
vars . svowname = " "
if ( e is None ) :
emit ( ' from_server ' , { ' cmd ' : ' hidesaveas ' , ' data ' : ' ' } )
else :
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) , colors . END ) )
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : str ( e ) } )
else :
# File exists, prompt for overwrite
vars . saveow = True
vars . svowname = name
emit ( ' from_server ' , { ' cmd ' : ' askforoverwrite ' , ' data ' : ' ' } )
#==================================================================#
# Launch in-browser story-delete prompt
#==================================================================#
def deletesave ( name ) :
name = utils . cleanfilename ( name )
e = fileops . deletesave ( name )
if ( e is None ) :
if ( vars . smandelete ) :
emit ( ' from_server ' , { ' cmd ' : ' hidepopupdelete ' , ' data ' : ' ' } )
getloadlist ( )
else :
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : " The server denied your request to delete this story " } )
else :
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) , colors . END ) )
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : str ( e ) } )
#==================================================================#
# Launch in-browser story-rename prompt
#==================================================================#
def renamesave ( name , newname ) :
# Check if filename exists already
name = utils . cleanfilename ( name )
newname = utils . cleanfilename ( newname )
if ( not fileops . saveexists ( newname ) or name == newname or ( vars . saveow and vars . svowname == newname ) ) :
e = fileops . renamesave ( name , newname )
vars . saveow = False
vars . svowname = " "
if ( e is None ) :
if ( vars . smanrename ) :
emit ( ' from_server ' , { ' cmd ' : ' hidepopuprename ' , ' data ' : ' ' } )
getloadlist ( )
else :
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : " The server denied your request to rename this story " } )
else :
print ( " {0} {1} {2} " . format ( colors . RED , str ( e ) , colors . END ) )
emit ( ' from_server ' , { ' cmd ' : ' popuperror ' , ' data ' : str ( e ) } )
else :
# File exists, prompt for overwrite
vars . saveow = True
vars . svowname = newname
emit ( ' from_server ' , { ' cmd ' : ' askforoverwrite ' , ' data ' : ' ' } )
#==================================================================#
# Save the currently running story
#==================================================================#
def save ( ) :
# Check if a file is currently open
if ( " .json " in vars . savedir ) :
saveRequest ( vars . savedir )
else :
emit ( ' from_server ' , { ' cmd ' : ' saveas ' , ' data ' : ' ' } )
#==================================================================#
# Save the story via file browser
#==================================================================#
def savetofile ( ) :
savpath = fileops . getsavepath ( vars . savedir , " Save Story As " , [ ( " Json " , " *.json " ) ] )
saveRequest ( savpath )
#==================================================================#
# Save the story to specified path
#==================================================================#
2022-01-22 01:02:56 +01:00
def saveRequest ( savpath , savepins = True ) :
2021-11-14 03:13:52 +01:00
if ( savpath ) :
# Leave Edit/Memory mode before continuing
exitModes ( )
# Save path for future saves
vars . savedir = savpath
txtpath = os . path . splitext ( savpath ) [ 0 ] + " .txt "
# Build json to write
js = { }
js [ " gamestarted " ] = vars . gamestarted
js [ " prompt " ] = vars . prompt
js [ " memory " ] = vars . memory
js [ " authorsnote " ] = vars . authornote
2021-12-30 05:43:36 +01:00
js [ " anotetemplate " ] = vars . authornotetemplate
2021-11-14 03:13:52 +01:00
js [ " actions " ] = tuple ( vars . actions . values ( ) )
2022-01-22 01:02:56 +01:00
if savepins :
js [ " actions_metadata " ] = vars . actions_metadata
2021-11-14 03:13:52 +01:00
js [ " worldinfo " ] = [ ]
2021-12-05 05:59:28 +01:00
js [ " wifolders_d " ] = vars . wifolders_d
js [ " wifolders_l " ] = vars . wifolders_l
2021-11-14 03:13:52 +01:00
# Extract only the important bits of WI
2022-01-10 22:36:36 +01:00
for wi in vars . worldinfo_i :
if ( True ) :
2021-11-14 03:13:52 +01:00
js [ " worldinfo " ] . append ( {
" key " : wi [ " key " ] ,
" keysecondary " : wi [ " keysecondary " ] ,
" content " : wi [ " content " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi [ " comment " ] ,
" folder " : wi [ " folder " ] ,
2021-11-14 03:13:52 +01:00
" selective " : wi [ " selective " ] ,
" constant " : wi [ " constant " ]
} )
txt = vars . prompt + " " . join ( vars . actions . values ( ) )
# Write it
try :
file = open ( savpath , " w " )
except Exception as e :
return e
try :
file . write ( json . dumps ( js , indent = 3 ) )
except Exception as e :
file . close ( )
return e
file . close ( )
try :
file = open ( txtpath , " w " )
except Exception as e :
return e
try :
file . write ( txt )
except Exception as e :
file . close ( )
return e
file . close ( )
filename = path . basename ( savpath )
if ( filename . endswith ( ' .json ' ) ) :
filename = filename [ : - 5 ]
vars . laststory = filename
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } , broadcast = True )
2022-01-18 23:20:45 +01:00
setgamesaved ( True )
2021-11-14 03:13:52 +01:00
print ( " {0} Story saved to {1} ! {2} " . format ( colors . GREEN , path . basename ( savpath ) , colors . END ) )
#==================================================================#
# Show list of saved stories
#==================================================================#
def getloadlist ( ) :
emit ( ' from_server ' , { ' cmd ' : ' buildload ' , ' data ' : fileops . getstoryfiles ( ) } )
#==================================================================#
# Show list of soft prompts
#==================================================================#
def getsplist ( ) :
if ( vars . allowsp ) :
emit ( ' from_server ' , { ' cmd ' : ' buildsp ' , ' data ' : fileops . getspfiles ( vars . modeldim ) } )
2021-12-13 07:03:26 +01:00
#==================================================================#
2021-12-23 05:33:27 +01:00
# Get list of userscripts
2021-12-13 07:03:26 +01:00
#==================================================================#
def getuslist ( ) :
files = { i : v for i , v in enumerate ( fileops . getusfiles ( ) ) }
loaded = [ ]
unloaded = [ ]
userscripts = set ( vars . userscripts )
for i in range ( len ( files ) ) :
2021-12-23 05:33:27 +01:00
if files [ i ] [ " filename " ] not in userscripts :
2021-12-13 07:03:26 +01:00
unloaded . append ( files [ i ] )
2021-12-23 05:33:27 +01:00
files = { files [ k ] [ " filename " ] : files [ k ] for k in files }
userscripts = set ( files . keys ( ) )
for filename in vars . userscripts :
if filename in userscripts :
loaded . append ( files [ filename ] )
return unloaded , loaded
2021-12-13 07:03:26 +01:00
2021-11-14 03:13:52 +01:00
#==================================================================#
# Load a saved story via file browser
#==================================================================#
def loadfromfile ( ) :
loadpath = fileops . getloadpath ( vars . savedir , " Select Story File " , [ ( " Json " , " *.json " ) ] )
loadRequest ( loadpath )
#==================================================================#
# Load a stored story from a file
#==================================================================#
def loadRequest ( loadpath , filename = None ) :
if ( loadpath ) :
# Leave Edit/Memory mode before continuing
exitModes ( )
# Read file contents into JSON object
if ( isinstance ( loadpath , str ) ) :
with open ( loadpath , " r " ) as file :
js = json . load ( file )
if ( filename is None ) :
filename = path . basename ( loadpath )
else :
js = loadpath
if ( filename is None ) :
filename = " untitled.json "
# Copy file contents to vars
vars . gamestarted = js [ " gamestarted " ]
vars . prompt = js [ " prompt " ]
vars . memory = js [ " memory " ]
vars . worldinfo = [ ]
2021-12-12 01:11:38 +01:00
vars . worldinfo = [ ]
2021-12-11 01:45:57 +01:00
vars . worldinfo_u = { }
2021-12-05 05:59:28 +01:00
vars . wifolders_d = { int ( k ) : v for k , v in js . get ( " wifolders_d " , { } ) . items ( ) }
vars . wifolders_l = js . get ( " wifolders_l " , [ ] )
2021-12-11 01:45:57 +01:00
vars . wifolders_u = { uid : [ ] for uid in vars . wifolders_d }
2021-11-14 03:13:52 +01:00
vars . lastact = " "
2021-12-12 07:52:42 +01:00
vars . submission = " "
2021-11-14 03:13:52 +01:00
vars . lastctx = " "
del vars . actions
vars . actions = structures . KoboldStoryRegister ( )
actions = collections . deque ( js [ " actions " ] )
2022-03-04 20:14:44 +01:00
2021-11-14 03:13:52 +01:00
2022-01-21 21:30:37 +01:00
if " actions_metadata " in js :
2022-02-28 16:39:36 +01:00
2022-02-28 14:31:26 +01:00
if type ( js [ " actions_metadata " ] ) == dict :
2022-02-28 16:39:36 +01:00
temp = js [ " actions_metadata " ]
vars . actions_metadata = { }
#we need to redo the numbering of the actions_metadata since the actions list doesn't preserve it's number on saving
if len ( temp ) > 0 :
counter = 0
2022-03-04 20:14:44 +01:00
temp = { int ( k ) : v for k , v in temp . items ( ) }
for i in range ( max ( temp ) + 1 ) :
2022-02-28 16:39:36 +01:00
if i in temp :
vars . actions_metadata [ counter ] = temp [ i ]
counter + = 1
del temp
else :
#fix if we're using the old metadata format
vars . actions_metadata = { }
i = 0
for text in js [ ' actions ' ] :
vars . actions_metadata [ i ] = { ' Selected Text ' : text , ' Alternative Text ' : [ ] }
i + = 1
2022-01-21 21:30:37 +01:00
else :
2022-02-28 14:31:26 +01:00
vars . actions_metadata = { }
i = 0
for text in js [ ' actions ' ] :
vars . actions_metadata [ i ] = { ' Selected Text ' : text , ' Alternative Text ' : [ ] }
i + = 1
2022-01-21 21:30:37 +01:00
2021-11-14 03:13:52 +01:00
if ( len ( vars . prompt . strip ( ) ) == 0 ) :
while ( len ( actions ) ) :
action = actions . popleft ( )
if ( len ( action . strip ( ) ) != 0 ) :
vars . prompt = action
break
else :
vars . gamestarted = False
if ( vars . gamestarted ) :
for s in actions :
vars . actions . append ( s )
# Try not to break older save files
if ( " authorsnote " in js ) :
vars . authornote = js [ " authorsnote " ]
else :
vars . authornote = " "
2021-12-30 05:43:36 +01:00
if ( " anotetemplate " in js ) :
vars . authornotetemplate = js [ " anotetemplate " ]
else :
2021-12-30 07:48:25 +01:00
vars . authornotetemplate = " [Author ' s note: <|>] "
2021-11-14 03:13:52 +01:00
if ( " worldinfo " in js ) :
num = 0
for wi in js [ " worldinfo " ] :
vars . worldinfo . append ( {
" key " : wi [ " key " ] ,
" keysecondary " : wi . get ( " keysecondary " , " " ) ,
" content " : wi [ " content " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi . get ( " comment " , " " ) ,
" folder " : wi . get ( " folder " , None ) ,
2021-11-14 03:13:52 +01:00
" num " : num ,
" init " : True ,
" selective " : wi . get ( " selective " , False ) ,
2021-12-11 01:45:57 +01:00
" constant " : wi . get ( " constant " , False ) ,
" uid " : None ,
2021-11-14 03:13:52 +01:00
} )
2021-12-11 01:45:57 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
num + = 1
2021-12-05 05:59:28 +01:00
for uid in vars . wifolders_l + [ None ] :
2021-12-11 01:45:57 +01:00
vars . worldinfo . append ( { " key " : " " , " keysecondary " : " " , " content " : " " , " comment " : " " , " folder " : uid , " num " : None , " init " : False , " selective " : False , " constant " : False , " uid " : None } )
2021-12-19 00:00:06 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-12-05 05:59:28 +01:00
stablesortwi ( )
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-12-05 05:59:28 +01:00
2021-11-14 03:13:52 +01:00
# Save path for save button
vars . savedir = loadpath
# Clear loadselect var
vars . loadselect = " "
# Refresh game screen
_filename = filename
if ( filename . endswith ( ' .json ' ) ) :
_filename = filename [ : - 5 ]
vars . laststory = _filename
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } , broadcast = True )
2022-01-18 23:20:45 +01:00
setgamesaved ( True )
2021-11-14 03:13:52 +01:00
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
refresh_story ( )
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' ready ' } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' hidegenseqs ' , ' data ' : ' ' } , broadcast = True )
print ( " {0} Story loaded from {1} ! {2} " . format ( colors . GREEN , filename , colors . END ) )
2022-03-04 20:14:44 +01:00
send_debug ( )
2021-11-14 03:13:52 +01:00
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
#==================================================================#
def importRequest ( ) :
importpath = fileops . getloadpath ( vars . savedir , " Select AID CAT File " , [ ( " Json " , " *.json " ) ] )
if ( importpath ) :
# Leave Edit/Memory mode before continuing
exitModes ( )
# Read file contents into JSON object
file = open ( importpath , " rb " )
vars . importjs = json . load ( file )
# If a bundle file is being imported, select just the Adventures object
if type ( vars . importjs ) is dict and " stories " in vars . importjs :
vars . importjs = vars . importjs [ " stories " ]
# Clear Popup Contents
emit ( ' from_server ' , { ' cmd ' : ' clearpopup ' , ' data ' : ' ' } , broadcast = True )
# Initialize vars
num = 0
vars . importnum = - 1
# Get list of stories
for story in vars . importjs :
ob = { }
ob [ " num " ] = num
if ( story [ " title " ] != " " and story [ " title " ] != None ) :
ob [ " title " ] = story [ " title " ]
else :
ob [ " title " ] = " (No Title) "
if ( story [ " description " ] != " " and story [ " description " ] != None ) :
ob [ " descr " ] = story [ " description " ]
else :
ob [ " descr " ] = " (No Description) "
if ( " actions " in story ) :
ob [ " acts " ] = len ( story [ " actions " ] )
elif ( " actionWindow " in story ) :
ob [ " acts " ] = len ( story [ " actionWindow " ] )
emit ( ' from_server ' , { ' cmd ' : ' addimportline ' , ' data ' : ob } )
num + = 1
# Show Popup
emit ( ' from_server ' , { ' cmd ' : ' popupshow ' , ' data ' : True } )
#==================================================================#
# Import an AIDungon game selected in popup
#==================================================================#
def importgame ( ) :
if ( vars . importnum > = 0 ) :
# Cache reference to selected game
ref = vars . importjs [ vars . importnum ]
# Copy game contents to vars
vars . gamestarted = True
# Support for different versions of export script
if ( " actions " in ref ) :
if ( len ( ref [ " actions " ] ) > 0 ) :
vars . prompt = ref [ " actions " ] [ 0 ] [ " text " ]
else :
vars . prompt = " "
elif ( " actionWindow " in ref ) :
if ( len ( ref [ " actionWindow " ] ) > 0 ) :
vars . prompt = ref [ " actionWindow " ] [ 0 ] [ " text " ]
else :
vars . prompt = " "
else :
vars . prompt = " "
vars . memory = ref [ " memory " ]
vars . authornote = ref [ " authorsNote " ] if type ( ref [ " authorsNote " ] ) is str else " "
2021-12-30 07:48:25 +01:00
vars . authornotetemplate = " [Author ' s note: <|>] "
2021-11-14 03:13:52 +01:00
vars . actions = structures . KoboldStoryRegister ( )
2022-02-28 16:39:36 +01:00
vars . actions_metadata = { }
2021-11-14 03:13:52 +01:00
vars . worldinfo = [ ]
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ ]
2021-12-11 01:45:57 +01:00
vars . worldinfo_u = { }
2021-12-05 05:59:28 +01:00
vars . wifolders_d = { }
vars . wifolders_l = [ ]
2021-12-11 01:45:57 +01:00
vars . wifolders_u = { uid : [ ] for uid in vars . wifolders_d }
2021-11-14 03:13:52 +01:00
vars . lastact = " "
2021-12-12 07:52:42 +01:00
vars . submission = " "
2021-11-14 03:13:52 +01:00
vars . lastctx = " "
# Get all actions except for prompt
if ( " actions " in ref ) :
if ( len ( ref [ " actions " ] ) > 1 ) :
for act in ref [ " actions " ] [ 1 : ] :
vars . actions . append ( act [ " text " ] )
elif ( " actionWindow " in ref ) :
if ( len ( ref [ " actionWindow " ] ) > 1 ) :
for act in ref [ " actionWindow " ] [ 1 : ] :
vars . actions . append ( act [ " text " ] )
# Get just the important parts of world info
if ( ref [ " worldInfo " ] != None ) :
if ( len ( ref [ " worldInfo " ] ) > 1 ) :
num = 0
for wi in ref [ " worldInfo " ] :
vars . worldinfo . append ( {
" key " : wi [ " keys " ] ,
" keysecondary " : wi . get ( " keysecondary " , " " ) ,
" content " : wi [ " entry " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi . get ( " comment " , " " ) ,
" folder " : wi . get ( " folder " , None ) ,
2021-11-14 03:13:52 +01:00
" num " : num ,
" init " : True ,
" selective " : wi . get ( " selective " , False ) ,
2021-12-11 01:45:57 +01:00
" constant " : wi . get ( " constant " , False ) ,
" uid " : None ,
2021-11-14 03:13:52 +01:00
} )
2021-12-11 01:45:57 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] ) is not None :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
num + = 1
2021-12-12 01:11:38 +01:00
for uid in vars . wifolders_l + [ None ] :
vars . worldinfo . append ( { " key " : " " , " keysecondary " : " " , " content " : " " , " comment " : " " , " folder " : uid , " num " : None , " init " : False , " selective " : False , " constant " : False , " uid " : None } )
2021-12-19 00:00:06 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-12-12 01:11:38 +01:00
stablesortwi ( )
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-11-14 03:13:52 +01:00
# Clear import data
vars . importjs = { }
# Reset current save
vars . savedir = getcwd ( ) + " \ stories "
# Refresh game screen
vars . laststory = None
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } , broadcast = True )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
refresh_story ( )
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' ready ' } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' hidegenseqs ' , ' data ' : ' ' } , broadcast = True )
#==================================================================#
# Import an aidg.club prompt and start a new game with it.
#==================================================================#
def importAidgRequest ( id ) :
exitModes ( )
urlformat = " https://prompts.aidg.club/api/ "
req = requests . get ( urlformat + id )
if ( req . status_code == 200 ) :
js = req . json ( )
# Import game state
vars . gamestarted = True
vars . prompt = js [ " promptContent " ]
vars . memory = js [ " memory " ]
vars . authornote = js [ " authorsNote " ]
2021-12-30 07:48:25 +01:00
vars . authornotetemplate = " [Author ' s note: <|>] "
2021-11-14 03:13:52 +01:00
vars . actions = structures . KoboldStoryRegister ( )
2022-02-28 16:39:36 +01:00
vars . actions_metadata = { }
2021-11-14 03:13:52 +01:00
vars . worldinfo = [ ]
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ ]
2021-12-11 01:45:57 +01:00
vars . worldinfo_u = { }
2021-12-05 05:59:28 +01:00
vars . wifolders_d = { }
vars . wifolders_l = [ ]
2021-12-11 01:45:57 +01:00
vars . wifolders_u = { uid : [ ] for uid in vars . wifolders_d }
2021-11-14 03:13:52 +01:00
vars . lastact = " "
2021-12-12 07:52:42 +01:00
vars . submission = " "
2021-11-14 03:13:52 +01:00
vars . lastctx = " "
num = 0
for wi in js [ " worldInfos " ] :
vars . worldinfo . append ( {
" key " : wi [ " keys " ] ,
" keysecondary " : wi . get ( " keysecondary " , " " ) ,
" content " : wi [ " entry " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi . get ( " comment " , " " ) ,
" folder " : wi . get ( " folder " , None ) ,
2021-11-14 03:13:52 +01:00
" num " : num ,
" init " : True ,
" selective " : wi . get ( " selective " , False ) ,
2021-12-11 01:45:57 +01:00
" constant " : wi . get ( " constant " , False ) ,
" uid " : None ,
2021-11-14 03:13:52 +01:00
} )
2021-12-11 01:45:57 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] ) is not None :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
num + = 1
2021-12-12 01:11:38 +01:00
for uid in vars . wifolders_l + [ None ] :
vars . worldinfo . append ( { " key " : " " , " keysecondary " : " " , " content " : " " , " comment " : " " , " folder " : uid , " num " : None , " init " : False , " selective " : False , " constant " : False , " uid " : None } )
2021-12-19 00:00:06 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-12-12 01:11:38 +01:00
stablesortwi ( )
vars . worldinfo_i = [ wi for wi in vars . worldinfo if wi [ " init " ] ]
2021-11-14 03:13:52 +01:00
# Reset current save
vars . savedir = getcwd ( ) + " \ stories "
# Refresh game screen
vars . laststory = None
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } , broadcast = True )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
refresh_story ( )
emit ( ' from_server ' , { ' cmd ' : ' setgamestate ' , ' data ' : ' ready ' } , broadcast = True )
#==================================================================#
# Import World Info JSON file
#==================================================================#
def wiimportrequest ( ) :
importpath = fileops . getloadpath ( vars . savedir , " Select World Info File " , [ ( " Json " , " *.json " ) ] )
if ( importpath ) :
file = open ( importpath , " rb " )
js = json . load ( file )
if ( len ( js ) > 0 ) :
# If the most recent WI entry is blank, remove it.
if ( not vars . worldinfo [ - 1 ] [ " init " ] ) :
del vars . worldinfo [ - 1 ]
# Now grab the new stuff
num = len ( vars . worldinfo )
for wi in js :
vars . worldinfo . append ( {
" key " : wi [ " keys " ] ,
" keysecondary " : wi . get ( " keysecondary " , " " ) ,
" content " : wi [ " entry " ] ,
2021-12-05 05:59:28 +01:00
" comment " : wi . get ( " comment " , " " ) ,
" folder " : wi . get ( " folder " , None ) ,
2021-11-14 03:13:52 +01:00
" num " : num ,
" init " : True ,
" selective " : wi . get ( " selective " , False ) ,
2021-12-11 01:45:57 +01:00
" constant " : wi . get ( " constant " , False ) ,
" uid " : None ,
2021-11-14 03:13:52 +01:00
} )
2021-12-11 01:45:57 +01:00
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] ) is not None :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
num + = 1
2021-12-19 00:00:06 +01:00
for uid in [ None ] :
vars . worldinfo . append ( { " key " : " " , " keysecondary " : " " , " content " : " " , " comment " : " " , " folder " : uid , " num " : None , " init " : False , " selective " : False , " constant " : False , " uid " : None } )
while ( True ) :
uid = int . from_bytes ( os . urandom ( 4 ) , " little " , signed = True )
if ( uid not in vars . worldinfo_u ) :
break
vars . worldinfo_u [ uid ] = vars . worldinfo [ - 1 ]
vars . worldinfo [ - 1 ] [ " uid " ] = uid
2022-01-04 20:13:36 +01:00
if ( vars . worldinfo [ - 1 ] [ " folder " ] is not None ) :
vars . wifolders_u [ vars . worldinfo [ - 1 ] [ " folder " ] ] . append ( vars . worldinfo [ - 1 ] )
2021-11-14 03:13:52 +01:00
2022-01-22 21:30:56 +01:00
if not vars . quiet :
print ( " {0} " . format ( vars . worldinfo [ 0 ] ) )
2021-11-14 03:13:52 +01:00
# Refresh game screen
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-11-14 03:13:52 +01:00
sendwi ( )
#==================================================================#
# Starts a new story
#==================================================================#
def newGameRequest ( ) :
# Leave Edit/Memory mode before continuing
exitModes ( )
# Clear vars values
vars . gamestarted = False
vars . prompt = " "
vars . memory = " "
vars . actions = structures . KoboldStoryRegister ( )
2022-02-28 14:31:26 +01:00
vars . actions_metadata = { }
2021-11-14 03:13:52 +01:00
vars . authornote = " "
2021-12-30 05:43:36 +01:00
vars . authornotetemplate = vars . setauthornotetemplate
2021-11-14 03:13:52 +01:00
vars . worldinfo = [ ]
2021-12-12 01:11:38 +01:00
vars . worldinfo_i = [ ]
2021-12-11 01:45:57 +01:00
vars . worldinfo_u = { }
2021-12-05 05:59:28 +01:00
vars . wifolders_d = { }
vars . wifolders_l = [ ]
2021-11-14 03:13:52 +01:00
vars . lastact = " "
2021-12-12 07:52:42 +01:00
vars . submission = " "
2021-11-14 03:13:52 +01:00
vars . lastctx = " "
# Reset current save
vars . savedir = getcwd ( ) + " \ stories "
# Refresh game screen
vars . laststory = None
emit ( ' from_server ' , { ' cmd ' : ' setstoryname ' , ' data ' : vars . laststory } , broadcast = True )
2022-01-18 23:20:45 +01:00
setgamesaved ( True )
2021-11-14 03:13:52 +01:00
sendwi ( )
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
emit ( ' from_server ' , { ' cmd ' : ' setanote ' , ' data ' : vars . authornote } , broadcast = True )
2021-12-30 05:43:36 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setanotetemplate ' , ' data ' : vars . authornotetemplate } , broadcast = True )
2021-11-14 03:13:52 +01:00
setStartState ( )
2021-12-30 05:15:59 +01:00
def randomGameRequest ( topic , memory = " " ) :
if ( vars . noai ) :
newGameRequest ( )
2022-01-20 21:14:55 +01:00
vars . memory = memory
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
2021-12-30 05:15:59 +01:00
return
2021-12-27 01:52:56 +01:00
vars . recentrng = topic
2022-01-04 20:40:10 +01:00
vars . recentrngm = memory
2021-11-14 03:13:52 +01:00
newGameRequest ( )
2022-01-18 23:20:45 +01:00
setgamesaved ( False )
2021-12-30 07:45:27 +01:00
_memory = memory
2021-12-30 05:15:59 +01:00
if ( len ( memory ) > 0 ) :
2021-12-30 07:45:27 +01:00
_memory = memory . rstrip ( ) + " \n \n "
vars . memory = _memory + " You generate the following " + topic + " story concept : "
2021-12-20 02:18:28 +01:00
vars . lua_koboldbridge . feedback = None
2021-12-27 01:40:20 +01:00
actionsubmit ( " " , force_submit = True , force_prompt_gen = True )
2021-12-30 05:15:59 +01:00
vars . memory = memory
2022-01-20 21:14:55 +01:00
emit ( ' from_server ' , { ' cmd ' : ' setmemory ' , ' data ' : vars . memory } , broadcast = True )
2021-12-30 05:15:59 +01:00
2022-03-07 01:51:35 +01:00
def final_startup ( ) :
# Prevent tokenizer from taking extra time the first time it's used
def __preempt_tokenizer ( ) :
if ( " tokenizer " not in globals ( ) ) :
return
utils . decodenewlines ( tokenizer . decode ( [ 25678 , 559 ] ) )
tokenizer . encode ( utils . encodenewlines ( " eunoia " ) )
threading . Thread ( target = __preempt_tokenizer ) . start ( )
# Load soft prompt specified by the settings file, if applicable
if ( path . exists ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " ) ) :
file = open ( " settings/ " + getmodelname ( ) . replace ( ' / ' , ' _ ' ) + " .settings " , " r " )
js = json . load ( file )
if ( vars . allowsp and " softprompt " in js and type ( js [ " softprompt " ] ) is str and all ( q not in js [ " softprompt " ] for q in ( " .. " , " : " ) ) and ( len ( js [ " softprompt " ] ) == 0 or all ( js [ " softprompt " ] [ 0 ] not in q for q in ( " / " , " \\ " ) ) ) ) :
spRequest ( js [ " softprompt " ] )
else :
vars . spfilename = " "
file . close ( )
# Precompile TPU backend if required
2022-06-06 15:49:46 +02:00
if ( vars . use_colab_tpu or vars . model in ( " TPUMeshTransformerGPTJ " , " TPUMeshTransformerGPTNeoX " ) ) :
2022-03-07 01:51:35 +01:00
soft_tokens = tpumtjgetsofttokens ( )
if ( vars . dynamicscan or ( not vars . nogenmod and vars . has_genmod ) ) :
threading . Thread (
target = tpu_mtj_backend . infer_dynamic ,
args = ( np . tile ( np . uint32 ( ( 23403 , 727 , 20185 ) ) , ( vars . numseqs , 1 ) ) , ) ,
kwargs = {
" soft_embeddings " : vars . sp ,
" soft_tokens " : soft_tokens ,
" gen_len " : 1 ,
" use_callback " : False ,
" numseqs " : vars . numseqs ,
" excluded_world_info " : list ( set ( ) for _ in range ( vars . numseqs ) ) ,
} ,
) . start ( )
else :
threading . Thread (
target = tpu_mtj_backend . infer_static ,
args = ( np . uint32 ( ( 23403 , 727 , 20185 ) ) , ) ,
kwargs = {
" soft_embeddings " : vars . sp ,
" soft_tokens " : soft_tokens ,
" gen_len " : 1 ,
" numseqs " : vars . numseqs ,
} ,
) . start ( )
2022-01-16 05:31:07 +01:00
2022-01-24 18:54:44 +01:00
def send_debug ( ) :
if vars . debug :
debug_info = " "
2022-03-04 21:36:13 +01:00
try :
debug_info = " {} Newline Mode: {} \n " . format ( debug_info , vars . newlinemode )
except :
pass
try :
debug_info = " {} Action Length: {} \n " . format ( debug_info , vars . actions . get_last_key ( ) )
except :
pass
try :
debug_info = " {} Actions Metadata Length: {} \n " . format ( debug_info , max ( vars . actions_metadata ) if len ( vars . actions_metadata ) > 0 else 0 )
except :
pass
try :
2022-03-05 16:31:28 +01:00
debug_info = " {} Actions: {} \n " . format ( debug_info , [ k for k in vars . actions ] )
2022-03-04 21:36:13 +01:00
except :
pass
try :
debug_info = " {} Actions Metadata: {} \n " . format ( debug_info , [ k for k in vars . actions_metadata ] )
except :
pass
try :
debug_info = " {} Last Action: {} \n " . format ( debug_info , vars . actions [ vars . actions . get_last_key ( ) ] )
except :
pass
try :
debug_info = " {} Last Metadata: {} \n " . format ( debug_info , vars . actions_metadata [ max ( vars . actions_metadata ) ] )
except :
pass
2022-01-24 18:54:44 +01:00
emit ( ' from_server ' , { ' cmd ' : ' debug_info ' , ' data ' : debug_info } , broadcast = True )
2022-03-25 22:18:28 +01:00
2021-11-14 03:13:52 +01:00
#==================================================================#
# Final startup commands to launch Flask app
#==================================================================#
2022-01-07 19:47:21 +01:00
print ( " " , end = " " , flush = True )
2021-11-14 03:13:52 +01:00
if __name__ == " __main__ " :
2022-01-07 19:47:21 +01:00
print ( " {0} \n Starting webserver... {1} " . format ( colors . GREEN , colors . END ) , flush = True )
2021-12-14 00:39:33 +01:00
2022-03-07 01:51:35 +01:00
general_startup ( )
2022-06-09 00:42:44 +02:00
patch_transformers ( )
2022-03-07 01:51:35 +01:00
#show_select_model_list()
2022-03-07 02:55:11 +01:00
if vars . model == " " or vars . model is None :
vars . model = " ReadOnly "
2022-03-07 18:33:35 +01:00
load_model ( initial_load = True )
2022-03-07 01:51:35 +01:00
2021-11-14 03:13:52 +01:00
# Start Flask/SocketIO (Blocking, so this must be last method!)
2022-06-06 15:49:46 +02:00
port = args . port if " port " in args and args . port is not None else 5000
2021-11-14 03:13:52 +01:00
2022-04-26 19:58:01 +02:00
#socketio.run(app, host='0.0.0.0', port=port)
2022-02-18 01:08:12 +01:00
if ( vars . host ) :
2022-04-19 13:47:44 +02:00
if ( args . localtunnel ) :
2022-04-19 14:41:21 +02:00
import subprocess , shutil
2022-04-26 19:58:01 +02:00
localtunnel = subprocess . Popen ( [ shutil . which ( ' lt ' ) , ' -p ' , str ( port ) , ' http ' ] , stdout = subprocess . PIPE )
2022-04-19 13:47:44 +02:00
attempts = 0
while attempts < 10 :
try :
cloudflare = str ( localtunnel . stdout . readline ( ) )
cloudflare = ( re . search ( " (?P<url>https?: \ / \ /[^ \ s]+loca.lt) " , cloudflare ) . group ( " url " ) )
break
except :
attempts + = 1
time . sleep ( 3 )
continue
if attempts == 10 :
print ( " LocalTunnel could not be created, falling back to cloudflare... " )
from flask_cloudflared import _run_cloudflared
2022-04-26 19:58:01 +02:00
cloudflare = _run_cloudflared ( port )
2022-04-19 13:47:44 +02:00
elif ( args . ngrok ) :
2021-11-29 18:11:14 +01:00
from flask_ngrok import _run_ngrok
cloudflare = _run_ngrok ( )
2022-02-18 01:08:12 +01:00
elif ( args . remote ) :
2021-11-29 18:11:14 +01:00
from flask_cloudflared import _run_cloudflared
2022-03-25 22:18:28 +01:00
cloudflare = _run_cloudflared ( port )
2022-04-19 13:47:44 +02:00
if ( args . localtunnel or args . ngrok or args . remote ) :
2022-02-18 01:08:12 +01:00
with open ( ' cloudflare.log ' , ' w ' ) as cloudflarelog :
cloudflarelog . write ( " KoboldAI has finished loading and is available at the following link : " + cloudflare )
print ( format ( colors . GREEN ) + " KoboldAI has finished loading and is available at the following link : " + cloudflare + format ( colors . END ) )
else :
2022-04-26 19:58:01 +02:00
print ( " {0} Webserver has started, you can now connect to this machine at port {1} {2} "
. format ( colors . GREEN , port , colors . END ) )
2021-12-13 08:32:09 +01:00
vars . serverstarted = True
2022-03-25 22:18:28 +01:00
socketio . run ( app , host = ' 0.0.0.0 ' , port = port )
2021-11-14 03:13:52 +01:00
else :
import webbrowser
2022-03-25 22:18:28 +01:00
webbrowser . open_new ( ' http://localhost: {0} ' . format ( port ) )
print ( " {0} Server started! \n You may now connect with a browser at http://127.0.0.1: {1} / {2} "
. format ( colors . GREEN , port , colors . END ) )
2021-12-13 08:32:09 +01:00
vars . serverstarted = True
2022-02-18 01:08:12 +01:00
if args . unblock :
2022-03-25 22:18:28 +01:00
socketio . run ( app , port = port , host = ' 0.0.0.0 ' )
2022-01-22 20:47:28 +01:00
else :
2022-03-25 22:18:28 +01:00
socketio . run ( app , port = port )
2022-01-04 23:11:14 +01:00
else :
2022-06-06 15:49:46 +02:00
general_startup ( )
2022-06-09 00:42:44 +02:00
patch_transformers ( )
2022-06-06 15:49:46 +02:00
#show_select_model_list()
if vars . model == " " or vars . model is None :
vars . model = " ReadOnly "
load_model ( initial_load = True )
2022-01-07 19:47:21 +01:00
print ( " {0} \n Server started in WSGI mode! {1} " . format ( colors . GREEN , colors . END ) , flush = True )