2022-05-11 21:45:38 +02:00
from threading import Timer
2021-05-11 01:17:10 +02:00
import re
2022-05-11 03:28:13 +02:00
import shutil
import json
import subprocess
import tempfile
2022-09-15 22:50:43 +02:00
from urllib . error import HTTPError
2022-05-11 03:28:13 +02:00
import requests
2022-05-13 23:00:10 +02:00
import requests . adapters
import time
2022-06-17 00:45:11 +02:00
from transformers import __version__ as transformers_version
2022-06-19 00:16:56 +02:00
from transformers import PreTrainedModel
2022-06-17 00:45:11 +02:00
import packaging . version
2022-05-13 23:00:10 +02:00
from tqdm . auto import tqdm
2022-05-11 03:28:13 +02:00
import os
2022-05-14 05:32:16 +02:00
import itertools
2022-09-15 19:37:50 +02:00
import hashlib
import huggingface_hub
2022-09-15 22:50:43 +02:00
import packaging . version
from pathlib import Path
2022-06-19 00:16:56 +02:00
from typing import List , Optional
2021-05-07 20:32:10 +02:00
2022-06-17 00:45:11 +02:00
HAS_ACCELERATE = packaging . version . parse ( transformers_version ) > = packaging . version . parse ( " 4.20.0.dev0 " )
try :
import accelerate
except ImportError :
HAS_ACCELERATE = False
2022-02-12 19:23:59 +01:00
vars = None
2022-05-13 05:51:40 +02:00
num_shards : Optional [ int ] = None
current_shard = 0
2022-05-14 05:32:16 +02:00
from_pretrained_model_name = " "
from_pretrained_index_filename : Optional [ str ] = None
from_pretrained_kwargs = { }
bar = None
2022-02-12 19:23:59 +01:00
2022-06-20 23:17:42 +02:00
layers_module_names : Optional [ List [ str ] ] = None
module_names : Optional [ List [ str ] ] = None
named_buffers : Optional [ List [ tuple ] ] = None
2022-08-23 21:10:21 +02:00
default_sampler_order = [ 6 , 0 , 1 , 2 , 3 , 4 , 5 ]
2022-06-14 01:12:23 +02:00
2021-05-11 01:17:10 +02:00
#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called
#==================================================================#
2021-05-07 20:32:10 +02:00
def debounce ( wait ) :
def decorator ( fun ) :
def debounced ( * args , * * kwargs ) :
def call_it ( ) :
fun ( * args , * * kwargs )
try :
debounced . t . cancel ( )
except AttributeError :
pass
debounced . t = Timer ( wait , call_it )
debounced . t . start ( )
return debounced
2021-05-11 01:17:10 +02:00
return decorator
#==================================================================#
# Replace fancy quotes and apostrope's with standard ones
#==================================================================#
def fixquotes ( txt ) :
txt = txt . replace ( " “ " , ' " ' )
txt = txt . replace ( " ” " , ' " ' )
txt = txt . replace ( " ’ " , " ' " )
txt = txt . replace ( " ` " , " ' " )
return txt
#==================================================================#
#
#==================================================================#
def trimincompletesentence ( txt ) :
# Cache length of text
ln = len ( txt )
# Find last instance of punctuation (Borrowed from Clover-Edition by cloveranon)
lastpunc = max ( txt . rfind ( " . " ) , txt . rfind ( " ! " ) , txt . rfind ( " ? " ) )
# Is this the end of a quote?
if ( lastpunc < ln - 1 ) :
if ( txt [ lastpunc + 1 ] == ' " ' ) :
lastpunc = lastpunc + 1
if ( lastpunc > = 0 ) :
txt = txt [ : lastpunc + 1 ]
return txt
#==================================================================#
#
#==================================================================#
def replaceblanklines ( txt ) :
txt = txt . replace ( " \n \n " , " \n " )
return txt
#==================================================================#
#
#==================================================================#
2021-08-19 13:18:01 +02:00
def removespecialchars ( txt , vars = None ) :
if vars is None or vars . actionmode == 0 :
txt = re . sub ( r " [#/@ % <> {} +=~| \ ^] " , " " , txt )
else :
txt = re . sub ( r " [#/@ % {} +=~| \ ^] " , " " , txt )
2021-05-11 01:17:10 +02:00
return txt
#==================================================================#
# If the next action follows a sentence closure, add a space
#==================================================================#
2021-05-14 08:24:05 +02:00
def addsentencespacing ( txt , vars ) :
2022-06-26 19:02:22 +02:00
# Don't add sentence spacing if submission is empty or starts with whitespace
if ( len ( txt ) == 0 or len ( txt ) != len ( txt . lstrip ( ) ) ) :
return txt
2021-05-11 01:17:10 +02:00
# Get last character of last action
2021-05-14 08:24:05 +02:00
if ( len ( vars . actions ) > 0 ) :
2021-08-29 00:54:10 +02:00
if ( len ( vars . actions [ vars . actions . get_last_key ( ) ] ) > 0 ) :
2021-12-27 00:29:54 +01:00
action = vars . actions [ vars . actions . get_last_key ( ) ]
lastchar = action [ - 1 ] if len ( action ) else " "
2021-05-18 23:59:59 +02:00
else :
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
2021-05-14 08:24:05 +02:00
else :
2021-12-27 00:29:54 +01:00
action = vars . prompt
lastchar = action [ - 1 ] if len ( action ) else " "
2022-06-26 20:27:21 +02:00
if ( lastchar != " " ) :
2021-05-11 01:17:10 +02:00
txt = " " + txt
return txt
2021-10-23 17:30:48 +02:00
def singlelineprocessing ( txt , vars ) :
txt = vars . regex_sl . sub ( ' ' , txt )
if ( len ( vars . actions ) > 0 ) :
if ( len ( vars . actions [ vars . actions . get_last_key ( ) ] ) > 0 ) :
2021-12-27 00:29:54 +01:00
action = vars . actions [ vars . actions . get_last_key ( ) ]
lastchar = action [ - 1 ] if len ( action ) else " "
2021-10-23 17:30:48 +02:00
else :
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
else :
2021-12-27 00:29:54 +01:00
action = vars . prompt
lastchar = action [ - 1 ] if len ( action ) else " "
2021-10-23 17:30:48 +02:00
if ( lastchar != " \n " ) :
txt = txt + " \n "
return txt
2021-05-22 11:28:40 +02:00
#==================================================================#
# Cleans string for use in file name
#==================================================================#
def cleanfilename ( filename ) :
2021-09-01 00:22:30 +02:00
filteredcharacters = ( ' / ' , ' \\ ' )
filename = " " . join ( c for c in filename if c not in filteredcharacters ) . rstrip ( )
2021-05-22 11:28:40 +02:00
return filename
2021-05-11 01:17:10 +02:00
2022-02-12 19:23:59 +01:00
#==================================================================#
# Newline substitution for fairseq models
#==================================================================#
def encodenewlines ( txt ) :
if ( vars . newlinemode == " s " ) :
return txt . replace ( ' \n ' , " </s> " )
return txt
def decodenewlines ( txt ) :
if ( vars . newlinemode == " s " ) :
return txt . replace ( " </s> " , ' \n ' )
2022-05-13 10:44:12 +02:00
if ( vars . newlinemode == " ns " ) :
return txt . replace ( " </s> " , ' ' )
2022-02-12 19:23:59 +01:00
return txt
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-05-13 07:03:38 +02:00
# Returns number of layers given an HF model config
#==================================================================#
def num_layers ( config ) :
2022-08-22 20:45:02 +02:00
return config [ " n_layer " ] if isinstance ( config , dict ) else config . num_layers if hasattr ( config , " num_layers " ) else config . n_layer if hasattr ( config , " n_layer " ) else config . num_hidden_layers if hasattr ( config , ' num_hidden_layers ' ) else None
2022-05-13 07:03:38 +02:00
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-05-13 05:51:40 +02:00
# Downloads huggingface checkpoints using aria2c if possible
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-07-22 19:58:20 +02:00
from flask_socketio import emit
2022-09-15 22:50:43 +02:00
def _download_with_aria2 ( aria2_config : str , total_length : int , directory : str = " . " , user_agent = None , force_download = False , use_auth_token = None ) :
2022-09-21 19:07:49 +02:00
class Send_to_socketio ( object ) :
def write ( self , bar ) :
bar = bar . replace ( " \r " , " " ) . replace ( " \n " , " " )
if bar != " " :
try :
2022-09-21 19:48:17 +02:00
print ( ' \r ' + bar , end = ' ' )
2022-09-21 19:07:49 +02:00
try :
emit ( ' from_server ' , { ' cmd ' : ' model_load_status ' , ' data ' : bar . replace ( " " , " " ) } , broadcast = True )
except :
pass
eventlet . sleep ( seconds = 0 )
except :
pass
def flush ( self ) :
pass
2022-09-15 22:50:43 +02:00
import transformers
lengths = { }
s = requests . Session ( )
s . mount ( " http:// " , requests . adapters . HTTPAdapter ( max_retries = requests . adapters . Retry ( total = 120 , backoff_factor = 1 ) ) )
bar = None
done = False
secret = os . urandom ( 17 ) . hex ( )
try :
with tempfile . NamedTemporaryFile ( " w+b " , delete = False ) as f :
f . write ( aria2_config )
f . flush ( )
p = subprocess . Popen ( [ " aria2c " , " -x " , " 10 " , " -s " , " 10 " , " -j " , " 10 " , " --enable-rpc=true " , f " --rpc-secret= { secret } " , " --rpc-listen-port " , str ( vars . aria2_port ) , " --disable-ipv6 " , " --file-allocation=trunc " , " --allow-overwrite " , " --auto-file-renaming=false " , " -d " , directory , " -i " , f . name , " -U " , transformers . file_utils . http_user_agent ( user_agent ) ] + ( [ " -c " ] if not force_download else [ ] ) + ( [ f " --header= ' Authorization: Bearer { use_auth_token } ' " ] if use_auth_token else [ ] ) , stdout = subprocess . DEVNULL , stderr = subprocess . DEVNULL )
while p . poll ( ) is None :
r = s . post ( f " http://localhost: { vars . aria2_port } /jsonrpc " , json = { " jsonrpc " : " 2.0 " , " id " : " kai " , " method " : " aria2.tellActive " , " params " : [ f " token: { secret } " ] } ) . json ( ) [ " result " ]
if not r :
s . close ( )
if bar is not None :
bar . n = bar . total
bar . close ( )
p . terminate ( )
done = True
break
if bar is None :
2022-09-21 19:07:49 +02:00
bar = tqdm ( total = total_length , desc = f " [aria2] Downloading model " , unit = " B " , unit_scale = True , unit_divisor = 1000 , file = Send_to_socketio ( ) )
2022-09-15 22:50:43 +02:00
visited = set ( )
for x in r :
filename = x [ " files " ] [ 0 ] [ " path " ]
lengths [ filename ] = ( int ( x [ " completedLength " ] ) , int ( x [ " totalLength " ] ) )
visited . add ( filename )
for k , v in lengths . items ( ) :
if k not in visited :
lengths [ k ] = ( v [ 1 ] , v [ 1 ] )
bar . n = sum ( v [ 0 ] for v in lengths . values ( ) )
bar . update ( )
time . sleep ( 0.1 )
path = f . name
except Exception as e :
p . terminate ( )
raise e
finally :
try :
os . remove ( path )
except OSError :
pass
code = p . wait ( )
if not done and code :
raise OSError ( f " aria2 exited with exit code { code } " )
def _transformers22_aria2_hook ( pretrained_model_name_or_path : str , force_download = False , cache_dir = None , proxies = None , resume_download = False , local_files_only = False , use_auth_token = None , user_agent = None , revision = None , * * kwargs ) :
import transformers
import transformers . modeling_utils
from huggingface_hub import HfFolder
if use_auth_token :
if isinstance ( use_auth_token , str ) :
token = use_auth_token
else :
token = HfFolder . get_token ( )
if token is None :
raise EnvironmentError ( " You specified use_auth_token=True, but a huggingface token was not found. " )
_cache_dir = str ( cache_dir ) if cache_dir is not None else transformers . TRANSFORMERS_CACHE
_revision = revision if revision is not None else huggingface_hub . constants . DEFAULT_REVISION
sharded = False
headers = { " user-agent " : transformers . file_utils . http_user_agent ( user_agent ) }
if use_auth_token :
headers [ " authorization " ] = f " Bearer { use_auth_token } "
storage_folder = os . path . join ( _cache_dir , huggingface_hub . file_download . repo_folder_name ( repo_id = pretrained_model_name_or_path , repo_type = " model " ) )
os . makedirs ( storage_folder , exist_ok = True )
def is_cached ( filename ) :
try :
huggingface_hub . hf_hub_download ( pretrained_model_name_or_path , filename , cache_dir = cache_dir , local_files_only = True )
except ValueError :
return False
return True
while True : # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try :
filename = transformers . modeling_utils . WEIGHTS_INDEX_NAME if sharded else transformers . modeling_utils . WEIGHTS_NAME
except AttributeError :
return
url = huggingface_hub . hf_hub_url ( pretrained_model_name_or_path , filename , revision = revision )
if is_cached ( filename ) or requests . head ( url , allow_redirects = True , proxies = proxies , headers = headers ) :
break
if sharded :
return
else :
sharded = True
if not sharded : # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [ transformers . modeling_utils . WEIGHTS_NAME ]
else : # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub . hf_hub_download ( pretrained_model_name_or_path , filename , cache_dir = cache_dir , force_download = force_download , proxies = proxies , resume_download = resume_download , use_auth_token = use_auth_token , user_agent = user_agent )
with open ( map_filename ) as f :
map_data = json . load ( f )
filenames = set ( map_data [ " weight_map " ] . values ( ) )
urls = [ huggingface_hub . hf_hub_url ( pretrained_model_name_or_path , n , revision = revision ) for n in filenames ]
if not force_download :
urls = [ u for u , n in zip ( urls , filenames ) if not is_cached ( n ) ]
if not urls :
return
blob_paths = [ ]
# This section is a modified version of hf_hub_download from huggingface_hub
# See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license
for u , n in zip ( urls , filenames ) :
relative_filename = os . path . join ( * n . split ( " / " ) )
if not local_files_only :
try :
r = huggingface_hub . file_download . _request_wrapper (
method = " HEAD " ,
url = u ,
headers = headers ,
allow_redirects = False ,
follow_relative_redirects = True ,
proxies = proxies ,
timeout = 10 ,
)
try :
r . raise_for_status ( )
except HTTPError as e :
error_code = r . headers . get ( " X-Error-Code " )
if error_code != " EntryNotFound " :
raise RuntimeError ( f " HEAD { u } failed with error code { r . status_code } " )
commit_hash = r . headers . get ( huggingface_hub . file_download . HUGGINGFACE_HEADER_X_REPO_COMMIT )
if commit_hash is not None :
no_exist_file_path = (
Path ( storage_folder )
/ " .no_exist "
/ commit_hash
/ relative_filename
)
no_exist_file_path . parent . mkdir ( parents = True , exist_ok = True )
no_exist_file_path . touch ( )
huggingface_hub . file_download . _cache_commit_hash_for_specific_revision (
storage_folder , _revision , commit_hash
)
raise
commit_hash = r . headers [ huggingface_hub . file_download . HUGGINGFACE_HEADER_X_REPO_COMMIT ]
if commit_hash is None :
raise OSError (
" Distant resource does not seem to be on huggingface.co (missing "
" commit header). "
)
etag = r . headers . get ( huggingface_hub . file_download . HUGGINGFACE_HEADER_X_LINKED_ETAG ) or r . headers . get (
" ETag "
)
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None :
raise OSError (
" Distant resource does not have an ETag, we won ' t be able to "
" reliably ensure reproducibility. "
)
etag = huggingface_hub . file_download . _normalize_etag ( etag )
# In case of a redirect, save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
# Useful for lfs blobs that are stored on a CDN.
if 300 < = r . status_code < = 399 :
url_to_download = r . headers [ " Location " ]
if (
" lfs.huggingface.co " in url_to_download
or " lfs-staging.huggingface.co " in url_to_download
) :
# Remove authorization header when downloading a LFS blob
headers . pop ( " authorization " , None )
except ( requests . exceptions . SSLError , requests . exceptions . ProxyError ) :
# Actually raise for those subclasses of ConnectionError
raise
except (
requests . exceptions . ConnectionError ,
requests . exceptions . Timeout ,
huggingface_hub . file_download . OfflineModeIsEnabled ,
) :
# Otherwise, our Internet connection is down.
# etag is None
pass
if etag is None :
# In those cases, we cannot force download.
if force_download :
raise ValueError (
" We have no connection or you passed local_files_only, so "
" force_download is not an accepted option. "
)
if huggingface_hub . file_download . REGEX_COMMIT_HASH . match ( _revision ) :
commit_hash = _revision
else :
ref_path = os . path . join ( storage_folder , " refs " , _revision )
with open ( ref_path ) as f :
commit_hash = f . read ( )
pointer_path = os . path . join (
storage_folder , " snapshots " , commit_hash , relative_filename
)
if os . path . exists ( pointer_path ) :
return pointer_path
# If we couldn't find an appropriate file on disk,
# raise an error.
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only :
raise huggingface_hub . file_download . LocalEntryNotFoundError (
" Cannot find the requested files in the disk cache and "
" outgoing traffic has been disabled. To enable hf.co look-ups "
" and downloads online, set ' local_files_only ' to False. "
)
else :
raise huggingface_hub . file_download . LocalEntryNotFoundError (
" Connection error, and we cannot find the requested files in "
" the disk cache. Please try again or make sure your Internet "
" connection is on. "
)
# From now on, etag and commit_hash are not None.
blob_path = os . path . join ( storage_folder , " blobs " , etag )
pointer_path = os . path . join (
storage_folder , " snapshots " , commit_hash , relative_filename
)
os . makedirs ( os . path . dirname ( blob_path ) , exist_ok = True )
os . makedirs ( os . path . dirname ( pointer_path ) , exist_ok = True )
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
huggingface_hub . file_download . _cache_commit_hash_for_specific_revision ( storage_folder , _revision , commit_hash )
if os . path . exists ( pointer_path ) and not force_download :
return pointer_path
if os . path . exists ( blob_path ) and not force_download :
# we have the blob already, but not the pointer
huggingface_hub . file_download . logger . info ( " creating pointer to %s from %s " , blob_path , pointer_path )
huggingface_hub . file_download . _create_relative_symlink ( blob_path , pointer_path )
return pointer_path
# Some Windows versions do not allow for paths longer than 255 characters.
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
if os . name == " nt " and len ( os . path . abspath ( blob_path ) ) > 255 :
blob_path = " \\ \\ ? \\ " + os . path . abspath ( blob_path )
blob_paths . append ( blob_path )
filenames = blob_paths
headers = [ requests . head ( u , headers = headers , allow_redirects = True , proxies = proxies , timeout = 10 ) . headers for u in urls ]
for n in filenames :
2022-09-21 18:57:09 +02:00
prefix , suffix = n . rsplit ( os . sep , 1 )
2022-09-21 18:47:13 +02:00
path = os . path . join ( prefix , " kai-tempfile. " + suffix + " .aria2 " )
if os . path . exists ( path ) :
os . remove ( path )
path = os . path . join ( prefix , " kai-tempfile. " + suffix )
if os . path . exists ( path ) :
os . remove ( path )
total_length = sum ( int ( h [ " Content-Length " ] ) for h in headers )
2022-09-21 18:57:09 +02:00
aria2_config = " \n " . join ( f " { u } \n out= { os . path . join ( prefix , ' kai-tempfile. ' + suffix ) } " for u , n in zip ( urls , filenames ) for prefix , suffix in [ n . rsplit ( os . sep , 1 ) ] ) . encode ( )
2022-09-21 18:47:13 +02:00
_download_with_aria2 ( aria2_config , total_length , use_auth_token = token if use_auth_token else None , user_agent = user_agent , force_download = force_download )
for u , n in zip ( urls , filenames ) :
2022-09-21 18:57:09 +02:00
prefix , suffix = n . rsplit ( os . sep , 1 )
2022-09-21 18:47:13 +02:00
os . rename ( os . path . join ( prefix , " kai-tempfile. " + suffix ) , os . path . join ( prefix , suffix ) )
2022-09-15 22:50:43 +02:00
2022-09-15 19:37:50 +02:00
def aria2_hook ( pretrained_model_name_or_path : str , force_download = False , cache_dir = None , proxies = None , resume_download = False , local_files_only = False , use_auth_token = None , user_agent = None , revision = None , * * kwargs ) :
2022-05-11 03:28:13 +02:00
import transformers
import transformers . modeling_utils
from huggingface_hub import HfFolder
if shutil . which ( " aria2c " ) is None : # Don't do anything if aria2 is not installed
return
2022-05-11 20:40:31 +02:00
if local_files_only : # If local_files_only is true, we obviously don't need to download anything
return
if os . path . isdir ( pretrained_model_name_or_path ) or os . path . isfile ( pretrained_model_name_or_path ) or os . path . isfile ( pretrained_model_name_or_path + " .index " ) or transformers . modeling_utils . is_remote_url ( pretrained_model_name_or_path ) :
2022-05-11 03:28:13 +02:00
return
if proxies :
print ( " WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode. " )
return
2022-09-15 22:50:43 +02:00
if packaging . version . parse ( transformers . __version__ ) > = packaging . version . parse ( " 4.22.0.dev0 " ) :
return _transformers22_aria2_hook ( pretrained_model_name_or_path , force_download = force_download , cache_dir = cache_dir , proxies = proxies , resume_download = resume_download , local_files_only = local_files_only , use_auth_token = use_auth_token , revision = revision , * * kwargs )
2022-05-11 03:28:13 +02:00
if use_auth_token :
if isinstance ( use_auth_token , str ) :
token = use_auth_token
else :
token = HfFolder . get_token ( )
if token is None :
raise EnvironmentError ( " You specified use_auth_token=True, but a huggingface token was not found. " )
_cache_dir = str ( cache_dir ) if cache_dir is not None else transformers . TRANSFORMERS_CACHE
sharded = False
2022-05-11 05:46:29 +02:00
headers = { " user-agent " : transformers . file_utils . http_user_agent ( user_agent ) }
if use_auth_token :
headers [ " authorization " ] = f " Bearer { use_auth_token } "
def is_cached ( url ) :
try :
2022-09-15 19:37:50 +02:00
huggingface_hub . cached_download ( url , cache_dir = cache_dir , local_files_only = True )
except ValueError :
2022-05-11 05:46:29 +02:00
return False
return True
2022-05-11 03:28:13 +02:00
while True : # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try :
filename = transformers . modeling_utils . WEIGHTS_INDEX_NAME if sharded else transformers . modeling_utils . WEIGHTS_NAME
except AttributeError :
return
2022-09-15 19:37:50 +02:00
url = huggingface_hub . hf_hub_url ( pretrained_model_name_or_path , filename , revision = revision )
2022-05-11 05:46:29 +02:00
if is_cached ( url ) or requests . head ( url , allow_redirects = True , proxies = proxies , headers = headers ) :
2022-05-11 03:28:13 +02:00
break
2022-05-11 05:46:29 +02:00
if sharded :
return
else :
sharded = True
2022-05-11 04:43:41 +02:00
if not sharded : # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [ transformers . modeling_utils . WEIGHTS_NAME ]
else : # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
2022-09-15 19:37:50 +02:00
map_filename = huggingface_hub . cached_download ( url , cache_dir = cache_dir , force_download = force_download , proxies = proxies , resume_download = resume_download , use_auth_token = use_auth_token , user_agent = user_agent )
2022-05-11 04:43:41 +02:00
with open ( map_filename ) as f :
map_data = json . load ( f )
filenames = set ( map_data [ " weight_map " ] . values ( ) )
2022-09-15 19:37:50 +02:00
urls = [ huggingface_hub . hf_hub_url ( pretrained_model_name_or_path , n , revision = revision ) for n in filenames ]
2022-05-11 21:45:38 +02:00
if not force_download :
urls = [ u for u in urls if not is_cached ( u ) ]
if not urls :
return
2022-05-11 03:28:13 +02:00
etags = [ h . get ( " X-Linked-Etag " ) or h . get ( " ETag " ) for u in urls for h in [ requests . head ( u , headers = headers , allow_redirects = False , proxies = proxies , timeout = 10 ) . headers ] ]
2022-05-13 23:00:10 +02:00
headers = [ requests . head ( u , headers = headers , allow_redirects = True , proxies = proxies , timeout = 10 ) . headers for u in urls ]
2022-09-15 19:37:50 +02:00
filenames = [ hashlib . sha256 ( u . encode ( " utf-8 " ) ) . hexdigest ( ) + " . " + hashlib . sha256 ( t . encode ( " utf-8 " ) ) . hexdigest ( ) for u , t in zip ( urls , etags ) ]
2022-05-11 21:14:37 +02:00
for n in filenames :
2022-05-11 21:45:38 +02:00
path = os . path . join ( _cache_dir , " kai-tempfile. " + n + " .aria2 " )
if os . path . exists ( path ) :
os . remove ( path )
path = os . path . join ( _cache_dir , " kai-tempfile. " + n )
2022-05-11 21:14:37 +02:00
if os . path . exists ( path ) :
os . remove ( path )
if force_download :
2022-05-11 04:47:03 +02:00
path = os . path . join ( _cache_dir , n + " .json " )
if os . path . exists ( path ) :
os . remove ( path )
2022-05-11 20:41:34 +02:00
path = os . path . join ( _cache_dir , n )
if os . path . exists ( path ) :
os . remove ( path )
2022-05-13 23:00:10 +02:00
total_length = sum ( int ( h [ " Content-Length " ] ) for h in headers )
2022-05-11 21:45:38 +02:00
aria2_config = " \n " . join ( f " { u } \n out=kai-tempfile. { n } " for u , n in zip ( urls , filenames ) ) . encode ( )
2022-09-15 22:50:43 +02:00
_download_with_aria2 ( aria2_config , total_length , directory = _cache_dir , use_auth_token = token if use_auth_token else None , user_agent = user_agent , force_download = force_download )
2022-05-11 03:28:13 +02:00
for u , t , n in zip ( urls , etags , filenames ) :
2022-05-11 21:45:38 +02:00
os . rename ( os . path . join ( _cache_dir , " kai-tempfile. " + n ) , os . path . join ( _cache_dir , n ) )
2022-05-11 03:28:13 +02:00
with open ( os . path . join ( _cache_dir , n + " .json " ) , " w " ) as f :
json . dump ( { " url " : u , " etag " : t } , f )
2022-05-13 05:51:40 +02:00
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards ( filename ) :
with open ( filename ) as f :
map_data = json . load ( f )
return len ( set ( map_data [ " weight_map " ] . values ( ) ) )
2022-05-14 05:32:16 +02:00
#==================================================================#
# Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
2022-09-15 19:37:50 +02:00
def get_sharded_checkpoint_num_tensors ( pretrained_model_name_or_path , filename , cache_dir = None , force_download = False , proxies = None , resume_download = False , local_files_only = False , use_auth_token = None , user_agent = None , revision = None , * * kwargs ) :
2022-05-14 05:32:16 +02:00
import transformers . modeling_utils
import torch
2022-09-15 19:37:50 +02:00
shard_paths , _ = transformers . modeling_utils . get_checkpoint_shard_files ( pretrained_model_name_or_path , filename , cache_dir = cache_dir , force_download = force_download , proxies = proxies , resume_download = resume_download , local_files_only = local_files_only , use_auth_token = use_auth_token , user_agent = user_agent , revision = revision )
2022-05-14 05:32:16 +02:00
return list ( itertools . chain ( * ( torch . load ( p , map_location = " cpu " ) . keys ( ) for p in shard_paths ) ) )
2022-06-17 00:45:11 +02:00
2022-06-19 00:16:56 +02:00
#==================================================================#
# Given a PreTrainedModel, returns the list of module names that correspond
# to the model's hidden layers.
#==================================================================#
def get_layers_module_names ( model : PreTrainedModel ) - > List [ str ] :
names : List [ str ] = [ ]
2022-06-17 00:45:11 +02:00
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if c [ 0 ] . isnumeric ( ) and any ( c [ 1 ] . __class__ . __name__ . endswith ( suffix ) for suffix in ( " Block " , " Layer " ) ) :
names . append ( name )
else :
recurse ( c [ 1 ] , head = name + " . " )
recurse ( model )
return names
2022-06-19 00:16:56 +02:00
#==================================================================#
# Given a PreTrainedModel, returns the module name that corresponds
# to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name ( model : PreTrainedModel ) - > str :
embeddings = model . get_input_embeddings ( )
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if c [ 1 ] is embeddings :
return name
else :
return recurse ( c [ 1 ] , head = name + " . " )
return recurse ( model )
#==================================================================#
# Given a PreTrainedModel and a list of module names, returns a list
# of module names such that the union of the set of modules given as input
# and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names ( model : PreTrainedModel , names : List [ str ] ) - > List [ str ] :
missing_names : List [ str ] = [ ]
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if any ( name . startswith ( n ) for n in names ) :
continue
if next ( c [ 1 ] . named_children ( ) , None ) is None :
missing_names . append ( name )
else :
recurse ( c [ 1 ] , head = name + " . " )
recurse ( model )
return missing_names