From 5052b39c3fd9be7c6b23b635416d481e464bc100 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 25 Aug 2022 17:59:22 -0500 Subject: [PATCH] wip context viewer --- aiserver.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/aiserver.py b/aiserver.py index 3ff03751..5ba01b20 100644 --- a/aiserver.py +++ b/aiserver.py @@ -39,6 +39,8 @@ import traceback import inspect import warnings import multiprocessing +from enum import Enum +from dataclasses import dataclass from collections.abc import Iterable from collections import OrderedDict from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type @@ -3970,6 +3972,19 @@ def check_for_backend_compilation(): break koboldai_vars.checking = False +class ContextType(Enum): + SOFT_PROMPT = 1 + STORY = 2 + WORLD_INFO = 3 + MEMORY = 4 + HEADER = 5 + +@dataclass +class ContextChunk: + def __init__(self, value, context_type: ContextType) -> None: + self.value = value + self.context_type = context_type + def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False): # Ignore new submissions if the AI is currently busy if(koboldai_vars.aibusy): @@ -4226,14 +4241,18 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals mem = koboldai_vars.memory + "\n" else: mem = koboldai_vars.memory + if(use_authors_note and koboldai_vars.authornote != ""): anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote) else: anotetxt = "" + MIN_STORY_TOKENS = 8 story_tokens = [] mem_tokens = [] wi_tokens = [] + context = [] + story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens) budget = lambda: story_budget() + MIN_STORY_TOKENS if budget() < 0: @@ -4241,15 +4260,20 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals "msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.", "type": "token_overflow", }}), mimetype="application/json", status=500)) + if use_memory: mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():] + if use_world_info: world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story) wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():] + if use_story: if koboldai_vars.useprompt: story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens + if use_story: for i, action in enumerate(reversed(koboldai_vars.actions.values())): if story_budget() <= 0: @@ -4260,6 +4284,23 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens if not koboldai_vars.useprompt: story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens + + # Context tracker + if koboldai_vars.sp: + context.append(ContextChunk(koboldai_vars.sp, ContextType.SOFT_PROMPT)) + + if tokenizer._koboldai_header: + context.append(ContextChunk(tokenizer._koboldai_header, ContextType.HEADER)) + + if mem_tokens: + context.append(ContextChunk(mem_tokens, ContextType.MEMORY)) + + if wi_tokens: + context.append(ContextChunk(wi_tokens, ContextType.WORLD_INFO)) + + if story_tokens: + context.append(ContextChunk(story_tokens, ContextType.STORY)) + tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens assert story_budget() >= 0 minimum = len(tokens) + 1 @@ -4270,6 +4311,8 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum) + koboldai_vars.context = context + print(context) return genout #==================================================================#