mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Gen split progress
This commit is contained in:
46
aiserver.py
46
aiserver.py
@@ -1841,6 +1841,10 @@ def patch_transformers():
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
|
||||
if koboldai_vars.inference_config.do_dynamic_wi:
|
||||
pass
|
||||
|
||||
koboldai_vars.generated_tkns += 1
|
||||
if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||
@@ -1874,8 +1878,9 @@ def patch_transformers():
|
||||
return self.regeneration_required or self.halt
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
global tokenizer
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
|
||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
@@ -2606,6 +2611,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
sendsettings()
|
||||
refresh_settings()
|
||||
|
||||
prompto = "What does 1+1 equal?\n"
|
||||
print("Hehe")
|
||||
out = raw_generate(prompto, 80)
|
||||
print(f"{out=}")
|
||||
|
||||
#Saving the tokenizer to the KoboldStoryRegister class so we can do token counting on the story data
|
||||
if 'tokenizer' in [x for x in globals()]:
|
||||
koboldai_vars.tokenizer = tokenizer
|
||||
@@ -4619,6 +4629,40 @@ def calcsubmit(txt):
|
||||
# Send it!
|
||||
ikrequest(subtxt)
|
||||
|
||||
def raw_generate(
|
||||
prompt: str,
|
||||
max_length: int,
|
||||
|
||||
do_streaming: bool = False,
|
||||
do_dynamic_wi: bool = False,
|
||||
):
|
||||
|
||||
koboldai_vars.inference_config.do_streaming = do_streaming
|
||||
koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
|
||||
device = "cpu"
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
device = koboldai_vars.gpu_device
|
||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||
device = breakmodel.primary_device
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=koboldai_vars.badwordsids,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
text_out = tokenizer.decode(genout[0])
|
||||
return text_out
|
||||
|
||||
#==================================================================#
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
|
||||
from io import BytesIO
|
||||
from flask import has_request_context
|
||||
@@ -705,8 +706,8 @@ class user_settings(settings):
|
||||
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
|
||||
|
||||
class system_settings(settings):
|
||||
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid']
|
||||
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted']
|
||||
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'inference_config']
|
||||
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'inference_config']
|
||||
settings_name = "system"
|
||||
def __init__(self, socketio):
|
||||
self.socketio = socketio
|
||||
@@ -784,6 +785,16 @@ class system_settings(settings):
|
||||
self.horde_share = False
|
||||
self._horde_pid = None
|
||||
self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost
|
||||
|
||||
@dataclass
|
||||
class _inference_config:
|
||||
do_streaming: bool = False
|
||||
|
||||
# NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world
|
||||
# info, but also max length, aborting, regeneration requests, etc
|
||||
# for kobold-rooted stuff. This would be nice to change in the future.
|
||||
do_dynamic_wi: bool = False
|
||||
self.inference_config = _inference_config()
|
||||
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
|
Reference in New Issue
Block a user