2022-05-11 21:45:38 +02:00
from threading import Timer
2021-05-11 01:17:10 +02:00
import re
2022-05-11 03:28:13 +02:00
import shutil
import json
import subprocess
import tempfile
import requests
2022-05-13 23:00:10 +02:00
import requests . adapters
import time
2022-06-17 00:45:11 +02:00
from transformers import __version__ as transformers_version
2022-06-19 00:16:56 +02:00
from transformers import PreTrainedModel
2022-06-17 00:45:11 +02:00
import packaging . version
2022-05-13 23:00:10 +02:00
from tqdm . auto import tqdm
2022-05-11 03:28:13 +02:00
import os
2022-05-14 05:32:16 +02:00
import itertools
2022-06-19 00:16:56 +02:00
from typing import List , Optional
2021-05-07 20:32:10 +02:00
2022-06-17 00:45:11 +02:00
HAS_ACCELERATE = packaging . version . parse ( transformers_version ) > = packaging . version . parse ( " 4.20.0.dev0 " )
try :
import accelerate
except ImportError :
HAS_ACCELERATE = False
2022-02-12 19:23:59 +01:00
vars = None
2022-05-13 05:51:40 +02:00
num_shards : Optional [ int ] = None
current_shard = 0
2022-05-14 05:32:16 +02:00
from_pretrained_model_name = " "
from_pretrained_index_filename : Optional [ str ] = None
from_pretrained_kwargs = { }
bar = None
2022-02-12 19:23:59 +01:00
2022-06-20 23:17:42 +02:00
layers_module_names : Optional [ List [ str ] ] = None
module_names : Optional [ List [ str ] ] = None
named_buffers : Optional [ List [ tuple ] ] = None
2022-06-14 01:12:23 +02:00
default_sampler_order = [ 0 , 1 , 2 , 3 , 4 , 5 ]
2021-05-11 01:17:10 +02:00
#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called
#==================================================================#
2021-05-07 20:32:10 +02:00
def debounce ( wait ) :
def decorator ( fun ) :
def debounced ( * args , * * kwargs ) :
def call_it ( ) :
fun ( * args , * * kwargs )
try :
debounced . t . cancel ( )
except AttributeError :
pass
debounced . t = Timer ( wait , call_it )
debounced . t . start ( )
return debounced
2021-05-11 01:17:10 +02:00
return decorator
#==================================================================#
# Replace fancy quotes and apostrope's with standard ones
#==================================================================#
def fixquotes ( txt ) :
txt = txt . replace ( " “ " , ' " ' )
txt = txt . replace ( " ” " , ' " ' )
txt = txt . replace ( " ’ " , " ' " )
txt = txt . replace ( " ` " , " ' " )
return txt
#==================================================================#
#
#==================================================================#
def trimincompletesentence ( txt ) :
# Cache length of text
ln = len ( txt )
# Find last instance of punctuation (Borrowed from Clover-Edition by cloveranon)
lastpunc = max ( txt . rfind ( " . " ) , txt . rfind ( " ! " ) , txt . rfind ( " ? " ) )
# Is this the end of a quote?
if ( lastpunc < ln - 1 ) :
if ( txt [ lastpunc + 1 ] == ' " ' ) :
lastpunc = lastpunc + 1
if ( lastpunc > = 0 ) :
txt = txt [ : lastpunc + 1 ]
return txt
#==================================================================#
#
#==================================================================#
def replaceblanklines ( txt ) :
txt = txt . replace ( " \n \n " , " \n " )
return txt
#==================================================================#
#
#==================================================================#
2021-08-19 13:18:01 +02:00
def removespecialchars ( txt , vars = None ) :
if vars is None or vars . actionmode == 0 :
txt = re . sub ( r " [#/@ % <> {} +=~| \ ^] " , " " , txt )
else :
txt = re . sub ( r " [#/@ % {} +=~| \ ^] " , " " , txt )
2021-05-11 01:17:10 +02:00
return txt
#==================================================================#
# If the next action follows a sentence closure, add a space
#==================================================================#
2021-05-14 08:24:05 +02:00
def addsentencespacing ( txt , vars ) :
2022-06-26 19:02:22 +02:00
# Don't add sentence spacing if submission is empty or starts with whitespace
if ( len ( txt ) == 0 or len ( txt ) != len ( txt . lstrip ( ) ) ) :
return txt
2021-05-11 01:17:10 +02:00
# Get last character of last action
2021-05-14 08:24:05 +02:00
if ( len ( vars . actions ) > 0 ) :
2021-08-29 00:54:10 +02:00
if ( len ( vars . actions [ vars . actions . get_last_key ( ) ] ) > 0 ) :
2021-12-27 00:29:54 +01:00
action = vars . actions [ vars . actions . get_last_key ( ) ]
lastchar = action [ - 1 ] if len ( action ) else " "
2021-05-18 23:59:59 +02:00
else :
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
2021-05-14 08:24:05 +02:00
else :
2021-12-27 00:29:54 +01:00
action = vars . prompt
lastchar = action [ - 1 ] if len ( action ) else " "
2022-06-26 20:27:21 +02:00
if ( lastchar != " " ) :
2021-05-11 01:17:10 +02:00
txt = " " + txt
return txt
2021-10-23 17:30:48 +02:00
def singlelineprocessing ( txt , vars ) :
txt = vars . regex_sl . sub ( ' ' , txt )
if ( len ( vars . actions ) > 0 ) :
if ( len ( vars . actions [ vars . actions . get_last_key ( ) ] ) > 0 ) :
2021-12-27 00:29:54 +01:00
action = vars . actions [ vars . actions . get_last_key ( ) ]
lastchar = action [ - 1 ] if len ( action ) else " "
2021-10-23 17:30:48 +02:00
else :
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
else :
2021-12-27 00:29:54 +01:00
action = vars . prompt
lastchar = action [ - 1 ] if len ( action ) else " "
2021-10-23 17:30:48 +02:00
if ( lastchar != " \n " ) :
txt = txt + " \n "
return txt
2021-05-22 11:28:40 +02:00
#==================================================================#
# Cleans string for use in file name
#==================================================================#
def cleanfilename ( filename ) :
2021-09-01 00:22:30 +02:00
filteredcharacters = ( ' / ' , ' \\ ' )
filename = " " . join ( c for c in filename if c not in filteredcharacters ) . rstrip ( )
2021-05-22 11:28:40 +02:00
return filename
2021-05-11 01:17:10 +02:00
2022-02-12 19:23:59 +01:00
#==================================================================#
# Newline substitution for fairseq models
#==================================================================#
def encodenewlines ( txt ) :
if ( vars . newlinemode == " s " ) :
return txt . replace ( ' \n ' , " </s> " )
return txt
def decodenewlines ( txt ) :
if ( vars . newlinemode == " s " ) :
return txt . replace ( " </s> " , ' \n ' )
2022-05-13 10:44:12 +02:00
if ( vars . newlinemode == " ns " ) :
return txt . replace ( " </s> " , ' ' )
2022-02-12 19:23:59 +01:00
return txt
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-05-13 07:03:38 +02:00
# Returns number of layers given an HF model config
#==================================================================#
def num_layers ( config ) :
2022-06-09 00:42:44 +02:00
return config . num_layers if hasattr ( config , " num_layers " ) else config . n_layer if hasattr ( config , " n_layer " ) else config . num_hidden_layers if hasattr ( config , ' num_hidden_layers ' ) else None
2022-05-13 07:03:38 +02:00
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-05-13 05:51:40 +02:00
# Downloads huggingface checkpoints using aria2c if possible
2022-05-11 03:28:13 +02:00
#==================================================================#
2022-07-22 19:58:20 +02:00
from flask_socketio import emit
class Send_to_socketio ( object ) :
def write ( self , bar ) :
time . sleep ( 0.01 )
try :
2022-08-19 18:13:46 +02:00
print ( bar )
2022-07-22 19:58:20 +02:00
emit ( ' from_server ' , { ' cmd ' : ' model_load_status ' , ' data ' : bar . replace ( " " , " " ) } , broadcast = True )
except :
pass
2022-05-11 21:51:48 +02:00
def aria2_hook ( pretrained_model_name_or_path : str , force_download = False , cache_dir = None , proxies = None , resume_download = False , local_files_only = False , use_auth_token = None , user_agent = None , revision = None , mirror = None , * * kwargs ) :
2022-05-11 03:28:13 +02:00
import transformers
import transformers . modeling_utils
from huggingface_hub import HfFolder
if shutil . which ( " aria2c " ) is None : # Don't do anything if aria2 is not installed
return
2022-05-11 20:40:31 +02:00
if local_files_only : # If local_files_only is true, we obviously don't need to download anything
return
if os . path . isdir ( pretrained_model_name_or_path ) or os . path . isfile ( pretrained_model_name_or_path ) or os . path . isfile ( pretrained_model_name_or_path + " .index " ) or transformers . modeling_utils . is_remote_url ( pretrained_model_name_or_path ) :
2022-05-11 03:28:13 +02:00
return
if proxies :
print ( " WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode. " )
return
if use_auth_token :
if isinstance ( use_auth_token , str ) :
token = use_auth_token
else :
token = HfFolder . get_token ( )
if token is None :
raise EnvironmentError ( " You specified use_auth_token=True, but a huggingface token was not found. " )
_cache_dir = str ( cache_dir ) if cache_dir is not None else transformers . TRANSFORMERS_CACHE
sharded = False
2022-05-11 05:46:29 +02:00
headers = { " user-agent " : transformers . file_utils . http_user_agent ( user_agent ) }
if use_auth_token :
headers [ " authorization " ] = f " Bearer { use_auth_token } "
def is_cached ( url ) :
try :
transformers . file_utils . get_from_cache ( url , cache_dir = cache_dir , local_files_only = True )
2022-07-27 17:45:07 +02:00
except ( FileNotFoundError , transformers . file_utils . EntryNotFoundError ) :
2022-05-11 05:46:29 +02:00
return False
return True
2022-05-11 03:28:13 +02:00
while True : # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try :
filename = transformers . modeling_utils . WEIGHTS_INDEX_NAME if sharded else transformers . modeling_utils . WEIGHTS_NAME
except AttributeError :
return
url = transformers . file_utils . hf_bucket_url ( pretrained_model_name_or_path , filename , revision = revision , mirror = mirror )
2022-05-11 05:46:29 +02:00
if is_cached ( url ) or requests . head ( url , allow_redirects = True , proxies = proxies , headers = headers ) :
2022-05-11 03:28:13 +02:00
break
2022-05-11 05:46:29 +02:00
if sharded :
return
else :
sharded = True
2022-05-11 04:43:41 +02:00
if not sharded : # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [ transformers . modeling_utils . WEIGHTS_NAME ]
else : # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
2022-05-11 20:40:31 +02:00
map_filename = transformers . file_utils . cached_path ( url , cache_dir = cache_dir , force_download = force_download , proxies = proxies , resume_download = resume_download , use_auth_token = use_auth_token , user_agent = user_agent )
2022-05-11 04:43:41 +02:00
with open ( map_filename ) as f :
map_data = json . load ( f )
filenames = set ( map_data [ " weight_map " ] . values ( ) )
2022-05-11 03:28:13 +02:00
urls = [ transformers . file_utils . hf_bucket_url ( pretrained_model_name_or_path , n , revision = revision , mirror = mirror ) for n in filenames ]
2022-05-11 21:45:38 +02:00
if not force_download :
urls = [ u for u in urls if not is_cached ( u ) ]
if not urls :
return
2022-05-11 03:28:13 +02:00
etags = [ h . get ( " X-Linked-Etag " ) or h . get ( " ETag " ) for u in urls for h in [ requests . head ( u , headers = headers , allow_redirects = False , proxies = proxies , timeout = 10 ) . headers ] ]
2022-05-13 23:00:10 +02:00
headers = [ requests . head ( u , headers = headers , allow_redirects = True , proxies = proxies , timeout = 10 ) . headers for u in urls ]
2022-05-11 03:28:13 +02:00
filenames = [ transformers . file_utils . url_to_filename ( u , t ) for u , t in zip ( urls , etags ) ]
2022-05-11 21:14:37 +02:00
for n in filenames :
2022-05-11 21:45:38 +02:00
path = os . path . join ( _cache_dir , " kai-tempfile. " + n + " .aria2 " )
if os . path . exists ( path ) :
os . remove ( path )
path = os . path . join ( _cache_dir , " kai-tempfile. " + n )
2022-05-11 21:14:37 +02:00
if os . path . exists ( path ) :
os . remove ( path )
if force_download :
2022-05-11 04:47:03 +02:00
path = os . path . join ( _cache_dir , n + " .json " )
if os . path . exists ( path ) :
os . remove ( path )
2022-05-11 20:41:34 +02:00
path = os . path . join ( _cache_dir , n )
if os . path . exists ( path ) :
os . remove ( path )
2022-05-13 23:00:10 +02:00
total_length = sum ( int ( h [ " Content-Length " ] ) for h in headers )
lengths = { }
2022-05-11 21:45:38 +02:00
aria2_config = " \n " . join ( f " { u } \n out=kai-tempfile. { n } " for u , n in zip ( urls , filenames ) ) . encode ( )
2022-05-13 23:00:10 +02:00
s = requests . Session ( )
s . mount ( " http:// " , requests . adapters . HTTPAdapter ( max_retries = requests . adapters . Retry ( total = 120 , backoff_factor = 1 ) ) )
bar = None
2022-05-14 17:44:28 +02:00
done = False
2022-05-13 23:00:10 +02:00
secret = os . urandom ( 17 ) . hex ( )
2022-05-11 06:14:00 +02:00
try :
2022-05-13 23:00:10 +02:00
with tempfile . NamedTemporaryFile ( " w+b " , delete = False ) as f :
f . write ( aria2_config )
f . flush ( )
p = subprocess . Popen ( [ " aria2c " , " -x " , " 10 " , " -s " , " 10 " , " -j " , " 10 " , " --enable-rpc=true " , f " --rpc-secret= { secret } " , " --rpc-listen-port " , str ( vars . aria2_port ) , " --disable-ipv6 " , " --file-allocation=trunc " , " --allow-overwrite " , " --auto-file-renaming=false " , " -d " , _cache_dir , " -i " , f . name , " -U " , transformers . file_utils . http_user_agent ( user_agent ) ] + ( [ " -c " ] if not force_download else [ ] ) + ( [ f " --header= ' Authorization: Bearer { token } ' " ] if use_auth_token else [ ] ) , stdout = subprocess . DEVNULL , stderr = subprocess . DEVNULL )
while p . poll ( ) is None :
r = s . post ( f " http://localhost: { vars . aria2_port } /jsonrpc " , json = { " jsonrpc " : " 2.0 " , " id " : " kai " , " method " : " aria2.tellActive " , " params " : [ f " token: { secret } " ] } ) . json ( ) [ " result " ]
if not r :
s . close ( )
if bar is not None :
bar . n = bar . total
bar . close ( )
p . terminate ( )
2022-05-14 17:44:28 +02:00
done = True
2022-05-13 23:00:10 +02:00
break
if bar is None :
2022-07-22 19:58:20 +02:00
bar = tqdm ( total = total_length , desc = f " [aria2] Downloading model " , unit = " B " , unit_scale = True , unit_divisor = 1000 , file = Send_to_socketio ( ) )
2022-05-13 23:00:10 +02:00
visited = set ( )
for x in r :
filename = x [ " files " ] [ 0 ] [ " path " ]
lengths [ filename ] = ( int ( x [ " completedLength " ] ) , int ( x [ " totalLength " ] ) )
visited . add ( filename )
for k , v in lengths . items ( ) :
if k not in visited :
lengths [ k ] = ( v [ 1 ] , v [ 1 ] )
bar . n = sum ( v [ 0 ] for v in lengths . values ( ) )
bar . update ( )
time . sleep ( 0.1 )
path = f . name
except Exception as e :
p . terminate ( )
raise e
finally :
try :
os . remove ( path )
except OSError :
pass
2022-05-11 22:23:24 +02:00
code = p . wait ( )
2022-05-14 17:44:28 +02:00
if not done and code :
2022-05-11 22:23:24 +02:00
raise OSError ( f " aria2 exited with exit code { code } " )
2022-05-11 03:28:13 +02:00
for u , t , n in zip ( urls , etags , filenames ) :
2022-05-11 21:45:38 +02:00
os . rename ( os . path . join ( _cache_dir , " kai-tempfile. " + n ) , os . path . join ( _cache_dir , n ) )
2022-05-11 03:28:13 +02:00
with open ( os . path . join ( _cache_dir , n + " .json " ) , " w " ) as f :
json . dump ( { " url " : u , " etag " : t } , f )
2022-05-13 05:51:40 +02:00
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards ( filename ) :
with open ( filename ) as f :
map_data = json . load ( f )
return len ( set ( map_data [ " weight_map " ] . values ( ) ) )
2022-05-14 05:32:16 +02:00
#==================================================================#
# Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors ( pretrained_model_name_or_path , filename , cache_dir = None , force_download = False , proxies = None , resume_download = False , local_files_only = False , use_auth_token = None , user_agent = None , revision = None , mirror = None , * * kwargs ) :
import transformers . modeling_utils
import torch
shard_paths , _ = transformers . modeling_utils . get_checkpoint_shard_files ( pretrained_model_name_or_path , filename , cache_dir = cache_dir , force_download = force_download , proxies = proxies , resume_download = resume_download , local_files_only = local_files_only , use_auth_token = use_auth_token , user_agent = user_agent , revision = revision , mirror = mirror )
return list ( itertools . chain ( * ( torch . load ( p , map_location = " cpu " ) . keys ( ) for p in shard_paths ) ) )
2022-06-17 00:45:11 +02:00
2022-06-19 00:16:56 +02:00
#==================================================================#
# Given a PreTrainedModel, returns the list of module names that correspond
# to the model's hidden layers.
#==================================================================#
def get_layers_module_names ( model : PreTrainedModel ) - > List [ str ] :
names : List [ str ] = [ ]
2022-06-17 00:45:11 +02:00
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if c [ 0 ] . isnumeric ( ) and any ( c [ 1 ] . __class__ . __name__ . endswith ( suffix ) for suffix in ( " Block " , " Layer " ) ) :
names . append ( name )
else :
recurse ( c [ 1 ] , head = name + " . " )
recurse ( model )
return names
2022-06-19 00:16:56 +02:00
#==================================================================#
# Given a PreTrainedModel, returns the module name that corresponds
# to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name ( model : PreTrainedModel ) - > str :
embeddings = model . get_input_embeddings ( )
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if c [ 1 ] is embeddings :
return name
else :
return recurse ( c [ 1 ] , head = name + " . " )
return recurse ( model )
#==================================================================#
# Given a PreTrainedModel and a list of module names, returns a list
# of module names such that the union of the set of modules given as input
# and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names ( model : PreTrainedModel , names : List [ str ] ) - > List [ str ] :
missing_names : List [ str ] = [ ]
def recurse ( module , head = " " ) :
for c in module . named_children ( ) :
name = head + c [ 0 ]
if any ( name . startswith ( n ) for n in names ) :
continue
if next ( c [ 1 ] . named_children ( ) , None ) is None :
missing_names . append ( name )
else :
recurse ( c [ 1 ] , head = name + " . " )
recurse ( model )
return missing_names