This commit is contained in:
somebody
2022-09-19 19:50:33 -05:00
10 changed files with 220 additions and 20 deletions

View File

@@ -64,7 +64,7 @@ from utils import debounce
import utils
import koboldai_settings
import torch
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
from transformers import __version__ as transformers_version
import transformers
try:
@@ -73,6 +73,11 @@ except:
pass
import transformers.generation_utils
# Text2img
import base64
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from io import BytesIO
global tpu_mtj_backend
@@ -1618,8 +1623,8 @@ def get_cluster_models(msg):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists(get_config_filename(koboldai_vars.model_selected)):
with open(get_config_filename(koboldai_vars.model_selected), "r") as file:
if path.exists(get_config_filename(model)):
with open(get_config_filename(model), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
@@ -1630,7 +1635,7 @@ def get_cluster_models(msg):
changed=True
if changed:
js={}
with open(get_config_filename(koboldai_vars.model_selected), "w") as file:
with open(get_config_filename(model), "w") as file:
js["apikey"] = koboldai_vars.oaiapikey
file.write(json.dumps(js, indent=3))
@@ -1674,7 +1679,7 @@ def patch_transformers_download():
if bar != "":
try:
print(bar, end="\r")
print(bar, end="\n")
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
eventlet.sleep(seconds=0)
except:
@@ -1712,10 +1717,12 @@ def patch_transformers_download():
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
file=Send_to_socketio(),
)
koboldai_vars.total_download_chunks = total
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if url[-11:] != 'config.json':
progress.update(len(chunk))
koboldai_vars.downloaded_chunks += len(chunk)
temp_file.write(chunk)
if url[-11:] != 'config.json':
progress.close()
@@ -1768,6 +1775,8 @@ def patch_transformers_download():
def patch_transformers():
global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download()
@@ -1784,9 +1793,11 @@ def patch_transformers():
if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
old_transfomers_functions['PreTrainedModel.from_pretrained'] = PreTrainedModel.from_pretrained
PreTrainedModel.from_pretrained = new_from_pretrained
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
@@ -1814,6 +1825,7 @@ def patch_transformers():
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1833,6 +1845,7 @@ def patch_transformers():
self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init
@@ -2117,6 +2130,7 @@ def patch_transformers():
break
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
@@ -2171,7 +2185,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if not utils.HAS_ACCELERATE:
disk_layers = None
koboldai_vars.reset_model()
koboldai_vars.cluster_requested_models = online_model
koboldai_vars.cluster_requested_models = [online_model] if isinstance(online_model, str) else online_model
koboldai_vars.noai = False
if not use_breakmodel_args:
set_aibusy(True)
@@ -8134,6 +8148,138 @@ def get_model_size(model_name):
def UI_2_save_revision(data):
koboldai_vars.save_revision()
#==================================================================#
# Generate Image
#==================================================================#
@socketio.on("generate_image")
def UI_2_generate_image(data):
koboldai_vars.generating_image = True
#get latest action
if len(koboldai_vars.actions) > 0:
action = koboldai_vars.actions[-1]
else:
action = koboldai_vars.prompt
#Get matching world info entries
keys = []
for wi in koboldai_vars.worldinfo_v2:
for key in wi['key']:
if key in action:
#Check to make sure secondary keys are present if needed
if len(wi['keysecondary']) > 0:
for keysecondary in wi['keysecondary']:
if keysecondary in action:
keys.append(key)
break
break
else:
keys.append(key)
break
#If we have > 4 keys, use those otherwise use sumarization
if len(keys) < 4:
from transformers import pipeline as summary_pipeline
summarizer = summary_pipeline("summarization")
#text to summarize:
if len(koboldai_vars.actions) < 5:
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
else:
text = "".join(koboldai_vars.actions[:-5])
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']]
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
b64_data = text2img(", ".join(keys), art_guide = art_guide)
emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
@logger.catch
def text2img(prompt,
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
filename = "story_art.png"):
print("Generating Image")
koboldai_vars.generating_image = True
final_imgen_params = {
"n": 1,
"width": 512,
"height": 512,
"steps": 50,
}
final_submit_dict = {
"prompt": "{}, {}".format(prompt, art_guide),
"api_key": koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",
"params": final_imgen_params,
}
logger.debug(final_submit_dict)
submit_req = requests.post('https://stablehorde.net/api/v1/generate/sync', json = final_submit_dict)
if submit_req.ok:
results = submit_req.json()
for iter in range(len(results)):
b64img = results[iter]["img"]
base64_bytes = b64img.encode('utf-8')
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
if len(results) > 1:
final_filename = f"{iter}_{filename}"
else:
final_filename = filename
img.save(final_filename)
print("Saved Image")
koboldai_vars.generating_image = False
return(b64img)
else:
koboldai_vars.generating_image = False
print(submit_req.text)
def get_items_locations_from_text(text):
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
nlp = transformers.pipeline("ner", model=model, tokenizer=tokenizer)
# input example sentence
ner_results = nlp(text)
orgs = []
last_org_position = -2
loc = []
last_loc_position = -2
per = []
last_per_position = -2
for i, result in enumerate(ner_results):
if result['entity'] in ('B-ORG', 'I-ORG'):
if result['start']-1 <= last_org_position:
if result['start'] != last_org_position:
orgs[-1] = "{} ".format(orgs[-1])
orgs[-1] = "{}{}".format(orgs[-1], result['word'].replace("##", ""))
else:
orgs.append(result['word'])
last_org_position = result['end']
elif result['entity'] in ('B-LOC', 'I-LOC'):
if result['start']-1 <= last_loc_position:
if result['start'] != last_loc_position:
loc[-1] = "{} ".format(loc[-1])
loc[-1] = "{}{}".format(loc[-1], result['word'].replace("##", ""))
else:
loc.append(result['word'])
last_loc_position = result['end']
elif result['entity'] in ('B-PER', 'I-PER'):
if result['start']-1 <= last_per_position:
if result['start'] != last_per_position:
per[-1] = "{} ".format(per[-1])
per[-1] = "{}{}".format(per[-1], result['word'].replace("##", ""))
else:
per.append(result['word'])
last_per_position = result['end']
print("Orgs: {}".format(orgs))
print("Locations: {}".format(loc))
print("People: {}".format(per))
#==================================================================#
# Test
#==================================================================#
@@ -10919,6 +11065,7 @@ if __name__ == "__main__":
try:
cloudflare = str(localtunnel.stdout.readline())
cloudflare = (re.search("(?P<url>https?:\/\/[^\s]+loca.lt)", cloudflare).group("url"))
koboldai_vars.cloudflare_link = cloudflare
break
except:
attempts += 1
@@ -10928,12 +11075,15 @@ if __name__ == "__main__":
print("LocalTunnel could not be created, falling back to cloudflare...")
from flask_cloudflared import _run_cloudflared
cloudflare = _run_cloudflared(port)
koboldai_vars.cloudflare_link = cloudflare
elif(args.ngrok):
from flask_ngrok import _run_ngrok
cloudflare = _run_ngrok()
koboldai_vars.cloudflare_link = cloudflare
elif(args.remote):
from flask_cloudflared import _run_cloudflared
cloudflare = _run_cloudflared(port)
koboldai_vars.cloudflare_link = cloudflare
if(args.localtunnel or args.ngrok or args.remote):
with open('cloudflare.log', 'w') as cloudflarelog:
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)

View File

@@ -20,6 +20,7 @@ dependencies:
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- Pillow
- pip:
- flask-cloudflared
- flask-ngrok

View File

@@ -17,6 +17,7 @@ dependencies:
- marshmallow>=3.13
- apispec-webframeworks
- loguru
- Pillow
- pip:
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
- torch==1.10.*

View File

@@ -111,7 +111,7 @@ class koboldai_vars(object):
def reset_model(self):
self._model_settings.reset_for_model_load()
def calc_ai_text(self, submitted_text="", method=2):
def calc_ai_text(self, submitted_text="", method=2, return_text=False):
context = []
token_budget = self.max_length
used_world_info = []
@@ -285,6 +285,8 @@ class koboldai_vars(object):
tokens = self.tokenizer.encode(text)
self.context = context
if return_text:
return text
return tokens, used_tokens, used_tokens+self.genamt
def __setattr__(self, name, value):
@@ -493,13 +495,16 @@ class model_settings(settings):
if self.tqdm.format_dict['rate'] is not None:
self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_layers-self.loaded_layers)/self.tqdm.format_dict['rate'])))
#Setup TQDP for model downloading
elif name == "total_download_chunks" and 'tqdm' in self.__dict__:
self.tqdm.reset(total=value)
self.tqdm_progress = 0
elif name == "downloaded_chunks" and 'tqdm' in self.__dict__:
if value == 0:
self.tqdm.reset(total=self.total_download_chunks)
self.tqdm_progress = 0
else:
self.tqdm.update(value-old_value)
self.tqdm_progress = round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1)
self.tqdm_progress = 0 if self.total_download_chunks==0 else round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1)
if self.tqdm.format_dict['rate'] is not None:
self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_download_chunks-self.downloaded_chunks)/self.tqdm.format_dict['rate'])))
@@ -738,7 +743,6 @@ class system_settings(settings):
self.userscripts = [] # List of userscripts to load
self.last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
self.corescript = "default.lua" # Filename of corescript to load
self.gpu_device = 0 # Which PyTorch device to use when using pure GPU generation
self.savedir = os.getcwd()+"\\stories"
self.hascuda = False # Whether torch has detected CUDA on the system
@@ -794,6 +798,8 @@ class system_settings(settings):
print("Colab Check: {}".format(self.on_colab))
self.horde_share = False
self._horde_pid = None
self.sh_apikey = "" # API key to use for txt2img from the Stable Horde.
self.generating_image = False #The current status of image generation
self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost
@@ -877,6 +883,8 @@ class KoboldStoryRegister(object):
temp = [self.actions[x]["Selected Text"] for x in list(self.actions)[i]]
return temp
else:
if i < 0:
return self.actions[self.action_count+i+1]["Selected Text"]
return self.actions[i]["Selected Text"]
def __setitem__(self, i, text):

View File

@@ -15,4 +15,5 @@ accelerate
flask_session
marshmallow>=3.13
apispec-webframeworks
loguru
loguru
Pillow

View File

@@ -19,4 +19,5 @@ bleach==4.1.0
flask-session
marshmallow>=3.13
apispec-webframeworks
loguru
loguru
Pillow

View File

@@ -2201,6 +2201,10 @@ button.disabled {
color: red;
}
.italics {
font-style: italic;
}
.within_max_length {
color: var(--text_to_ai_color);
font-weight: bold;
@@ -2370,4 +2374,14 @@ h2 .material-icons-outlined {
input[type='range'] {
border: none !important;
}
.settings_button[system_generating_image="true"] {
filter: brightness(35%);
cursor: not-allowed;
pointer-events:none;
}
.action_image {
width: var(--flyout_menu_width);
}

View File

@@ -29,6 +29,7 @@ socket.on('load_cookies', function(data){load_cookies(data)});
socket.on('load_tweaks', function(data){load_tweaks(data);});
socket.on("wi_results", updateWISearchListings);
socket.on("request_prompt_config", configurePrompt);
socket.on("Action_Image", function(data){Action_Image(data);});
//socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});});
var presets = {};
@@ -1429,8 +1430,8 @@ function load_model() {
for (item of document.getElementById("oaimodel").selectedOptions) {
selected_models.push(item.value);
}
if (selected_models == []) {
selected_models = "";
if (selected_models == ['']) {
selected_models = [];
} else if (selected_models.length == 1) {
selected_models = selected_models[0];
}
@@ -1958,6 +1959,18 @@ function load_cookies(data) {
}
}
function Action_Image(data) {
var image = new Image();
image.src = 'data:image/png;base64,'+data['b64'];
image.setAttribute("title", data['prompt']);
image.classList.add("action_image");
image_area = document.getElementById("action image");
while (image_area.firstChild) {
image_area.removeChild(image_area.firstChild);
}
image_area.appendChild(image);
}
//--------------------------------------------UI to Server Functions----------------------------------
function unload_userscripts() {
files_to_unload = document.getElementById('loaded_userscripts');
@@ -2917,7 +2930,8 @@ function assign_world_info_to_action(action_item, uid) {
//console.log(null);
var before_span = document.createElement("span");
before_span.textContent = before_highlight_text;
var hightlight_span = document.createElement("i");
var hightlight_span = document.createElement("span");
hightlight_span.classList.add("italics");
hightlight_span.textContent = highlight_text;
hightlight_span.title = worldinfo['content'];
var after_span = document.createElement("span");
@@ -2977,7 +2991,8 @@ function assign_world_info_to_action(action_item, uid) {
//console.log(null);
var before_span = document.createElement("span");
before_span.textContent = before_highlight_text;
var hightlight_span = document.createElement("i");
var hightlight_span = document.createElement("span");
hightlight_span.classList.add("italics");
hightlight_span.textContent = highlight_text;
hightlight_span.title = worldinfo['content'];
var after_span = document.createElement("span");

View File

@@ -105,6 +105,10 @@
</div>
<span id="debug-dump" class="cursor" onclick="document.getElementById('debug-file-container').classList.remove('hidden');">Download debug dump</span>
<div id="Images">
<button class="settings_button var_sync_alt_system_generating_image" onclick="socket.emit('generate_image', {})">Generate Image</button>
<div id="action image"></div>
</div>
</div>
<div id="setting_menu_settings" class="hidden settings_category_area tab-target tab-target-settings">
<div class="preset_area">

View File

@@ -190,6 +190,7 @@ class Send_to_socketio(object):
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
import transformers
lengths = {}
path = None
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
@@ -207,11 +208,10 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
if bar is not None:
bar.n = bar.total
bar.close()
koboldai_vars.downloaded_chunks = bar.total
p.terminate()
done = True
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
visited = set()
for x in r:
filename = x["files"][0]["path"]
@@ -220,7 +220,11 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
for k, v in lengths.items():
if k not in visited:
lengths[k] = (v[1], v[1])
bar.n = sum(v[0] for v in lengths.values())
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values())
koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values())
bar.n = koboldai_vars.downloaded_chunks
bar.update()
time.sleep(0.1)
path = f.name
@@ -229,8 +233,9 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
raise e
finally:
try:
if os.path.exists(path):
os.remove(path)
if path is not None:
if os.path.exists(path):
os.remove(path)
except OSError:
pass
code = p.wait()