mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
wip context viewer
This commit is contained in:
43
aiserver.py
43
aiserver.py
@@ -39,6 +39,8 @@ import traceback
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
|
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
|
||||||
@@ -3970,6 +3972,19 @@ def check_for_backend_compilation():
|
|||||||
break
|
break
|
||||||
koboldai_vars.checking = False
|
koboldai_vars.checking = False
|
||||||
|
|
||||||
|
class ContextType(Enum):
|
||||||
|
SOFT_PROMPT = 1
|
||||||
|
STORY = 2
|
||||||
|
WORLD_INFO = 3
|
||||||
|
MEMORY = 4
|
||||||
|
HEADER = 5
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextChunk:
|
||||||
|
def __init__(self, value, context_type: ContextType) -> None:
|
||||||
|
self.value = value
|
||||||
|
self.context_type = context_type
|
||||||
|
|
||||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
|
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
|
||||||
# Ignore new submissions if the AI is currently busy
|
# Ignore new submissions if the AI is currently busy
|
||||||
if(koboldai_vars.aibusy):
|
if(koboldai_vars.aibusy):
|
||||||
@@ -4226,14 +4241,18 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
|
|||||||
mem = koboldai_vars.memory + "\n"
|
mem = koboldai_vars.memory + "\n"
|
||||||
else:
|
else:
|
||||||
mem = koboldai_vars.memory
|
mem = koboldai_vars.memory
|
||||||
|
|
||||||
if(use_authors_note and koboldai_vars.authornote != ""):
|
if(use_authors_note and koboldai_vars.authornote != ""):
|
||||||
anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote)
|
anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote)
|
||||||
else:
|
else:
|
||||||
anotetxt = ""
|
anotetxt = ""
|
||||||
|
|
||||||
MIN_STORY_TOKENS = 8
|
MIN_STORY_TOKENS = 8
|
||||||
story_tokens = []
|
story_tokens = []
|
||||||
mem_tokens = []
|
mem_tokens = []
|
||||||
wi_tokens = []
|
wi_tokens = []
|
||||||
|
context = []
|
||||||
|
|
||||||
story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
|
story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
|
||||||
budget = lambda: story_budget() + MIN_STORY_TOKENS
|
budget = lambda: story_budget() + MIN_STORY_TOKENS
|
||||||
if budget() < 0:
|
if budget() < 0:
|
||||||
@@ -4241,15 +4260,20 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
|
|||||||
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.",
|
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.",
|
||||||
"type": "token_overflow",
|
"type": "token_overflow",
|
||||||
}}), mimetype="application/json", status=500))
|
}}), mimetype="application/json", status=500))
|
||||||
|
|
||||||
if use_memory:
|
if use_memory:
|
||||||
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
|
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
|
||||||
|
|
||||||
if use_world_info:
|
if use_world_info:
|
||||||
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
|
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
|
||||||
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
|
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
|
||||||
|
|
||||||
if use_story:
|
if use_story:
|
||||||
if koboldai_vars.useprompt:
|
if koboldai_vars.useprompt:
|
||||||
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():]
|
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():]
|
||||||
|
|
||||||
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
|
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
|
||||||
|
|
||||||
if use_story:
|
if use_story:
|
||||||
for i, action in enumerate(reversed(koboldai_vars.actions.values())):
|
for i, action in enumerate(reversed(koboldai_vars.actions.values())):
|
||||||
if story_budget() <= 0:
|
if story_budget() <= 0:
|
||||||
@@ -4260,6 +4284,23 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
|
|||||||
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
|
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
|
||||||
if not koboldai_vars.useprompt:
|
if not koboldai_vars.useprompt:
|
||||||
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens
|
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens
|
||||||
|
|
||||||
|
# Context tracker
|
||||||
|
if koboldai_vars.sp:
|
||||||
|
context.append(ContextChunk(koboldai_vars.sp, ContextType.SOFT_PROMPT))
|
||||||
|
|
||||||
|
if tokenizer._koboldai_header:
|
||||||
|
context.append(ContextChunk(tokenizer._koboldai_header, ContextType.HEADER))
|
||||||
|
|
||||||
|
if mem_tokens:
|
||||||
|
context.append(ContextChunk(mem_tokens, ContextType.MEMORY))
|
||||||
|
|
||||||
|
if wi_tokens:
|
||||||
|
context.append(ContextChunk(wi_tokens, ContextType.WORLD_INFO))
|
||||||
|
|
||||||
|
if story_tokens:
|
||||||
|
context.append(ContextChunk(story_tokens, ContextType.STORY))
|
||||||
|
|
||||||
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
|
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
|
||||||
assert story_budget() >= 0
|
assert story_budget() >= 0
|
||||||
minimum = len(tokens) + 1
|
minimum = len(tokens) + 1
|
||||||
@@ -4270,6 +4311,8 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
|
|||||||
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
||||||
|
|
||||||
|
koboldai_vars.context = context
|
||||||
|
print(context)
|
||||||
return genout
|
return genout
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user