Gen split progress

This commit is contained in:
somebody
2022-09-14 18:28:37 -05:00
parent 04621ccbbc
commit 807e3e88c2
2 changed files with 58 additions and 3 deletions

View File

@@ -1841,6 +1841,10 @@ def patch_transformers():
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if koboldai_vars.inference_config.do_dynamic_wi:
pass
koboldai_vars.generated_tkns += 1
if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
@@ -1874,8 +1878,9 @@ def patch_transformers():
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
@@ -2606,6 +2611,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
sendsettings()
refresh_settings()
prompto = "What does 1+1 equal?\n"
print("Hehe")
out = raw_generate(prompto, 80)
print(f"{out=}")
#Saving the tokenizer to the KoboldStoryRegister class so we can do token counting on the story data
if 'tokenizer' in [x for x in globals()]:
koboldai_vars.tokenizer = tokenizer
@@ -4619,6 +4629,40 @@ def calcsubmit(txt):
# Send it!
ikrequest(subtxt)
def raw_generate(
prompt: str,
max_length: int,
do_streaming: bool = False,
do_dynamic_wi: bool = False,
):
koboldai_vars.inference_config.do_streaming = do_streaming
koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi
prompt_tokens = tokenizer.encode(prompt)
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
device = "cpu"
if koboldai_vars.hascuda and koboldai_vars.usegpu:
device = koboldai_vars.gpu_device
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
device = breakmodel.primary_device
gen_in = gen_in.to(device)
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
max_length=max_length,
repetition_penalty=1.0,
bad_words_ids=koboldai_vars.badwordsids,
use_cache=True,
)
text_out = tokenizer.decode(genout[0])
return text_out
#==================================================================#
# Send text to generator and deal with output
#==================================================================#

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
from io import BytesIO
from flask import has_request_context
@@ -705,8 +706,8 @@ class user_settings(settings):
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
class system_settings(settings):
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid']
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted']
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'inference_config']
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'inference_config']
settings_name = "system"
def __init__(self, socketio):
self.socketio = socketio
@@ -785,6 +786,16 @@ class system_settings(settings):
self._horde_pid = None
self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost
@dataclass
class _inference_config:
do_streaming: bool = False
# NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world
# info, but also max length, aborting, regeneration requests, etc
# for kobold-rooted stuff. This would be nice to change in the future.
do_dynamic_wi: bool = False
self.inference_config = _inference_config()
def __setattr__(self, name, value):
new_variable = name not in self.__dict__