mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #128 from VE-FORBRYDERNE/opt
OPT breakmodel and TPU support
This commit is contained in:
20
utils.py
20
utils.py
@ -6,8 +6,11 @@ import subprocess
|
||||
import tempfile
|
||||
import requests
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
vars = None
|
||||
num_shards: Optional[int] = None
|
||||
current_shard = 0
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
@ -135,7 +138,13 @@ def decodenewlines(txt):
|
||||
return txt
|
||||
|
||||
#==================================================================#
|
||||
# Downloads sharded huggingface checkpoints using aria2c if possible
|
||||
# Returns number of layers given an HF model config
|
||||
#==================================================================#
|
||||
def num_layers(config):
|
||||
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers
|
||||
|
||||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
#==================================================================#
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers
|
||||
@ -227,3 +236,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||
json.dump({"url": u, "etag": t}, f)
|
||||
|
||||
#==================================================================#
|
||||
# Given the path to a pytorch_model.bin.index.json, returns how many
|
||||
# shards there are in the model
|
||||
#==================================================================#
|
||||
def get_num_shards(filename):
|
||||
with open(filename) as f:
|
||||
map_data = json.load(f)
|
||||
return len(set(map_data["weight_map"].values()))
|
||||
|
Reference in New Issue
Block a user