mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add Original UI status for TPU connection/loading
This commit is contained in:
@@ -283,7 +283,7 @@ class Send_to_socketio(object):
|
||||
print('\r' + bar, end='')
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -1799,7 +1799,7 @@ def patch_transformers_download():
|
||||
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
|
||||
try:
|
||||
print('\r' + bar, end='')
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
eventlet.sleep(seconds=0)
|
||||
except:
|
||||
pass
|
||||
@@ -3061,12 +3061,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
global tpu_mtj_backend
|
||||
import tpu_mtj_backend
|
||||
|
||||
tpu_mtj_backend.socketio = socketio
|
||||
if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"):
|
||||
koboldai_vars.badwordsids = koboldai_vars.badwordsids_neox
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
if koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not koboldai_vars.custmodpth or not os.path.isdir(koboldai_vars.custmodpth)):
|
||||
raise FileNotFoundError(f"The specified model path {repr(koboldai_vars.custmodpth)} is not the path to a valid folder")
|
||||
import tpu_mtj_backend
|
||||
if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"):
|
||||
tpu_mtj_backend.pad_token_id = 2
|
||||
tpu_mtj_backend.koboldai_vars = koboldai_vars
|
||||
|
@@ -54,6 +54,8 @@ from mesh_transformer.util import to_bf16
|
||||
import time
|
||||
|
||||
|
||||
tqdm_print = None
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
__seed = random.randrange(2**64)
|
||||
@@ -116,10 +118,29 @@ def show_spinner():
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
i = 0
|
||||
while True:
|
||||
if i % 2 == 0:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||
else:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
|
||||
bar.update(i)
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "")
|
||||
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
|
||||
#logger.info(bar)
|
||||
print('\r' + bar, end='')
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
except:
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
__F = TypeVar("__F", bound=Callable)
|
||||
__T = TypeVar("__T")
|
||||
@@ -991,7 +1012,10 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
}
|
||||
|
||||
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
|
||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
||||
if socketio is None:
|
||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
||||
else:
|
||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=Send_to_socketio)
|
||||
koboldai_vars.status_message = "Loading TPU"
|
||||
koboldai_vars.total_layers = tqdm_length
|
||||
koboldai_vars.loaded_layers = 0
|
||||
@@ -1269,7 +1293,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
num_tensors = len(model_dict)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
if socketio is None:
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
else:
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio)
|
||||
koboldai_vars.status_message = "Loading model"
|
||||
koboldai_vars.loaded_layers = 0
|
||||
koboldai_vars.total_layers = num_tensors
|
||||
|
Reference in New Issue
Block a user