mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
TPUMTJ: Fix loading bar
I don't know why it works but I know it works
This commit is contained in:
@@ -712,7 +712,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
if socketio is None:
|
if socketio is None:
|
||||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
||||||
else:
|
else:
|
||||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile())
|
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile(socketio.emit))
|
||||||
koboldai_vars.status_message = "Loading TPU"
|
koboldai_vars.status_message = "Loading TPU"
|
||||||
koboldai_vars.total_layers = tqdm_length
|
koboldai_vars.total_layers = tqdm_length
|
||||||
koboldai_vars.loaded_layers = 0
|
koboldai_vars.loaded_layers = 0
|
||||||
@@ -1021,10 +1021,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||||
else:
|
else:
|
||||||
num_tensors = len(model_dict)
|
num_tensors = len(model_dict)
|
||||||
|
|
||||||
if socketio is None:
|
if socketio is None:
|
||||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||||
else:
|
else:
|
||||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile())
|
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile(socketio.emit))
|
||||||
koboldai_vars.status_message = "Loading model"
|
koboldai_vars.status_message = "Loading model"
|
||||||
koboldai_vars.loaded_layers = 0
|
koboldai_vars.loaded_layers = 0
|
||||||
koboldai_vars.total_layers = num_tensors
|
koboldai_vars.total_layers = num_tensors
|
||||||
|
7
utils.py
7
utils.py
@@ -630,6 +630,8 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s
|
|||||||
|
|
||||||
class UIProgressBarFile(object):
|
class UIProgressBarFile(object):
|
||||||
"""Write TQDM progress to the UI."""
|
"""Write TQDM progress to the UI."""
|
||||||
|
def __init__(self, emit_func=emit) -> None:
|
||||||
|
self.emit_func = emit_func
|
||||||
|
|
||||||
def write(self, bar):
|
def write(self, bar):
|
||||||
bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "")
|
bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "")
|
||||||
@@ -638,8 +640,9 @@ class UIProgressBarFile(object):
|
|||||||
print('\r' + bar, end='')
|
print('\r' + bar, end='')
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
try:
|
try:
|
||||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
self.emit_func('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
|
Reference in New Issue
Block a user