mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-26 00:17:41 +01:00
Merge pull request #112 from VE-FORBRYDERNE/lazy-loader
Fix lazy loader
This commit is contained in:
commit
a060219ff7
11
aiserver.py
11
aiserver.py
@ -30,6 +30,7 @@ import markdown
|
|||||||
import bleach
|
import bleach
|
||||||
import itertools
|
import itertools
|
||||||
import bisect
|
import bisect
|
||||||
|
import functools
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||||
|
|
||||||
@ -1092,6 +1093,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
try:
|
try:
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
f = None
|
f = None
|
||||||
|
current_offset = 0
|
||||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||||
storage_key = model_dict[key].key
|
storage_key = model_dict[key].key
|
||||||
if storage_key != last_storage_key:
|
if storage_key != last_storage_key:
|
||||||
@ -1099,10 +1101,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
f = z.open(f"archive/data/{storage_key}")
|
f = z.open(f"archive/data/{storage_key}")
|
||||||
current_offset = f.tell()
|
current_offset = 0
|
||||||
if current_offset != model_dict[key].seek_offset:
|
if current_offset != model_dict[key].seek_offset:
|
||||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
f.read(model_dict[key].seek_offset - current_offset)
|
||||||
|
current_offset = model_dict[key].seek_offset
|
||||||
device = device_map[key]
|
device = device_map[key]
|
||||||
|
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
|
||||||
|
dtype = model_dict[key].dtype
|
||||||
|
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||||
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||||
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
|
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
|
||||||
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
|
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
|
||||||
@ -1111,6 +1117,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
model_dict[key] = model_dict[key].to(torch.float32)
|
model_dict[key] = model_dict[key].to(torch.float32)
|
||||||
model_dict[key] = model_dict[key].to(device)
|
model_dict[key] = model_dict[key].to(device)
|
||||||
#print("OK", flush=True)
|
#print("OK", flush=True)
|
||||||
|
current_offset += nbytes
|
||||||
finally:
|
finally:
|
||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
|
@ -1176,7 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
storage_key = model_dict[key].key
|
storage_key = model_dict[key].key
|
||||||
if storage_key != last_storage_key:
|
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||||
last_storage_key = storage_key
|
last_storage_key = storage_key
|
||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user