Added start to alternative multi-gen (linear instead of parallel). Non-functional

Continued stub for in UI soft prompt training.
Removed old xls to preset file code
This commit is contained in:
ebolam
2022-12-05 13:50:49 -05:00
parent 457b7a46c4
commit 280c35b452
11 changed files with 246 additions and 205 deletions

View File

@@ -2452,6 +2452,37 @@ def reset_model_settings():
koboldai_vars.revision = None koboldai_vars.revision = None
koboldai_vars.lazy_load = True koboldai_vars.lazy_load = True
def unload_model():
global model
global generator
global model_config
global tokenizer
#We need to wipe out the existing model and refresh the cuda cache
model = None
generator = None
model_config = None
koboldai_vars.online_model = ''
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
gc.collect()
try:
with torch.no_grad():
torch.cuda.empty_cache()
except:
pass
#Reload our badwords
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False, url=None, use_8_bit=False): def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False, url=None, use_8_bit=False):
global model global model
global generator global generator
@@ -2490,29 +2521,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if breakmodel_args_default_to_cpu and disk_layers is None: if breakmodel_args_default_to_cpu and disk_layers is None:
disk_layers = args.breakmodel_disklayers = 0 disk_layers = args.breakmodel_disklayers = 0
#We need to wipe out the existing model and refresh the cuda cache unload_model()
model = None
generator = None
model_config = None
koboldai_vars.online_model = ''
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
for tensor in gc.get_objects():
try:
if torch.is_tensor(tensor):
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
except:
pass
gc.collect()
try:
with torch.no_grad():
torch.cuda.empty_cache()
except:
pass
#Reload our badwords
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
if online_model == "": if online_model == "":
koboldai_vars.configname = getmodelname() koboldai_vars.configname = getmodelname()
@@ -5244,97 +5253,103 @@ def core_generate(text: list, _min: int, _max: int, found_entries: set, is_core:
with torch.no_grad(): with torch.no_grad():
already_generated = 0 already_generated = 0
numseqs = koboldai_vars.numseqs numseqs = koboldai_vars.numseqs
total_gens = None
while True: for i in range(koboldai_vars.numseqs if koboldai_vars.alt_multi_gen else 1):
# The reason this is a loop is due to how Dynamic WI works. We while True:
# cannot simply add the WI to the context mid-generation, so we # The reason this is a loop is due to how Dynamic WI works. We
# stop early, and then insert WI, then continue generating. That # cannot simply add the WI to the context mid-generation, so we
# stopping and continuing is this loop. # stop early, and then insert WI, then continue generating. That
# stopping and continuing is this loop.
start_time = time.time() start_time = time.time()
result = raw_generate( result = raw_generate(
gen_in[0], gen_in[0],
max_new=koboldai_vars.genamt, max_new=koboldai_vars.genamt,
do_streaming=koboldai_vars.output_streaming, do_streaming=koboldai_vars.output_streaming,
do_dynamic_wi=koboldai_vars.dynamicscan, do_dynamic_wi=koboldai_vars.dynamicscan,
batch_count=numseqs, batch_count=numseqs if not koboldai_vars.alt_multi_gen else 1,
# Real max length is handled by CoreStopper. # Real max length is handled by CoreStopper.
bypass_hf_maxlength=koboldai_vars.dynamicscan, bypass_hf_maxlength=koboldai_vars.dynamicscan,
is_core=True, is_core=True,
)
logger.debug("core_generate: run raw_generate pass {} {}s".format(already_generated, time.time()-start_time))
genout = result.encoded
already_generated += len(genout[0])
try:
assert already_generated <= koboldai_vars.genamt
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", koboldai_vars.genamt)
raise
if result.is_whole_generation:
break
# Generation stopped; why?
# If we have been told to halt, we have reached our target token
# amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if model.core_stopper.halt or not model.core_stopper.regeneration_required:
break
# Now we are doing stuff for Dynamic WI.
assert genout.ndim >= 2
assert genout.shape[0] == koboldai_vars.numseqs
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
if(already_generated != koboldai_vars.generated_tkns):
print("already_generated: {}".format(already_generated))
print("generated_tkns: {}".format(koboldai_vars.generated_tkns))
raise RuntimeError("WI scanning error")
for r in range(koboldai_vars.numseqs):
for c in range(already_generated):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
#winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
#txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _, _found_entries = koboldai_vars.calc_ai_text(submitted_text=txt, send_context=False)
found_entries[i].update(_found_entries)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
device=genout.device,
) )
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1) logger.debug("core_generate: run raw_generate pass {} {}s".format(already_generated, time.time()-start_time))
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
gen_in = genout genout = result.encoded
numseqs = 1
already_generated += len(genout[0])
try:
assert already_generated <= koboldai_vars.genamt
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", koboldai_vars.genamt)
raise
if result.is_whole_generation:
break
# Generation stopped; why?
# If we have been told to halt, we have reached our target token
# amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if model.core_stopper.halt or not model.core_stopper.regeneration_required:
break
# Now we are doing stuff for Dynamic WI.
assert genout.ndim >= 2
assert genout.shape[0] == koboldai_vars.numseqs
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
if(already_generated != koboldai_vars.generated_tkns):
print("already_generated: {}".format(already_generated))
print("generated_tkns: {}".format(koboldai_vars.generated_tkns))
raise RuntimeError("WI scanning error")
for r in range(koboldai_vars.numseqs):
for c in range(already_generated):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
#winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
#txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _, _found_entries = koboldai_vars.calc_ai_text(submitted_text=txt, send_context=False)
found_entries[i].update(_found_entries)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
gen_in = genout
numseqs = 1
if total_gens is None:
total_gens = genout
else:
total_gens = torch.cat((total_gens, genout))
return genout, already_generated return total_gens, already_generated
class GenerationResult: class GenerationResult:
def __init__( def __init__(
@@ -6043,6 +6058,7 @@ def generate(txt, minimum, maximum, found_entries=None):
try: try:
start_time = time.time() start_time = time.time()
genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries) genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries)
print(genout)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):
@@ -9613,6 +9629,55 @@ def UI_2_privacy_mode(data):
if data['password'] == koboldai_vars.privacy_password: if data['password'] == koboldai_vars.privacy_password:
koboldai_vars.privacy_mode = False koboldai_vars.privacy_mode = False
#==================================================================#
# Soft Prompt Tuning
#==================================================================#
@socketio.on("create_new_softprompt")
@logger.catch
def UI_2_create_new_softprompt(data):
logger.info("Soft Prompt Dataset: {}".format(data))
from prompt_tuner import BasicTrainer
trainer = BasicTrainer(None, quiet=koboldai_vars.quiet)
trainer.data.ckpt_path = koboldai_vars.model
trainer.get_hf_checkpoint_metadata()
trainer.data.save_file = "{}.mtjsp".format("".join(x for x in data['sp_title'] if x.isalnum() or x in [" ", "-", "_"]))
trainer.data.prompt_method = "tokens"
tokenizer = trainer.get_tokenizer()
if trainer.data.newlinemode == "s": # Handle fairseq-style newlines if required
initial_softprompt = data['sp_prompt'].replace("\n", "</s>")
trainer.data.initial_softprompt = tokenizer.encode(
data['sp_prompt'], max_length=int(2e9), truncation=True
)
trainer.tokenize_dataset(dataset_path=data['sp_dataset'],
output_file="softprompts/{}.npy".format("".join(x for x in data['sp_title'] if x.isalnum() or x in [" ", "-", "_"])),
batch_size=2048 if 'batch_size' not in data else data['batch_size'],
epochs=1 if 'epochs' not in data else data['epochs'])
trainer.data.dataset_file = "softprompts/{}.npy".format("".join(x for x in data['sp_title'] if x.isalnum() or x in [" ", "-", "_"]))
trainer.data.gradient_accumulation_steps = 16 if 'gradient_accumulation_steps' not in data else data['gradient_accumulation_steps']
trainer.data.stparams = {
"lr": 3e-5,
"max_grad_norm": 10.0,
"weight_decay": 0.1,
"warmup": 0.1,
"end_lr_multiplier": 0.1,
"save_every": 50,
}
unload_model()
trainer.train(breakmodel_primary_device=breakmodel.primary_device,
breakmodel_gpulayers=breakmodel.gpu_blocks,
breakmodel_disklayers=breakmodel.disk_blocks)
output_file = "softprompts/{}.zip".format("".join(x for x in data['sp_title'] if x.isalnum() or x in [" ", "-", "_"]))
name = data['sp_title']
author = data['sp_author']
supported = koboldai_vars.model
description = data['sp_description']
trainer.export_to_kobold(output_file, name, author, supported, description)
output_file = "softprompts/{}.json".format("".join(x for x in data['sp_title'] if x.isalnum() or x in [" ", "-", "_"]))
trainer.export_to_mkultra(output_file, name, description)
#==================================================================# #==================================================================#
# Test # Test

View File

@@ -1,60 +0,0 @@
import pandas as pd
import sys
output = []
sheet_mapper = {"KAI-ADAPTED 13B": "13B", "KAI-ADAPTED 6B": "6B", 'KAI-CUSTOM': 'Custom'}
for file in ['KoboldAI Settings (6B).xlsx', 'KoboldAI Settings (13B).xlsx', 'KoboldAI Settings (Custom).xlsx', 'KoboldAI Settings (Original).xlsx']:
presets = pd.read_excel("preset Files/{}".format(file), None)
for sheet in presets:
df = presets[sheet]
if sheet in sheet_mapper:
sheet = sheet_mapper[sheet]
df = df.dropna(axis=1, how='all')
df = df.rename(columns={"Unnamed: 0": "setting"})
df = pd.melt(df, id_vars=['setting'])
df = df.rename(columns={"variable": "preset"})
df['fix'] = df['value'].str.replace(" (KAI)", "", regex=False)
df.loc[~df['fix'].isnull(), 'value'] = df['fix']
df = df.drop(columns=['fix'])
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("Temp", "5", regex=False)
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("K", "0", regex=False)
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("TFS", "3", regex=False)
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("A", "1", regex=False)
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("Typ", "4", regex=False)
df.loc[df['setting']=='Samplers Order', 'value'] = df['value'].str.replace("P", "2", regex=False)
settings_mapper = {'Temperature': 'temp', 'Output Length': 'genamt', 'Repetition Penalty': 'rep_pen',
'Top P': 'top_p', 'Top K': 'top_k', 'Tail-Free': 'tfs', 'Repetition Penalty Range': 'rep_pen_range',
'Repetition Penalty Slope': 'rep_pen_slope', 'Typical': 'typical', 'Top A': 'top_a',
'Samplers Order': 'sampler_order', 'Description of settings from the author': 'description',
'Author': 'Author', 'Model Type': 'Model Type',
'Description of settings from NovelAI': 'description', 'Model Size': "Model Size"
}
df['setting'] = df['setting'].map(settings_mapper)
try:
df = df.pivot(index='preset', columns='setting', values='value')
except:
print(file)
display(df)
raise
df['Model Type'] = df['Model Type'].str.replace(", ", ",").str.split(",")
df.loc[:, 'Model Category'] = sheet
output.append(df)
#output[sheet] = df.to_json(orient="index")
df = pd.concat(output)
df = df.reset_index(drop=False)
df['uid'] = df.index
df = df.explode("Model Type")
df['description'] = df['description'].str.strip()
with open("official.presets", "w") as f:
f.write(df.reset_index(drop=True).to_json(orient='records'))

View File

@@ -36,4 +36,5 @@ dependencies:
- ansi2html - ansi2html
- flask_compress - flask_compress
- ijson - ijson
- bitsandbytes - bitsandbytes
- ftfy

View File

@@ -34,3 +34,4 @@ dependencies:
- ansi2html - ansi2html
- flask_compress - flask_compress
- ijson - ijson
- ftfy

View File

@@ -411,6 +411,22 @@ gensettingstf = [
"sub_path": "Other", "sub_path": "Other",
"classname": "system", "classname": "system",
"name": "alt_gen", "name": "alt_gen",
"ui_level": 2
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Alt Multi Gen",
"id": "alt_multi_gen",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Runs Gens per Action one at a time so you can select one if you like it without having to wait.",
"menu_path": "Settings",
"sub_path": "Other",
"classname": "model",
"name": "alt_multi_gen",
"ui_level": 2 "ui_level": 2
}, },
{ {

View File

@@ -649,6 +649,7 @@ class model_settings(settings):
</div>""" # Custom Welcome Text </div>""" # Custom Welcome Text
self.welcome = self.welcome_default self.welcome = self.welcome_default
self.koboldai_vars = koboldai_vars self.koboldai_vars = koboldai_vars
self.alt_multi_gen = False
def reset_for_model_load(self): def reset_for_model_load(self):
self.max_length = 2048 # Maximum number of tokens to submit per action self.max_length = 2048 # Maximum number of tokens to submit per action

View File

@@ -3,7 +3,7 @@ import os
import sys import sys
import math import math
import numpy as np import numpy as np
import termcolor from logger import logger
import contextlib import contextlib
import traceback import traceback
import random import random
@@ -70,21 +70,24 @@ def patch_transformers_download():
class Send_to_socketio(object): class Send_to_socketio(object):
def write(self, bar): def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "") bar = bar.replace("\r", "").replace("\n", "")
if bar != "":
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
try: try:
print(bar, end="\r") print('\r' + bar, end='')
if utils.emit is not None: socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True, room="UI_1")
utils.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
eventlet.sleep(seconds=0) eventlet.sleep(seconds=0)
except: except:
pass pass
def flush(self):
pass
def http_get( def http_get(
url: str, url: str,
temp_file: transformers.utils.hub.BinaryIO, temp_file,
proxies=None, proxies=None,
resume_size=0, resume_size=0,
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None, headers=None,
file_name: transformers.utils.hub.Optional[str] = None, file_name=None,
): ):
""" """
Download remote file. Do not gobble up errors. Download remote file. Do not gobble up errors.
@@ -108,13 +111,18 @@ def patch_transformers_download():
desc=f"Downloading {file_name}" if file_name is not None else "Downloading", desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
file=Send_to_socketio(), file=Send_to_socketio(),
) )
koboldai_vars.status_message = "Download Model"
koboldai_vars.total_download_chunks = total
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
if url[-11:] != 'config.json': if url[-11:] != 'config.json':
progress.update(len(chunk)) progress.update(len(chunk))
koboldai_vars.downloaded_chunks += len(chunk)
temp_file.write(chunk) temp_file.write(chunk)
if url[-11:] != 'config.json': if url[-11:] != 'config.json':
progress.close() progress.close()
koboldai_vars.status_message = ""
transformers.utils.hub.http_get = http_get transformers.utils.hub.http_get = http_get
@@ -195,18 +203,18 @@ def device_list(n_layers, primary=None, selected=None):
if(device_count < 2): if(device_count < 2):
primary = None primary = None
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0] gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}") logger.info(" DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
for i in range(device_count): for i in range(device_count):
name = torch.cuda.get_device_name(i) name = torch.cuda.get_device_name(i)
if(len(name) > 47): if(len(name) > 47):
name = "..." + name[-44:] name = "..." + name[-44:]
row_color = colors.END row_color = colors.END
sep_color = colors.YELLOW sep_color = colors.YELLOW
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}") logger.info(f"{'(primary)' if i == primary else ' '*9} {i:3} | {gpu_blocks[i]:3} | {name}")
row_color = colors.END row_color = colors.END
sep_color = colors.YELLOW sep_color = colors.YELLOW
print(f"{row_color}{colors.YELLOW + '->' + row_color if -1 == selected else ' '} {' '*9} N/A {sep_color}|{row_color} {breakmodel.disk_blocks:3} {sep_color}|{row_color} (Disk cache){colors.END}") logger.info(f" {' '*9} N/A | {breakmodel.disk_blocks:3} | (Disk cache)")
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") logger.info(f" {' '*9} N/A | {n_layers:3} | (CPU)")
def move_model_to_devices(model, usegpu, gpu_device): def move_model_to_devices(model, usegpu, gpu_device):
@@ -440,12 +448,12 @@ class TrainerBase(abc.ABC):
@property @property
def lazy_load_spec(self): def lazy_load_spec(self):
print("WARNING: `TrainerData.lazy_load_spec` is currently unused", file=sys.stderr) logger.warning("WARNING: `TrainerData.lazy_load_spec` is currently unused")
return self.__lazy_load_spec return self.__lazy_load_spec
@lazy_load_spec.setter @lazy_load_spec.setter
def lazy_load_spec(self, value: Optional[dict]): def lazy_load_spec(self, value: Optional[dict]):
print("WARNING: `TrainerData.lazy_load_spec` is currently unused", file=sys.stderr) logger.warning("WARNING: `TrainerData.lazy_load_spec` is currently unused")
self.__lazy_load_spec = value self.__lazy_load_spec = value
@property @property
@@ -465,7 +473,7 @@ class TrainerBase(abc.ABC):
self.data = self.TrainerData() self.data = self.TrainerData()
self._spmodule: Optional[str] = None self._spmodule: Optional[str] = None
if universe is not None: if universe is not None:
print("WARNING: The `universe` argument of `TrainerBase.__init__` is currently unused", file=sys.stderr) logger.warning("WARNING: The `universe` argument of `TrainerBase.__init__` is currently unused")
def raise_configuration_error(self, msg, **kwargs): def raise_configuration_error(self, msg, **kwargs):
if "quiet" not in kwargs: if "quiet" not in kwargs:
@@ -608,14 +616,11 @@ class TrainerBase(abc.ABC):
self.data.params["max_batch_size"] - self.data.soft_in_dim, self.data.params["max_batch_size"] - self.data.soft_in_dim,
) )
assert batch_size >= 0 assert batch_size >= 0
print( logger.info(
termcolor.colored( "\nIf you see a warning somewhere below about token indices, ignore it. That warning is normal.\n"
"\nIf you see a warning somewhere below about token indices, ignore it. That warning is normal.\n",
"magenta",
)
) )
print("Batch size:", batch_size) logger.info("Batch size: {}".format(batch_size))
print(termcolor.colored("Tokenizing your dataset...\n", "magenta")) logger.info("Tokenizing your dataset...\n")
if not isinstance(dataset_path, str): if not isinstance(dataset_path, str):
files = [dataset_path] files = [dataset_path]
@@ -632,7 +637,7 @@ class TrainerBase(abc.ABC):
eos = tokenizer.decode(self.data.params["eos_token"]) eos = tokenizer.decode(self.data.params["eos_token"])
for path in files: for path in files:
if isinstance(path, str): if isinstance(path, str):
f = open(path) f = open(path, 'r', encoding='utf-8')
else: else:
f = path f = path
try: try:
@@ -645,7 +650,7 @@ class TrainerBase(abc.ABC):
if isinstance(path, str): if isinstance(path, str):
f.close() f.close()
print("Dataset size (in tokens):", len(tokens)) logger.info("Dataset size (in tokens): {}".format(len(tokens)))
if len(tokens) < batch_size + 1: if len(tokens) < batch_size + 1:
self.raise_configuration_error( self.raise_configuration_error(
"Your dataset is too small! The number of tokens has to be greater than the batch size. Try increasing the epochs.", "Your dataset is too small! The number of tokens has to be greater than the batch size. Try increasing the epochs.",
@@ -653,7 +658,7 @@ class TrainerBase(abc.ABC):
) )
tail = len(tokens) % (batch_size + 1) tail = len(tokens) % (batch_size + 1)
if tail: if tail:
print( logger.info(
f"We're removing the last {tail} tokens from your dataset to make the length a multiple of {batch_size+1}." f"We're removing the last {tail} tokens from your dataset to make the length a multiple of {batch_size+1}."
) )
tokens = tokens[:-tail] tokens = tokens[:-tail]
@@ -671,7 +676,7 @@ class TrainerBase(abc.ABC):
axis=0, axis=0,
) )
tokens = tokens[: math.ceil(epochs * sequences_per_epoch)] tokens = tokens[: math.ceil(epochs * sequences_per_epoch)]
print(f"Total sequences in your dataset: {tokens.shape[0]}") logger.info(f"Total sequences in your dataset: {tokens.shape[0]}")
if isinstance(output_file, str): if isinstance(output_file, str):
f = open(output_file, "w") f = open(output_file, "w")
@@ -698,7 +703,7 @@ class TrainerBase(abc.ABC):
self.data.params["max_batch_size"] = 2048 self.data.params["max_batch_size"] = 2048
if not os.path.exists(self.data.save_file): if not os.path.exists(self.data.save_file):
print("We are starting a brand new soft-tuning session.\n") logger.info("We are starting a brand new soft-tuning session.\n")
self.startup(step=-1) self.startup(step=-1)
if self.data.soft_in_dim <= 0: if self.data.soft_in_dim <= 0:
self.raise_configuration_error( self.raise_configuration_error(
@@ -718,7 +723,7 @@ class TrainerBase(abc.ABC):
opt_state = z["opt_state"] opt_state = z["opt_state"]
except AssertionError: except AssertionError:
self.raise_configuration_error("MKUSP file is corrupted.", code=14) self.raise_configuration_error("MKUSP file is corrupted.", code=14)
print(f"We're resuming a previous soft-tuning session at step {step+1}.\n") logger.info(f"We're resuming a previous soft-tuning session at step {step+1}.\n")
self.startup(step=step + 1) self.startup(step=step + 1)
soft_embeddings = z["tensor"] soft_embeddings = z["tensor"]
@@ -785,7 +790,7 @@ class TrainerBase(abc.ABC):
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else: else:
num_tensors = len(device_map) num_tensors = len(device_map)
print(flush=True) #print(flush=True)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio()) utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
with zipfile.ZipFile(f, "r") as z: with zipfile.ZipFile(f, "r") as z:

View File

@@ -24,4 +24,5 @@ psutil
ansi2html ansi2html
flask_compress flask_compress
ijson ijson
bitsandbytes bitsandbytes
ftfy

View File

@@ -25,4 +25,5 @@ diffusers
psutil psutil
ansi2html ansi2html
flask_compress flask_compress
ijson ijson
ftfy

View File

@@ -2600,6 +2600,15 @@ function process_log_message(full_data) {
} }
//--------------------------------------------UI to Server Functions---------------------------------- //--------------------------------------------UI to Server Functions----------------------------------
function create_new_softprompt() {
socket.emit("create_new_softprompt", {"sp_title": document.getElementById("sp_title").value,
"sp_prompt": document.getElementById("sp_prompt").value,
"sp_dataset": document.getElementById("sp_dataset").value,
"sp_author": document.getElementById("sp_author").value,
"sp_description": document.getElementById("sp_description").value
});
}
async function download_story_to_json() { async function download_story_to_json() {
//document.getElementById('download_iframe').src = 'json'; //document.getElementById('download_iframe').src = 'json';
downloaded = false; downloaded = false;

View File

@@ -85,7 +85,7 @@
<button type="button" class="btn btn-primary popup_load_cancel_button" onclick="closePopups();">Cancel</button> <button type="button" class="btn btn-primary popup_load_cancel_button" onclick="closePopups();">Cancel</button>
</div> </div>
</div> </div>
<!---------------- Story overwrite screen ----------------------> <!---------------- Private Mode Unlock screen ---------------------->
<div id="privacy_mode" class="popup-window popup"> <div id="privacy_mode" class="popup-window popup">
<div class="title"> <div class="title">
<div class="popuptitletext">Locked</div> <div class="popuptitletext">Locked</div>
@@ -267,7 +267,8 @@
</div> </div>
<div id="shortcut-container"></div> <div id="shortcut-container"></div>
</div> </div>
<!---------------- Shortcuts ------------------->
<!---------------- Softprompt Trainer ------------------->
<div id="sp-trainer-popup" class="popup-window popup"> <div id="sp-trainer-popup" class="popup-window popup">
<div class="title"> <div class="title">
<div class="popuptitletext">Softprompt Trainer</div> <div class="popuptitletext">Softprompt Trainer</div>
@@ -282,7 +283,7 @@
</form> </form>
</div> </div>
<div class="popup_load_cancel"> <div class="popup_load_cancel">
<button type="button" class="btn btn-primary popup_load_cancel_button" onclick="closePopups();">Ok</button> <button type="button" class="btn btn-primary popup_load_cancel_button" onclick="create_new_softprompt(); closePopups();">Ok</button>
</div> </div>
</div> </div>