From ce2e4e1f9ecb9f4c8077ad95a12dff9b9b5da434 Mon Sep 17 00:00:00 2001 From: KoboldAI Dev Date: Sun, 16 May 2021 14:53:19 -0400 Subject: [PATCH] Switched aidg.club import from HTML scrape to API call Added square bracket to bad_words_ids to help suppress AN tag from leaking into generator output Added version number to CSS/JS ref to address browser loading outdated versions from cache --- aiserver.py | 49 ++++++++++++++++++++++---------------------- templates/index.html | 4 ++-- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/aiserver.py b/aiserver.py index af732bb3..199a9582 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1,6 +1,6 @@ #==================================================================# # KoboldAI Client -# Version: Dev-0.1 +# Version: 1.14.0 # By: KoboldAIDev #==================================================================# @@ -680,6 +680,11 @@ def generate(txt, min, max): if(vars.hascuda and vars.usegpu): torch.cuda.empty_cache() + # Suppress Author's Note by flagging square brackets + bad_words = [] + bad_words.append(tokenizer("[", add_prefix_space=True).input_ids) + bad_words.append(tokenizer("[", add_prefix_space=False).input_ids) + # Submit input text to generator genout = generator( txt, @@ -688,7 +693,8 @@ def generate(txt, min, max): max_length=max, repetition_penalty=vars.rep_pen, top_p=vars.top_p, - temperature=vars.temp + temperature=vars.temp, + bad_words_ids=bad_words )[0]["generated_text"] print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) @@ -1284,39 +1290,32 @@ def importgame(): #==================================================================# # Import an aidg.club prompt and start a new game with it. #==================================================================# -def importAidgRequest(id): - import html - import re - +def importAidgRequest(id): exitModes() - urlformat = "https://prompts.aidg.club/" + urlformat = "https://prompts.aidg.club/api/" req = requests.get(urlformat+id) if(req.status_code == 200): - contents = html.unescape(req.text) - title = re.search("

(.*?)

", contents, re.IGNORECASE | re.MULTILINE | re.DOTALL).group(1).strip() + js = req.json() - keys = re.findall("
(.*?)
", contents, re.IGNORECASE | re.MULTILINE | re.DOTALL) - contents = re.findall("(.*?)", contents, re.IGNORECASE | re.MULTILINE | re.DOTALL) - - # Initialize game state + # Import game state vars.gamestarted = True - vars.prompt = "" - vars.memory = "" - vars.authornote = "" + vars.prompt = js["promptContent"] + vars.memory = js["memory"] + vars.authornote = js["authorsNote"] vars.actions = [] vars.worldinfo = [] - for i in range(len(keys)): - if(keys[i] == "Description"): - pass - elif(keys[i] == "Prompt"): - vars.prompt = contents[i] - elif(keys[i] == "Memory"): - vars.memory = contents[i] - elif(keys[i] == "Author's Note"): - vars.authornote = contents[i] + num = 0 + for wi in js["worldInfos"]: + vars.worldinfo.append({ + "key": wi["keys"], + "content": wi["entry"], + "num": num, + "init": True + }) + num += 1 # Refresh game screen sendwi() diff --git a/templates/index.html b/templates/index.html index 3764e8db..b7116126 100644 --- a/templates/index.html +++ b/templates/index.html @@ -6,13 +6,13 @@ - + - +