diff --git a/aiserver.py b/aiserver.py index e8322f4a..469e79b1 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1841,6 +1841,10 @@ def patch_transformers(): scores: torch.FloatTensor, **kwargs, ) -> bool: + + if koboldai_vars.inference_config.do_dynamic_wi: + pass + koboldai_vars.generated_tkns += 1 if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols): raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})") @@ -1874,8 +1878,9 @@ def patch_transformers(): return self.regeneration_required or self.halt old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria def new_get_stopping_criteria(self, *args, **kwargs): - stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) global tokenizer + stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) + self.kai_scanner = DynamicWorldInfoScanCriteria( tokenizer=tokenizer, excluded_world_info=self.kai_scanner_excluded_world_info, @@ -2606,6 +2611,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal sendsettings() refresh_settings() + prompto = "What does 1+1 equal?\n" + print("Hehe") + out = raw_generate(prompto, 80) + print(f"{out=}") + #Saving the tokenizer to the KoboldStoryRegister class so we can do token counting on the story data if 'tokenizer' in [x for x in globals()]: koboldai_vars.tokenizer = tokenizer @@ -4619,6 +4629,40 @@ def calcsubmit(txt): # Send it! ikrequest(subtxt) +def raw_generate( + prompt: str, + max_length: int, + + do_streaming: bool = False, + do_dynamic_wi: bool = False, +): + + koboldai_vars.inference_config.do_streaming = do_streaming + koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi + + prompt_tokens = tokenizer.encode(prompt) + gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] + + device = "cpu" + if koboldai_vars.hascuda and koboldai_vars.usegpu: + device = koboldai_vars.gpu_device + elif koboldai_vars.hascuda and koboldai_vars.breakmodel: + device = breakmodel.primary_device + gen_in = gen_in.to(device) + + with torch.no_grad(): + genout = generator( + gen_in, + do_sample=True, + max_length=max_length, + repetition_penalty=1.0, + bad_words_ids=koboldai_vars.badwordsids, + use_cache=True, + ) + + text_out = tokenizer.decode(genout[0]) + return text_out + #==================================================================# # Send text to generator and deal with output #==================================================================# diff --git a/koboldai_settings.py b/koboldai_settings.py index b29db50f..931039c0 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys from io import BytesIO from flask import has_request_context @@ -705,8 +706,8 @@ class user_settings(settings): process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value) class system_settings(settings): - local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid'] - no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted'] + local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'inference_config'] + no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'inference_config'] settings_name = "system" def __init__(self, socketio): self.socketio = socketio @@ -784,6 +785,16 @@ class system_settings(settings): self.horde_share = False self._horde_pid = None self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost + + @dataclass + class _inference_config: + do_streaming: bool = False + + # NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world + # info, but also max length, aborting, regeneration requests, etc + # for kobold-rooted stuff. This would be nice to change in the future. + do_dynamic_wi: bool = False + self.inference_config = _inference_config() def __setattr__(self, name, value):