From 1bb2d2621cec85f924a123c1fb179a79127aed6f Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 3 Jul 2023 17:12:07 -0500 Subject: [PATCH] Make TPU in line with new lazyload behavior --- tpu_mtj_backend.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index ec69f66d..a5fd9d69 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1196,6 +1196,7 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword if utils.num_shards is not None: utils.current_shard += 1 + for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None) @@ -1210,31 +1211,16 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword koboldai_vars.loaded_layers += 1 continue - storage_key = model_dict[key].key - if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: - last_storage_key = storage_key - if isinstance(f, zipfile.ZipExtFile): - f.close() - try: - f = z.open(f"archive/data/{storage_key}") - except: - f = z.open(f"{zipfolder}/data/{storage_key}") - current_offset = 0 - if current_offset != model_dict[key].seek_offset: - f.read(model_dict[key].seek_offset - current_offset) - current_offset = model_dict[key].seek_offset spec = model_spec[model_spec_key] transforms = set(spec.get("transforms", ())) + if not isinstance(model_dict[key], lazy_loader.LazyTensor): error = f"Duplicate key {repr(key)}" print("\n\nERROR: " + error, file=sys.stderr) raise RuntimeError(error) - size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1) - dtype = model_dict[key].dtype - nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) - tensor = model_dict[key].materialize(f, map_location="cpu") + + tensor = model_dict[key].materialize(map_location="cpu") model_dict[key] = tensor.to("meta") - current_offset += nbytes # MTJ requires certain mathematical operations to be performed # on tensors in order for them to be in the correct format