wip context viewer

This commit is contained in:
somebody
2022-08-25 17:59:22 -05:00
parent b9af9a1669
commit 5052b39c3f

View File

@@ -39,6 +39,8 @@ import traceback
import inspect
import warnings
import multiprocessing
from enum import Enum
from dataclasses import dataclass
from collections.abc import Iterable
from collections import OrderedDict
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
@@ -3970,6 +3972,19 @@ def check_for_backend_compilation():
break
koboldai_vars.checking = False
class ContextType(Enum):
SOFT_PROMPT = 1
STORY = 2
WORLD_INFO = 3
MEMORY = 4
HEADER = 5
@dataclass
class ContextChunk:
def __init__(self, value, context_type: ContextType) -> None:
self.value = value
self.context_type = context_type
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
# Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy):
@@ -4226,14 +4241,18 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
mem = koboldai_vars.memory + "\n"
else:
mem = koboldai_vars.memory
if(use_authors_note and koboldai_vars.authornote != ""):
anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote)
else:
anotetxt = ""
MIN_STORY_TOKENS = 8
story_tokens = []
mem_tokens = []
wi_tokens = []
context = []
story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
budget = lambda: story_budget() + MIN_STORY_TOKENS
if budget() < 0:
@@ -4241,15 +4260,20 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.",
"type": "token_overflow",
}}), mimetype="application/json", status=500))
if use_memory:
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
if use_world_info:
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
if use_story:
if koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():]
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
if use_story:
for i, action in enumerate(reversed(koboldai_vars.actions.values())):
if story_budget() <= 0:
@@ -4260,6 +4284,23 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
if not koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens
# Context tracker
if koboldai_vars.sp:
context.append(ContextChunk(koboldai_vars.sp, ContextType.SOFT_PROMPT))
if tokenizer._koboldai_header:
context.append(ContextChunk(tokenizer._koboldai_header, ContextType.HEADER))
if mem_tokens:
context.append(ContextChunk(mem_tokens, ContextType.MEMORY))
if wi_tokens:
context.append(ContextChunk(wi_tokens, ContextType.WORLD_INFO))
if story_tokens:
context.append(ContextChunk(story_tokens, ContextType.STORY))
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
assert story_budget() >= 0
minimum = len(tokens) + 1
@@ -4270,6 +4311,8 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
koboldai_vars.context = context
print(context)
return genout
#==================================================================#