mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge pull request #108 from VE-FORBRYDERNE/lazy-loader
Lazy loader Python 3.6 compatibility
This commit is contained in:
		@@ -95,7 +95,7 @@ class LazyTensor:
 | 
			
		||||
        nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
 | 
			
		||||
        if isinstance(checkpoint, zipfile.ZipFile):
 | 
			
		||||
            f = checkpoint.open(f"archive/data/{self.key}", "r")
 | 
			
		||||
            f.seek(self.seek_offset)
 | 
			
		||||
            f.read(self.seek_offset)
 | 
			
		||||
        else:
 | 
			
		||||
            f = checkpoint
 | 
			
		||||
        try:
 | 
			
		||||
 
 | 
			
		||||
@@ -887,6 +887,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
 | 
			
		||||
    output_shards = config["cores_per_replica"] // checkpoint_shards
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
    import torch.utils.dlpack
 | 
			
		||||
    from tqdm.auto import tqdm
 | 
			
		||||
 | 
			
		||||
    move_xmap = jax.experimental.maps.xmap(
 | 
			
		||||
@@ -1154,12 +1155,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
 | 
			
		||||
    import torch_lazy_loader
 | 
			
		||||
    import torch
 | 
			
		||||
    from tqdm.auto import tqdm
 | 
			
		||||
    import functools
 | 
			
		||||
 | 
			
		||||
    def callback(model_dict, f, **_):
 | 
			
		||||
        with zipfile.ZipFile(f, "r") as z:
 | 
			
		||||
            try:
 | 
			
		||||
                last_storage_key = None
 | 
			
		||||
                f = None
 | 
			
		||||
                current_offset = 0
 | 
			
		||||
                print("\n\n\nThis model has  ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), "  parameters.\n")
 | 
			
		||||
                for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
 | 
			
		||||
 | 
			
		||||
@@ -1178,17 +1181,22 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
 | 
			
		||||
                        if isinstance(f, zipfile.ZipExtFile):
 | 
			
		||||
                            f.close()
 | 
			
		||||
                        f = z.open(f"archive/data/{storage_key}")
 | 
			
		||||
                    current_offset = f.tell()
 | 
			
		||||
                        current_offset = 0
 | 
			
		||||
                    if current_offset != model_dict[key].seek_offset:
 | 
			
		||||
                        f.seek(model_dict[key].seek_offset - current_offset, 1)
 | 
			
		||||
                        f.read(model_dict[key].seek_offset - current_offset)
 | 
			
		||||
                        current_offset = model_dict[key].seek_offset
 | 
			
		||||
                    spec = model_spec[key]
 | 
			
		||||
                    transforms = set(spec.get("transforms", ()))
 | 
			
		||||
                    if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
 | 
			
		||||
                        error = f"Duplicate key {repr(key)}"
 | 
			
		||||
                        print("\n\nERROR:  " + error, file=sys.stderr)
 | 
			
		||||
                        raise RuntimeError(error)
 | 
			
		||||
                    size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
 | 
			
		||||
                    dtype = model_dict[key].dtype
 | 
			
		||||
                    nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
 | 
			
		||||
                    tensor = model_dict[key].materialize(f, map_location="cpu")
 | 
			
		||||
                    model_dict[key] = tensor.to("meta")
 | 
			
		||||
                    current_offset += nbytes
 | 
			
		||||
 | 
			
		||||
                    # MTJ requires certain mathematical operations to be performed
 | 
			
		||||
                    # on tensors in order for them to be in the correct format
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user