mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for web based model loading
This commit is contained in:
16
aiserver.py
16
aiserver.py
@@ -53,6 +53,7 @@ from utils import debounce
|
||||
import utils
|
||||
import structures
|
||||
import torch
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||
@@ -321,6 +322,12 @@ class vars:
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
print(bar, end="")
|
||||
time.sleep(0.01)
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar}, broadcast=True)
|
||||
|
||||
# Set logging level to reduce chatter from Flask
|
||||
import logging
|
||||
log = logging.getLogger('werkzeug')
|
||||
@@ -986,6 +993,10 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||
vars.noai = False
|
||||
if not initial_load:
|
||||
set_aibusy(True)
|
||||
if vars.model != 'ReadOnly':
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': "Loading {}".format(vars.model)}, broadcast=True)
|
||||
#Have to add a sleep so the server will send the emit for some reason
|
||||
time.sleep(0.1)
|
||||
if gpu_layers is not None:
|
||||
args.breakmodel_gpulayers = gpu_layers
|
||||
|
||||
@@ -1254,7 +1265,7 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||
else:
|
||||
num_tensors = len(device_map)
|
||||
print(flush=True)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
@@ -1871,6 +1882,9 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||
final_startup()
|
||||
if not initial_load:
|
||||
set_aibusy(False)
|
||||
print("Sending model window close")
|
||||
emit('from_server', {'cmd': 'hide_model_name'}, broadcast=True)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
# Set up Flask routes
|
||||
|
Reference in New Issue
Block a user