Fix for web based model loading
This commit is contained in:
parent
1e139594a9
commit
c984f4412d
16
aiserver.py
16
aiserver.py
|
@ -53,6 +53,7 @@ from utils import debounce
|
||||||
import utils
|
import utils
|
||||||
import structures
|
import structures
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||||
|
@ -321,6 +322,12 @@ class vars:
|
||||||
|
|
||||||
utils.vars = vars
|
utils.vars = vars
|
||||||
|
|
||||||
|
class Send_to_socketio(object):
|
||||||
|
def write(self, bar):
|
||||||
|
print(bar, end="")
|
||||||
|
time.sleep(0.01)
|
||||||
|
emit('from_server', {'cmd': 'model_load_status', 'data': bar}, broadcast=True)
|
||||||
|
|
||||||
# Set logging level to reduce chatter from Flask
|
# Set logging level to reduce chatter from Flask
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger('werkzeug')
|
log = logging.getLogger('werkzeug')
|
||||||
|
@ -986,6 +993,10 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||||
vars.noai = False
|
vars.noai = False
|
||||||
if not initial_load:
|
if not initial_load:
|
||||||
set_aibusy(True)
|
set_aibusy(True)
|
||||||
|
if vars.model != 'ReadOnly':
|
||||||
|
emit('from_server', {'cmd': 'model_load_status', 'data': "Loading {}".format(vars.model)}, broadcast=True)
|
||||||
|
#Have to add a sleep so the server will send the emit for some reason
|
||||||
|
time.sleep(0.1)
|
||||||
if gpu_layers is not None:
|
if gpu_layers is not None:
|
||||||
args.breakmodel_gpulayers = gpu_layers
|
args.breakmodel_gpulayers = gpu_layers
|
||||||
|
|
||||||
|
@ -1254,7 +1265,7 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||||
else:
|
else:
|
||||||
num_tensors = len(device_map)
|
num_tensors = len(device_map)
|
||||||
print(flush=True)
|
print(flush=True)
|
||||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
|
||||||
|
|
||||||
with zipfile.ZipFile(f, "r") as z:
|
with zipfile.ZipFile(f, "r") as z:
|
||||||
try:
|
try:
|
||||||
|
@ -1871,6 +1882,9 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||||
final_startup()
|
final_startup()
|
||||||
if not initial_load:
|
if not initial_load:
|
||||||
set_aibusy(False)
|
set_aibusy(False)
|
||||||
|
print("Sending model window close")
|
||||||
|
emit('from_server', {'cmd': 'hide_model_name'}, broadcast=True)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
|
|
|
@ -2541,10 +2541,11 @@ $(document).ready(function(){
|
||||||
$("#showmodelnamecontainer").removeClass("hidden");
|
$("#showmodelnamecontainer").removeClass("hidden");
|
||||||
} else if(msg.cmd == 'hide_model_name') {
|
} else if(msg.cmd == 'hide_model_name') {
|
||||||
$("#showmodelnamecontainer").addClass("hidden");
|
$("#showmodelnamecontainer").addClass("hidden");
|
||||||
|
//console.log("Closing window");
|
||||||
} else if(msg.cmd == 'model_load_status') {
|
} else if(msg.cmd == 'model_load_status') {
|
||||||
$("#showmodelnamecontent").html("<div class=\"flex\"><div class=\"loadlistpadding\"></div><div class=\"loadlistitem\" style='align: left'>" + msg.data + "</div></div>");
|
$("#showmodelnamecontent").html("<div class=\"flex\"><div class=\"loadlistpadding\"></div><div class=\"loadlistitem\" style='align: left'>" + msg.data + "</div></div>");
|
||||||
$("#showmodelnamecontainer").removeClass("hidden");
|
$("#showmodelnamecontainer").removeClass("hidden");
|
||||||
console.log(msg.data);
|
//console.log(msg.data);
|
||||||
} else if(msg.cmd == 'oai_engines') {
|
} else if(msg.cmd == 'oai_engines') {
|
||||||
RemoveAllButFirstOption($("#oaimodel")[0]);
|
RemoveAllButFirstOption($("#oaimodel")[0]);
|
||||||
for (const engine of msg.data) {
|
for (const engine of msg.data) {
|
||||||
|
|
Loading…
Reference in New Issue