mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'UI2' of https://github.com/ebolam/KoboldAI into UI2
This commit is contained in:
162
aiserver.py
162
aiserver.py
@@ -64,7 +64,7 @@ from utils import debounce
|
||||
import utils
|
||||
import koboldai_settings
|
||||
import torch
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
|
||||
from transformers import __version__ as transformers_version
|
||||
import transformers
|
||||
try:
|
||||
@@ -73,6 +73,11 @@ except:
|
||||
pass
|
||||
import transformers.generation_utils
|
||||
|
||||
# Text2img
|
||||
import base64
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
|
||||
from io import BytesIO
|
||||
|
||||
global tpu_mtj_backend
|
||||
|
||||
|
||||
@@ -1618,8 +1623,8 @@ def get_cluster_models(msg):
|
||||
# If the client settings file doesn't exist, create it
|
||||
# Write API key to file
|
||||
os.makedirs('settings', exist_ok=True)
|
||||
if path.exists(get_config_filename(koboldai_vars.model_selected)):
|
||||
with open(get_config_filename(koboldai_vars.model_selected), "r") as file:
|
||||
if path.exists(get_config_filename(model)):
|
||||
with open(get_config_filename(model), "r") as file:
|
||||
js = json.load(file)
|
||||
if 'online_model' in js:
|
||||
online_model = js['online_model']
|
||||
@@ -1630,7 +1635,7 @@ def get_cluster_models(msg):
|
||||
changed=True
|
||||
if changed:
|
||||
js={}
|
||||
with open(get_config_filename(koboldai_vars.model_selected), "w") as file:
|
||||
with open(get_config_filename(model), "w") as file:
|
||||
js["apikey"] = koboldai_vars.oaiapikey
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
@@ -1674,7 +1679,7 @@ def patch_transformers_download():
|
||||
|
||||
if bar != "":
|
||||
try:
|
||||
print(bar, end="\r")
|
||||
print(bar, end="\n")
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
eventlet.sleep(seconds=0)
|
||||
except:
|
||||
@@ -1712,10 +1717,12 @@ def patch_transformers_download():
|
||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||
file=Send_to_socketio(),
|
||||
)
|
||||
koboldai_vars.total_download_chunks = total
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
if url[-11:] != 'config.json':
|
||||
progress.update(len(chunk))
|
||||
koboldai_vars.downloaded_chunks += len(chunk)
|
||||
temp_file.write(chunk)
|
||||
if url[-11:] != 'config.json':
|
||||
progress.close()
|
||||
@@ -1768,6 +1775,8 @@ def patch_transformers_download():
|
||||
|
||||
def patch_transformers():
|
||||
global transformers
|
||||
global old_transfomers_functions
|
||||
old_transfomers_functions = {}
|
||||
|
||||
patch_transformers_download()
|
||||
|
||||
@@ -1784,9 +1793,11 @@ def patch_transformers():
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
old_transfomers_functions['PreTrainedModel.from_pretrained'] = PreTrainedModel.from_pretrained
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
@@ -1814,6 +1825,7 @@ def patch_transformers():
|
||||
if max_pos > self.weights.size(0):
|
||||
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
|
||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||
|
||||
|
||||
@@ -1833,6 +1845,7 @@ def patch_transformers():
|
||||
self.model = OPTModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
|
||||
OPTForCausalLM.__init__ = new_init
|
||||
|
||||
|
||||
@@ -2117,6 +2130,7 @@ def patch_transformers():
|
||||
break
|
||||
return self.regeneration_required or self.halt
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
global tokenizer
|
||||
@@ -2171,7 +2185,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
if not utils.HAS_ACCELERATE:
|
||||
disk_layers = None
|
||||
koboldai_vars.reset_model()
|
||||
koboldai_vars.cluster_requested_models = online_model
|
||||
koboldai_vars.cluster_requested_models = [online_model] if isinstance(online_model, str) else online_model
|
||||
koboldai_vars.noai = False
|
||||
if not use_breakmodel_args:
|
||||
set_aibusy(True)
|
||||
@@ -8134,6 +8148,138 @@ def get_model_size(model_name):
|
||||
def UI_2_save_revision(data):
|
||||
koboldai_vars.save_revision()
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Generate Image
|
||||
#==================================================================#
|
||||
@socketio.on("generate_image")
|
||||
def UI_2_generate_image(data):
|
||||
koboldai_vars.generating_image = True
|
||||
#get latest action
|
||||
if len(koboldai_vars.actions) > 0:
|
||||
action = koboldai_vars.actions[-1]
|
||||
else:
|
||||
action = koboldai_vars.prompt
|
||||
#Get matching world info entries
|
||||
keys = []
|
||||
for wi in koboldai_vars.worldinfo_v2:
|
||||
for key in wi['key']:
|
||||
if key in action:
|
||||
#Check to make sure secondary keys are present if needed
|
||||
if len(wi['keysecondary']) > 0:
|
||||
for keysecondary in wi['keysecondary']:
|
||||
if keysecondary in action:
|
||||
keys.append(key)
|
||||
break
|
||||
break
|
||||
else:
|
||||
keys.append(key)
|
||||
break
|
||||
|
||||
|
||||
#If we have > 4 keys, use those otherwise use sumarization
|
||||
if len(keys) < 4:
|
||||
from transformers import pipeline as summary_pipeline
|
||||
summarizer = summary_pipeline("summarization")
|
||||
#text to summarize:
|
||||
if len(koboldai_vars.actions) < 5:
|
||||
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
|
||||
else:
|
||||
text = "".join(koboldai_vars.actions[:-5])
|
||||
global old_transfomers_functions
|
||||
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
|
||||
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']]
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
|
||||
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
|
||||
b64_data = text2img(", ".join(keys), art_guide = art_guide)
|
||||
emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
|
||||
|
||||
|
||||
@logger.catch
|
||||
def text2img(prompt,
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
filename = "story_art.png"):
|
||||
print("Generating Image")
|
||||
koboldai_vars.generating_image = True
|
||||
final_imgen_params = {
|
||||
"n": 1,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"steps": 50,
|
||||
}
|
||||
|
||||
final_submit_dict = {
|
||||
"prompt": "{}, {}".format(prompt, art_guide),
|
||||
"api_key": koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",
|
||||
"params": final_imgen_params,
|
||||
}
|
||||
logger.debug(final_submit_dict)
|
||||
submit_req = requests.post('https://stablehorde.net/api/v1/generate/sync', json = final_submit_dict)
|
||||
if submit_req.ok:
|
||||
results = submit_req.json()
|
||||
for iter in range(len(results)):
|
||||
b64img = results[iter]["img"]
|
||||
base64_bytes = b64img.encode('utf-8')
|
||||
img_bytes = base64.b64decode(base64_bytes)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if len(results) > 1:
|
||||
final_filename = f"{iter}_{filename}"
|
||||
else:
|
||||
final_filename = filename
|
||||
img.save(final_filename)
|
||||
print("Saved Image")
|
||||
koboldai_vars.generating_image = False
|
||||
return(b64img)
|
||||
else:
|
||||
koboldai_vars.generating_image = False
|
||||
print(submit_req.text)
|
||||
|
||||
def get_items_locations_from_text(text):
|
||||
# load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
|
||||
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
|
||||
nlp = transformers.pipeline("ner", model=model, tokenizer=tokenizer)
|
||||
# input example sentence
|
||||
ner_results = nlp(text)
|
||||
orgs = []
|
||||
last_org_position = -2
|
||||
loc = []
|
||||
last_loc_position = -2
|
||||
per = []
|
||||
last_per_position = -2
|
||||
for i, result in enumerate(ner_results):
|
||||
if result['entity'] in ('B-ORG', 'I-ORG'):
|
||||
if result['start']-1 <= last_org_position:
|
||||
if result['start'] != last_org_position:
|
||||
orgs[-1] = "{} ".format(orgs[-1])
|
||||
orgs[-1] = "{}{}".format(orgs[-1], result['word'].replace("##", ""))
|
||||
else:
|
||||
orgs.append(result['word'])
|
||||
last_org_position = result['end']
|
||||
elif result['entity'] in ('B-LOC', 'I-LOC'):
|
||||
if result['start']-1 <= last_loc_position:
|
||||
if result['start'] != last_loc_position:
|
||||
loc[-1] = "{} ".format(loc[-1])
|
||||
loc[-1] = "{}{}".format(loc[-1], result['word'].replace("##", ""))
|
||||
else:
|
||||
loc.append(result['word'])
|
||||
last_loc_position = result['end']
|
||||
elif result['entity'] in ('B-PER', 'I-PER'):
|
||||
if result['start']-1 <= last_per_position:
|
||||
if result['start'] != last_per_position:
|
||||
per[-1] = "{} ".format(per[-1])
|
||||
per[-1] = "{}{}".format(per[-1], result['word'].replace("##", ""))
|
||||
else:
|
||||
per.append(result['word'])
|
||||
last_per_position = result['end']
|
||||
|
||||
print("Orgs: {}".format(orgs))
|
||||
print("Locations: {}".format(loc))
|
||||
print("People: {}".format(per))
|
||||
|
||||
#==================================================================#
|
||||
# Test
|
||||
#==================================================================#
|
||||
@@ -10919,6 +11065,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
cloudflare = str(localtunnel.stdout.readline())
|
||||
cloudflare = (re.search("(?P<url>https?:\/\/[^\s]+loca.lt)", cloudflare).group("url"))
|
||||
koboldai_vars.cloudflare_link = cloudflare
|
||||
break
|
||||
except:
|
||||
attempts += 1
|
||||
@@ -10928,12 +11075,15 @@ if __name__ == "__main__":
|
||||
print("LocalTunnel could not be created, falling back to cloudflare...")
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
cloudflare = _run_cloudflared(port)
|
||||
koboldai_vars.cloudflare_link = cloudflare
|
||||
elif(args.ngrok):
|
||||
from flask_ngrok import _run_ngrok
|
||||
cloudflare = _run_ngrok()
|
||||
koboldai_vars.cloudflare_link = cloudflare
|
||||
elif(args.remote):
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
cloudflare = _run_cloudflared(port)
|
||||
koboldai_vars.cloudflare_link = cloudflare
|
||||
if(args.localtunnel or args.ngrok or args.remote):
|
||||
with open('cloudflare.log', 'w') as cloudflarelog:
|
||||
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
|
||||
|
@@ -20,6 +20,7 @@ dependencies:
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- Pillow
|
||||
- pip:
|
||||
- flask-cloudflared
|
||||
- flask-ngrok
|
||||
|
@@ -17,6 +17,7 @@ dependencies:
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- Pillow
|
||||
- pip:
|
||||
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
|
||||
- torch==1.10.*
|
||||
|
@@ -111,7 +111,7 @@ class koboldai_vars(object):
|
||||
def reset_model(self):
|
||||
self._model_settings.reset_for_model_load()
|
||||
|
||||
def calc_ai_text(self, submitted_text="", method=2):
|
||||
def calc_ai_text(self, submitted_text="", method=2, return_text=False):
|
||||
context = []
|
||||
token_budget = self.max_length
|
||||
used_world_info = []
|
||||
@@ -285,6 +285,8 @@ class koboldai_vars(object):
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
self.context = context
|
||||
if return_text:
|
||||
return text
|
||||
return tokens, used_tokens, used_tokens+self.genamt
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@@ -493,13 +495,16 @@ class model_settings(settings):
|
||||
if self.tqdm.format_dict['rate'] is not None:
|
||||
self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_layers-self.loaded_layers)/self.tqdm.format_dict['rate'])))
|
||||
#Setup TQDP for model downloading
|
||||
elif name == "total_download_chunks" and 'tqdm' in self.__dict__:
|
||||
self.tqdm.reset(total=value)
|
||||
self.tqdm_progress = 0
|
||||
elif name == "downloaded_chunks" and 'tqdm' in self.__dict__:
|
||||
if value == 0:
|
||||
self.tqdm.reset(total=self.total_download_chunks)
|
||||
self.tqdm_progress = 0
|
||||
else:
|
||||
self.tqdm.update(value-old_value)
|
||||
self.tqdm_progress = round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1)
|
||||
self.tqdm_progress = 0 if self.total_download_chunks==0 else round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1)
|
||||
if self.tqdm.format_dict['rate'] is not None:
|
||||
self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_download_chunks-self.downloaded_chunks)/self.tqdm.format_dict['rate'])))
|
||||
|
||||
@@ -738,7 +743,6 @@ class system_settings(settings):
|
||||
self.userscripts = [] # List of userscripts to load
|
||||
self.last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||
self.corescript = "default.lua" # Filename of corescript to load
|
||||
|
||||
self.gpu_device = 0 # Which PyTorch device to use when using pure GPU generation
|
||||
self.savedir = os.getcwd()+"\\stories"
|
||||
self.hascuda = False # Whether torch has detected CUDA on the system
|
||||
@@ -794,6 +798,8 @@ class system_settings(settings):
|
||||
print("Colab Check: {}".format(self.on_colab))
|
||||
self.horde_share = False
|
||||
self._horde_pid = None
|
||||
self.sh_apikey = "" # API key to use for txt2img from the Stable Horde.
|
||||
self.generating_image = False #The current status of image generation
|
||||
self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost
|
||||
|
||||
|
||||
@@ -877,6 +883,8 @@ class KoboldStoryRegister(object):
|
||||
temp = [self.actions[x]["Selected Text"] for x in list(self.actions)[i]]
|
||||
return temp
|
||||
else:
|
||||
if i < 0:
|
||||
return self.actions[self.action_count+i+1]["Selected Text"]
|
||||
return self.actions[i]["Selected Text"]
|
||||
|
||||
def __setitem__(self, i, text):
|
||||
|
@@ -15,4 +15,5 @@ accelerate
|
||||
flask_session
|
||||
marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
loguru
|
||||
loguru
|
||||
Pillow
|
@@ -19,4 +19,5 @@ bleach==4.1.0
|
||||
flask-session
|
||||
marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
loguru
|
||||
loguru
|
||||
Pillow
|
@@ -2201,6 +2201,10 @@ button.disabled {
|
||||
color: red;
|
||||
}
|
||||
|
||||
.italics {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.within_max_length {
|
||||
color: var(--text_to_ai_color);
|
||||
font-weight: bold;
|
||||
@@ -2370,4 +2374,14 @@ h2 .material-icons-outlined {
|
||||
|
||||
input[type='range'] {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
.settings_button[system_generating_image="true"] {
|
||||
filter: brightness(35%);
|
||||
cursor: not-allowed;
|
||||
pointer-events:none;
|
||||
}
|
||||
|
||||
.action_image {
|
||||
width: var(--flyout_menu_width);
|
||||
}
|
@@ -29,6 +29,7 @@ socket.on('load_cookies', function(data){load_cookies(data)});
|
||||
socket.on('load_tweaks', function(data){load_tweaks(data);});
|
||||
socket.on("wi_results", updateWISearchListings);
|
||||
socket.on("request_prompt_config", configurePrompt);
|
||||
socket.on("Action_Image", function(data){Action_Image(data);});
|
||||
//socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});});
|
||||
|
||||
var presets = {};
|
||||
@@ -1429,8 +1430,8 @@ function load_model() {
|
||||
for (item of document.getElementById("oaimodel").selectedOptions) {
|
||||
selected_models.push(item.value);
|
||||
}
|
||||
if (selected_models == []) {
|
||||
selected_models = "";
|
||||
if (selected_models == ['']) {
|
||||
selected_models = [];
|
||||
} else if (selected_models.length == 1) {
|
||||
selected_models = selected_models[0];
|
||||
}
|
||||
@@ -1958,6 +1959,18 @@ function load_cookies(data) {
|
||||
}
|
||||
}
|
||||
|
||||
function Action_Image(data) {
|
||||
var image = new Image();
|
||||
image.src = 'data:image/png;base64,'+data['b64'];
|
||||
image.setAttribute("title", data['prompt']);
|
||||
image.classList.add("action_image");
|
||||
image_area = document.getElementById("action image");
|
||||
while (image_area.firstChild) {
|
||||
image_area.removeChild(image_area.firstChild);
|
||||
}
|
||||
image_area.appendChild(image);
|
||||
}
|
||||
|
||||
//--------------------------------------------UI to Server Functions----------------------------------
|
||||
function unload_userscripts() {
|
||||
files_to_unload = document.getElementById('loaded_userscripts');
|
||||
@@ -2917,7 +2930,8 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
//console.log(null);
|
||||
var before_span = document.createElement("span");
|
||||
before_span.textContent = before_highlight_text;
|
||||
var hightlight_span = document.createElement("i");
|
||||
var hightlight_span = document.createElement("span");
|
||||
hightlight_span.classList.add("italics");
|
||||
hightlight_span.textContent = highlight_text;
|
||||
hightlight_span.title = worldinfo['content'];
|
||||
var after_span = document.createElement("span");
|
||||
@@ -2977,7 +2991,8 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
//console.log(null);
|
||||
var before_span = document.createElement("span");
|
||||
before_span.textContent = before_highlight_text;
|
||||
var hightlight_span = document.createElement("i");
|
||||
var hightlight_span = document.createElement("span");
|
||||
hightlight_span.classList.add("italics");
|
||||
hightlight_span.textContent = highlight_text;
|
||||
hightlight_span.title = worldinfo['content'];
|
||||
var after_span = document.createElement("span");
|
||||
|
@@ -105,6 +105,10 @@
|
||||
|
||||
</div>
|
||||
<span id="debug-dump" class="cursor" onclick="document.getElementById('debug-file-container').classList.remove('hidden');">Download debug dump</span>
|
||||
<div id="Images">
|
||||
<button class="settings_button var_sync_alt_system_generating_image" onclick="socket.emit('generate_image', {})">Generate Image</button>
|
||||
<div id="action image"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="setting_menu_settings" class="hidden settings_category_area tab-target tab-target-settings">
|
||||
<div class="preset_area">
|
||||
|
15
utils.py
15
utils.py
@@ -190,6 +190,7 @@ class Send_to_socketio(object):
|
||||
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
|
||||
import transformers
|
||||
lengths = {}
|
||||
path = None
|
||||
s = requests.Session()
|
||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||
bar = None
|
||||
@@ -207,11 +208,10 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||
if bar is not None:
|
||||
bar.n = bar.total
|
||||
bar.close()
|
||||
koboldai_vars.downloaded_chunks = bar.total
|
||||
p.terminate()
|
||||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
@@ -220,7 +220,11 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||
for k, v in lengths.items():
|
||||
if k not in visited:
|
||||
lengths[k] = (v[1], v[1])
|
||||
bar.n = sum(v[0] for v in lengths.values())
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values())
|
||||
koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values())
|
||||
bar.n = koboldai_vars.downloaded_chunks
|
||||
bar.update()
|
||||
time.sleep(0.1)
|
||||
path = f.name
|
||||
@@ -229,8 +233,9 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
if path is not None:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
code = p.wait()
|
||||
|
Reference in New Issue
Block a user