Merge pull request #112 from VE-FORBRYDERNE/lazy-loader
Fix lazy loader
This commit is contained in:
commit
a060219ff7
11
aiserver.py
11
aiserver.py
|
@ -30,6 +30,7 @@ import markdown
|
|||
import bleach
|
||||
import itertools
|
||||
import bisect
|
||||
import functools
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
|
@ -1092,6 +1093,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
try:
|
||||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key:
|
||||
|
@ -1099,10 +1101,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
current_offset = f.tell()
|
||||
current_offset = 0
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
||||
f.read(model_dict[key].seek_offset - current_offset)
|
||||
current_offset = model_dict[key].seek_offset
|
||||
device = device_map[key]
|
||||
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
|
||||
dtype = model_dict[key].dtype
|
||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
|
||||
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
|
||||
|
@ -1111,6 +1117,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
#print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
finally:
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
|
|
@ -1176,7 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
continue
|
||||
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key:
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
|
Loading…
Reference in New Issue