Merge pull request #112 from VE-FORBRYDERNE/lazy-loader

Fix lazy loader
This commit is contained in:
henk717 2022-04-09 03:16:30 +02:00 committed by GitHub
commit a060219ff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -30,6 +30,7 @@ import markdown
import bleach import bleach
import itertools import itertools
import bisect import bisect
import functools
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
@ -1092,6 +1093,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
try: try:
last_storage_key = None last_storage_key = None
f = None f = None
current_offset = 0
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
storage_key = model_dict[key].key storage_key = model_dict[key].key
if storage_key != last_storage_key: if storage_key != last_storage_key:
@ -1099,10 +1101,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if isinstance(f, zipfile.ZipExtFile): if isinstance(f, zipfile.ZipExtFile):
f.close() f.close()
f = z.open(f"archive/data/{storage_key}") f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell() current_offset = 0
if current_offset != model_dict[key].seek_offset: if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1) f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
device = device_map[key] device = device_map[key]
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
dtype = model_dict[key].dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location="cpu") model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32: if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
@ -1111,6 +1117,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
model_dict[key] = model_dict[key].to(torch.float32) model_dict[key] = model_dict[key].to(torch.float32)
model_dict[key] = model_dict[key].to(device) model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True) #print("OK", flush=True)
current_offset += nbytes
finally: finally:
if isinstance(f, zipfile.ZipExtFile): if isinstance(f, zipfile.ZipExtFile):
f.close() f.close()

View File

@ -1176,7 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
continue continue
storage_key = model_dict[key].key storage_key = model_dict[key].key
if storage_key != last_storage_key: if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
last_storage_key = storage_key last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile): if isinstance(f, zipfile.ZipExtFile):
f.close() f.close()