Add Original UI status for TPU connection/loading

This commit is contained in:
ebolam
2022-10-24 18:52:53 -04:00
parent 4dd7ba383d
commit 47c832fde4
2 changed files with 33 additions and 5 deletions

View File

@@ -283,7 +283,7 @@ class Send_to_socketio(object):
print('\r' + bar, end='')
time.sleep(0.01)
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
except:
pass
@@ -1799,7 +1799,7 @@ def patch_transformers_download():
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
try:
print('\r' + bar, end='')
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
eventlet.sleep(seconds=0)
except:
pass
@@ -3061,12 +3061,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
global tpu_mtj_backend
import tpu_mtj_backend
tpu_mtj_backend.socketio = socketio
if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"):
koboldai_vars.badwordsids = koboldai_vars.badwordsids_neox
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not koboldai_vars.custmodpth or not os.path.isdir(koboldai_vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(koboldai_vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"):
tpu_mtj_backend.pad_token_id = 2
tpu_mtj_backend.koboldai_vars = koboldai_vars

View File

@@ -54,6 +54,8 @@ from mesh_transformer.util import to_bf16
import time
tqdm_print = None
params: Dict[str, Any] = {}
__seed = random.randrange(2**64)
@@ -116,10 +118,29 @@ def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
i = 0
while True:
if i % 2 == 0:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
else:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
bar.update(i)
time.sleep(0.1)
i += 1
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "")
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
#logger.info(bar)
print('\r' + bar, end='')
time.sleep(0.01)
try:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
except:
pass
def flush(self):
pass
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
@@ -991,7 +1012,10 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
}
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
if socketio is None:
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
else:
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=Send_to_socketio)
koboldai_vars.status_message = "Loading TPU"
koboldai_vars.total_layers = tqdm_length
koboldai_vars.loaded_layers = 0
@@ -1269,7 +1293,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else:
num_tensors = len(model_dict)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
if socketio is None:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
else:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio)
koboldai_vars.status_message = "Loading model"
koboldai_vars.loaded_layers = 0
koboldai_vars.total_layers = num_tensors