Merge pull request #62 from VE-FORBRYDERNE/indicator

Show message when TPU backend is compiling
This commit is contained in:
henk717 2022-01-17 04:02:58 +01:00 committed by GitHub
commit 260f4ffae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 6 deletions

View File

@ -14,6 +14,7 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '50'
from eventlet import tpool from eventlet import tpool
from os import path, getcwd from os import path, getcwd
import time
import re import re
import json import json
import collections import collections
@ -127,6 +128,8 @@ class vars:
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0 generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
userscripts = [] # List of userscripts to load userscripts = [] # List of userscripts to load
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
@ -638,7 +641,7 @@ log.setLevel(logging.ERROR)
# Start flask & SocketIO # Start flask & SocketIO
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="") print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
from flask import Flask, render_template, Response, request from flask import Flask, render_template, Response, request, copy_current_request_context
from flask_socketio import SocketIO, emit from flask_socketio import SocketIO, emit
app = Flask(__name__) app = Flask(__name__)
app.config['SECRET KEY'] = 'secret!' app.config['SECRET KEY'] = 'secret!'
@ -1054,6 +1057,13 @@ else:
break break
return excluded_world_info, regeneration_required, halt return excluded_world_info, regeneration_required, halt
def tpumtjgenerate_compiling_callback() -> None:
print(colors.GREEN + "TPU backend compilation triggered" + colors.END)
vars.compiling = True
def tpumtjgenerate_stopped_compiling_callback() -> None:
vars.compiling = False
# If we're running Colab or OAI, we still need a tokenizer. # If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"): if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
@ -1068,6 +1078,8 @@ else:
import tpu_mtj_backend import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.load_model(vars.custmodpth) tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1645,6 +1657,7 @@ def execute_genmod():
vars.lua_koboldbridge.execute_genmod() vars.lua_koboldbridge.execute_genmod()
def execute_outmod(): def execute_outmod():
emit('from_server', {'cmd': 'hidemsg', 'data': ''}, broadcast=True)
try: try:
tpool.execute(vars.lua_koboldbridge.execute_outmod) tpool.execute(vars.lua_koboldbridge.execute_outmod)
except lupa.LuaError as e: except lupa.LuaError as e:
@ -2251,6 +2264,18 @@ def settingschanged():
#==================================================================# #==================================================================#
# Take input text from SocketIO and decide what to do with it # Take input text from SocketIO and decide what to do with it
#==================================================================# #==================================================================#
def check_for_backend_compilation():
if(vars.checking):
return
vars.checking = True
for _ in range(31):
time.sleep(0.06276680299820175)
if(vars.compiling):
emit('from_server', {'cmd': 'warnmsg', 'data': 'Compiling TPU backend—this usually takes 1–2 minutes...'}, broadcast=True)
break
vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False): def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False):
# Ignore new submissions if the AI is currently busy # Ignore new submissions if the AI is currently busy
if(vars.aibusy): if(vars.aibusy):
@ -2966,6 +2991,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
global past global past
socketio.start_background_task(copy_current_request_context(check_for_backend_compilation))
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
context = np.tile(np.uint32(txt), (vars.numseqs, 1)) context = np.tile(np.uint32(txt), (vars.numseqs, 1))

View File

@ -663,9 +663,9 @@ function showMessage(msg) {
message_text.html(msg); message_text.html(msg);
} }
function errMessage(msg) { function errMessage(msg, type="error") {
message_text.removeClass(); message_text.removeClass();
message_text.addClass("color_red"); message_text.addClass(type == "warn" ? "color_orange" : "color_red");
message_text.html(msg); message_text.html(msg);
} }
@ -1932,7 +1932,12 @@ $(document).ready(function(){
} }
} else if(msg.cmd == "errmsg") { } else if(msg.cmd == "errmsg") {
// Send error message // Send error message
errMessage(msg.data); errMessage(msg.data, "error");
} else if(msg.cmd == "warnmsg") {
// Send warning message
errMessage(msg.data, "warn");
} else if(msg.cmd == "hidemsg") {
hideMessage();
} else if(msg.cmd == "texteffect") { } else if(msg.cmd == "texteffect") {
// Apply color highlight to line of text // Apply color highlight to line of text
newTextHighlight($("#n"+msg.data)) newTextHighlight($("#n"+msg.data))

View File

@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script> <script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script> <script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script> <script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.16.4w"></script> <script src="static/application.js?ver=1.16.4y"></script>
</head> </head>
<body> <body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none"> <input type="file" id="remote-save-select" accept="application/json" style="display:none">

View File

@ -26,6 +26,15 @@ def warper_callback(logits) -> np.array:
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
def started_compiling_callback() -> None:
pass
def stopped_compiling_callback() -> None:
pass
def compiling_callback() -> None:
pass
def show_spinner(): def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')]) bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
@ -358,6 +367,7 @@ class PenalizingCausalTransformer(CausalTransformer):
# Initialize # Initialize
super().__init__(config) super().__init__(config)
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
compiling_callback()
numseqs = numseqs_aux.shape[0] numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write # These are the tokens that we don't want the AI to ever write
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
@ -452,6 +462,7 @@ class PenalizingCausalTransformer(CausalTransformer):
axis_resources={'shard': 'mp', 'batch': 'dp'}, axis_resources={'shard': 'mp', 'batch': 'dp'},
) )
def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
compiling_callback()
numseqs = numseqs_aux.shape[0] numseqs = numseqs_aux.shape[0]
@hk.transform @hk.transform
def generate_initial_inner(context, ctx_length): def generate_initial_inner(context, ctx_length):
@ -552,6 +563,7 @@ class PenalizingCausalTransformer(CausalTransformer):
n_generated = 0 n_generated = 0
regeneration_required = False regeneration_required = False
halt = False halt = False
started_compiling_callback()
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
sample_key = np.asarray(sample_key[0, 0]) sample_key = np.asarray(sample_key[0, 0])
while True: while True:
@ -574,13 +586,15 @@ class PenalizingCausalTransformer(CausalTransformer):
break break
else: else:
break break
stopped_compiling_callback()
return sample_data, n_generated, regeneration_required, halt return sample_data, n_generated, regeneration_required, halt
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
assert not return_logits assert not return_logits
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
self.batch_size = batch_size self.batch_size = batch_size
return self.generate_static_xmap( started_compiling_callback()
result = self.generate_static_xmap(
self.state, self.state,
jnp.array(key.take(batch_size)), jnp.array(key.take(batch_size)),
ctx, ctx,
@ -590,6 +604,8 @@ class PenalizingCausalTransformer(CausalTransformer):
sampler_options, sampler_options,
soft_embeddings, soft_embeddings,
) )
stopped_compiling_callback()
return result
def infer_dynamic( def infer_dynamic(