Merge branch 'UI2' of https://github.com/ebolam/KoboldAI into ui2-chat2-again

This commit is contained in:
somebody
2022-11-09 19:08:39 -06:00
5 changed files with 145 additions and 95 deletions

View File

@@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
return tokenizer
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
# These must be set by wherever they get setup
get_logits_processor: callable
sample: callable
get_stopping_criteria: callable
# We set these automatically
old_get_logits_processor: callable
old_sample: callable
old_get_stopping_criteria: callable
def __enter__(self):
use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor
use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample
use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor
transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample
transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria
#==================================================================#
# Variables & Storage
#==================================================================#
@@ -1910,8 +1938,6 @@ def patch_transformers_download():
def patch_transformers():
global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download()
@@ -1933,7 +1959,6 @@ def patch_transformers():
PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
@@ -1961,7 +1986,6 @@ def patch_transformers():
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1981,7 +2005,6 @@ def patch_transformers():
self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init
@@ -2170,8 +2193,8 @@ def patch_transformers():
processors.append(PhraseBiasLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
@@ -2204,9 +2227,9 @@ def patch_transformers():
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
transformers.generation_utils.GenerationMixin.sample = new_sample
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# Allow bad words filter to ban <|endoftext|> token
import transformers.generation_logits_process
@@ -2374,7 +2397,7 @@ def patch_transformers():
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@@ -2392,7 +2415,7 @@ def patch_transformers():
stopping_criteria.insert(0, token_streamer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria
def reset_model_settings():
koboldai_vars.socketio = socketio
@@ -5395,56 +5418,57 @@ def raw_generate(
result: GenerationResult
time_start = time.time()
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model](
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
else:
# Torch HF
start_time = time.time()
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
start_time = time.time()
result = GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
with use_core_manipulations():
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model](
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
else:
# Torch HF
start_time = time.time()
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
start_time = time.time()
result = GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time))
time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
@@ -8077,8 +8101,8 @@ def file_popup(popup_title, starting_folder, return_event, upload=True, jailed=T
def get_files_folders(starting_folder):
import stat
session['current_folder'] = os.path.abspath(starting_folder).replace("\\", "/")
item_check = session['popup_item_check']
extra_parameter_function = session['extra_parameter_function']
item_check = globals()[session['popup_item_check']] if session['popup_item_check'] is not None else None
extra_parameter_function = globals()[session['extra_parameter_function']] if session['extra_parameter_function'] is not None else None
show_breadcrumbs = session['popup_show_breadcrumbs']
show_hidden = session['popup_show_hidden']
folder_only = session['popup_folder_only']
@@ -8090,7 +8114,7 @@ def get_files_folders(starting_folder):
sort = session['sort']
desc = session['desc']
show_folders = session['show_folders']
advanced_sort = session['advanced_sort']
advanced_sort = globals()[session['advanced_sort']] if session['advanced_sort'] is not None else None
if starting_folder == 'This PC':
breadcrumbs = [['This PC', 'This PC']]
@@ -8444,10 +8468,10 @@ def UI_2_load_model(data):
@logger.catch
def UI_2_load_story_list(data):
file_popup("Select Story to Load", "./stories", "load_story", upload=True, jailed=True, folder_only=False, renameable=True,
deleteable=True, show_breadcrumbs=True, item_check=valid_story,
valid_only=True, hide_extention=True, extra_parameter_function=get_story_listing_data,
deleteable=True, show_breadcrumbs=True, item_check="valid_story",
valid_only=True, hide_extention=True, extra_parameter_function="get_story_listing_data",
column_names=['Story Name', 'Action Count', 'Last Loaded'], show_filename=False,
column_widths=['minmax(150px, auto)', '140px', '160px'], advanced_sort=story_sort,
column_widths=['minmax(150px, auto)', '140px', '160px'], advanced_sort="story_sort",
sort="Modified", desc=True, rename_return_emit_name="popup_rename_story")
@logger.catch
@@ -8809,8 +8833,8 @@ def UI_2_load_softprompt_list(data):
socketio.emit("error", "Soft prompts are not supported by your current model/backend", broadcast=True, room="UI_2")
assert koboldai_vars.allowsp, "Soft prompts are not supported by your current model/backend"
file_popup("Select Softprompt to Load", "./softprompts", "load_softprompt", upload=True, jailed=True, folder_only=False, renameable=True,
deleteable=True, show_breadcrumbs=True, item_check=valid_softprompt,
valid_only=True, hide_extention=True, extra_parameter_function=get_softprompt_desc,
deleteable=True, show_breadcrumbs=True, item_check="valid_softprompt",
valid_only=True, hide_extention=True, extra_parameter_function="get_softprompt_desc",
column_names=['Softprompt Name', 'Softprompt Description'],
show_filename=False,
column_widths=['150px', 'auto'])
@@ -8852,8 +8876,8 @@ def UI_2_load_softprompt(data):
@logger.catch
def UI_2_load_userscripts_list(data):
file_popup("Select Userscripts to Load", "./userscripts", "load_userscripts", upload=True, jailed=True, folder_only=False, renameable=True, editable=True,
deleteable=True, show_breadcrumbs=False, item_check=valid_userscripts_to_load,
valid_only=True, hide_extention=True, extra_parameter_function=get_userscripts_desc,
deleteable=True, show_breadcrumbs=False, item_check="valid_userscripts_to_load",
valid_only=True, hide_extention=True, extra_parameter_function="get_userscripts_desc",
column_names=['Module Name', 'Description'],
show_filename=False, show_folders=False,
column_widths=['150px', 'auto'])
@@ -9199,24 +9223,33 @@ def text2img_horde(prompt,
filename = "story_art.png"):
logger.debug("Generating Image using Horde")
koboldai_vars.generating_image = True
final_imgen_params = {
"n": 1,
"width": 512,
"height": 512,
"steps": 50,
}
final_submit_dict = {
"prompt": "{}, {}".format(prompt, art_guide),
"api_key": koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",
"params": final_imgen_params,
"trusted_workers": False,
"models": [
"stable_diffusion"
],
"params": {
"n":1,
"nsfw": True,
"sampler_name": "k_euler_a",
"karras": True,
"cfg_scale": 7.0,
"steps":25,
"width":512,
"height":512}
}
cluster_headers = {'apikey': koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",}
logger.debug(final_submit_dict)
submit_req = requests.post('https://stablehorde.net/api/v1/generate/sync', json = final_submit_dict)
submit_req = requests.post('https://stablehorde.net/api/v2/generate/sync', json = final_submit_dict, headers=cluster_headers)
if submit_req.ok:
results = submit_req.json()
for iter in range(len(results)):
b64img = results[iter]["img"]
for iter in range(len(results['generations'])):
b64img = results['generations'][iter]["img"]
base64_bytes = b64img.encode('utf-8')
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
@@ -9285,7 +9318,7 @@ def text2img_api(prompt,
"prompt": "{}, {}".format(prompt, art_guide),
"params": final_imgen_params,
}
apiaddress = 'http://127.0.0.1:7860/sdapi/v1/txt2img'
apiaddress = '{}/sdapi/v1/txt2img'.format(koboldai_vars.img_gen_api_url)
payload_json = json.dumps(final_submit_dict)
logger.debug(final_submit_dict)
submit_req = requests.post(url=f'{apiaddress}', data=payload_json).json()
@@ -9384,14 +9417,10 @@ def summarize(text, max_length=100, min_length=30, unload=True):
#Actual sumarization
start_time = time.time()
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
#make sure text is less than 1024 tokens, otherwise we'll crash
if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000:
text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000])
output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
logger.debug("Time to summarize: {}".format(time.time()-start_time))
#move model back to CPU to save precious vram
torch.cuda.empty_cache()

View File

@@ -554,6 +554,19 @@ gensettingstf = [
'children': [{'text': 'Use Local Only', 'value': 0}, {'text':'Prefer Local','value':1}, {'text':'Prefer Horde', 'value':2}, {'text':'Use Horde Only', 'value':3}, {'text':'Use Local SD-WebUI API', 'value':4}]
},
{
"UI_V2_Only": True,
"uitype": "text",
"unit": "text",
"label": "Img API URL",
"id": "img_gen_api_url",
"default": "",
"tooltip": "The URL to use when selecting Use Local SD-WebUI API setting in Image Priority",
"menu_path": "Interface",
"sub_path": "Images",
"classname": "user",
"name": "img_gen_api_url"
},
{
"UI_V2_Only": True,
"uitype": "toggle",
"unit": "bool",

View File

@@ -985,6 +985,7 @@ class user_settings(settings):
self.beep_on_complete = False
self.img_gen_priority = 1
self.show_budget = False
self.img_gen_api_url = "http://127.0.0.1:7860/"
self.cluster_requested_models = [] # The models which we allow to generate during cluster mode
@@ -1287,7 +1288,7 @@ class KoboldStoryRegister(object):
if type(json_data) == str:
import json
json_data = json.loads(json_data)
self.action_count = json_data['action_count']
self.action_count = int(json_data['action_count'])
#JSON forces keys to be strings, so let's fix that
temp = {}
data_to_send = []

View File

@@ -1784,6 +1784,13 @@ function world_info_entry(data) {
//First let's get the id of the element we're on so we can restore it after removing the object
var original_focus = document.activeElement.id;
if (!(document.getElementById("world_info_folder_"+data.folder))) {
folder = document.createElement("div");
//console.log("Didn't find folder " + data.folder);
} else {
folder = document.getElementById("world_info_folder_"+data.folder);
}
if (document.getElementById("world_info_"+data.uid)) {
world_info_card = document.getElementById("world_info_"+data.uid);
} else {
@@ -1791,6 +1798,7 @@ function world_info_entry(data) {
world_info_card = world_info_card_template.cloneNode(true);
world_info_card.id = "world_info_"+data.uid;
world_info_card.setAttribute("uid", data.uid);
folder.append(world_info_card);
}
if (data.used_in_game) {
world_info_card.classList.add("used_in_game");
@@ -2049,12 +2057,6 @@ function world_info_entry(data) {
constant.checked = data.constant;
constant.classList.remove("pulse");
if (!(document.getElementById("world_info_folder_"+data.folder))) {
folder = document.createElement("div");
//console.log("Didn't find folder " + data.folder);
} else {
folder = document.getElementById("world_info_folder_"+data.folder);
}
//Let's figure out the order to insert this card
var found = false;
var moved = false;
@@ -2888,6 +2890,11 @@ function push_selection_to_world_info() {
menu.classList.add("open");
}
document.getElementById("story_flyout_tab_wi").onclick();
if (~("root" in world_info_folder_data)) {
world_info_folder_data["root"] = [];
world_info_folder(world_info_folder_data);
}
create_new_wi_entry("root");
document.getElementById("world_info_entry_text_-1").value = getSelectionText();
}

View File

@@ -1,9 +1,9 @@
{% for item in settings %}
{% if item["menu_path"] == menu and item['sub_path'] == sub_path %}
{% if 'extra_classes' in item %}
<div class="setting_container {{ item['extra_classes'] }}">
<div id="{{ item['name'] }}_card" class="setting_container {{ item['extra_classes'] }}">
{% else %}
<div class="setting_container">
<div id="{{ item['name'] }}_card" class="setting_container">
{% endif %}
<!---Top Row---->
<span class="setting_label">