Merge pull request #112 from VE-FORBRYDERNE/lazy-loader

Fix lazy loader
This commit is contained in:
henk717 2022-04-09 03:16:30 +02:00 committed by GitHub
commit a060219ff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -30,6 +30,7 @@ import markdown
import bleach
import itertools
import bisect
import functools
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
@ -1092,6 +1093,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
try:
last_storage_key = None
f = None
current_offset = 0
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
storage_key = model_dict[key].key
if storage_key != last_storage_key:
@ -1099,10 +1101,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
current_offset = 0
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
device = device_map[key]
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
dtype = model_dict[key].dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
@ -1111,6 +1117,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
model_dict[key] = model_dict[key].to(torch.float32)
model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True)
current_offset += nbytes
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()

View File

@ -1176,7 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
continue
storage_key = model_dict[key].key
if storage_key != last_storage_key:
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()