mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Allow TPU Colab to load sharded HF models
This commit is contained in:
@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import utils
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
if utils.current_shard == 0:
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||
return
|
||||
|
||||
# Check for tensors that MTJ needs that were not provided in the
|
||||
# HF model
|
||||
for mk, mv in network.state["params"].items():
|
||||
|
Reference in New Issue
Block a user