wip context viewer

This commit is contained in:
somebody
2022-08-25 17:59:22 -05:00
parent b9af9a1669
commit 5052b39c3f

View File

@@ -39,6 +39,8 @@ import traceback
import inspect import inspect
import warnings import warnings
import multiprocessing import multiprocessing
from enum import Enum
from dataclasses import dataclass
from collections.abc import Iterable from collections.abc import Iterable
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
@@ -3970,6 +3972,19 @@ def check_for_backend_compilation():
break break
koboldai_vars.checking = False koboldai_vars.checking = False
class ContextType(Enum):
SOFT_PROMPT = 1
STORY = 2
WORLD_INFO = 3
MEMORY = 4
HEADER = 5
@dataclass
class ContextChunk:
def __init__(self, value, context_type: ContextType) -> None:
self.value = value
self.context_type = context_type
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False): def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
# Ignore new submissions if the AI is currently busy # Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy): if(koboldai_vars.aibusy):
@@ -4226,14 +4241,18 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
mem = koboldai_vars.memory + "\n" mem = koboldai_vars.memory + "\n"
else: else:
mem = koboldai_vars.memory mem = koboldai_vars.memory
if(use_authors_note and koboldai_vars.authornote != ""): if(use_authors_note and koboldai_vars.authornote != ""):
anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote) anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote)
else: else:
anotetxt = "" anotetxt = ""
MIN_STORY_TOKENS = 8 MIN_STORY_TOKENS = 8
story_tokens = [] story_tokens = []
mem_tokens = [] mem_tokens = []
wi_tokens = [] wi_tokens = []
context = []
story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens) story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
budget = lambda: story_budget() + MIN_STORY_TOKENS budget = lambda: story_budget() + MIN_STORY_TOKENS
if budget() < 0: if budget() < 0:
@@ -4241,15 +4260,20 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.", "msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.",
"type": "token_overflow", "type": "token_overflow",
}}), mimetype="application/json", status=500)) }}), mimetype="application/json", status=500))
if use_memory: if use_memory:
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():] mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
if use_world_info: if use_world_info:
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story) world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():] wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
if use_story: if use_story:
if koboldai_vars.useprompt: if koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():]
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
if use_story: if use_story:
for i, action in enumerate(reversed(koboldai_vars.actions.values())): for i, action in enumerate(reversed(koboldai_vars.actions.values())):
if story_budget() <= 0: if story_budget() <= 0:
@@ -4260,6 +4284,23 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
if not koboldai_vars.useprompt: if not koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens
# Context tracker
if koboldai_vars.sp:
context.append(ContextChunk(koboldai_vars.sp, ContextType.SOFT_PROMPT))
if tokenizer._koboldai_header:
context.append(ContextChunk(tokenizer._koboldai_header, ContextType.HEADER))
if mem_tokens:
context.append(ContextChunk(mem_tokens, ContextType.MEMORY))
if wi_tokens:
context.append(ContextChunk(wi_tokens, ContextType.WORLD_INFO))
if story_tokens:
context.append(ContextChunk(story_tokens, ContextType.STORY))
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
assert story_budget() >= 0 assert story_budget() >= 0
minimum = len(tokens) + 1 minimum = len(tokens) + 1
@@ -4270,6 +4311,8 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum) genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
koboldai_vars.context = context
print(context)
return genout return genout
#==================================================================# #==================================================================#