Merge pull request #232 from VE-FORBRYDERNE/mkultra
Universal mkultra-based soft prompt tuner
This commit is contained in:
commit
351fb3c80b
|
@ -452,6 +452,7 @@ def emit(*args, **kwargs):
|
||||||
return _emit(*args, **kwargs)
|
return _emit(*args, **kwargs)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return socketio.emit(*args, **kwargs)
|
return socketio.emit(*args, **kwargs)
|
||||||
|
utils.emit = emit
|
||||||
|
|
||||||
# marshmallow/apispec setup
|
# marshmallow/apispec setup
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
|
@ -879,7 +880,7 @@ def device_config(config):
|
||||||
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
||||||
|
|
||||||
logger.init_ok("Final device configuration:", status="Info")
|
logger.init_ok("Final device configuration:", status="Info")
|
||||||
device_list(n_layers)
|
device_list(n_layers, primary=breakmodel.primary_device)
|
||||||
|
|
||||||
# If all layers are on the same device, use the old GPU generation mode
|
# If all layers are on the same device, use the old GPU generation mode
|
||||||
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
||||||
|
@ -1361,6 +1362,8 @@ def general_startup(override_args=None):
|
||||||
else:
|
else:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
utils.args = args
|
||||||
|
|
||||||
set_logger_verbosity(args.verbosity)
|
set_logger_verbosity(args.verbosity)
|
||||||
quiesce_logger(args.quiesce)
|
quiesce_logger(args.quiesce)
|
||||||
if args.customsettings:
|
if args.customsettings:
|
||||||
|
@ -1796,7 +1799,9 @@ def patch_transformers():
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
if(not hasattr(PreTrainedModel, "_kai_patched")):
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
PreTrainedModel._kai_patched = True
|
||||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
|
@ -2662,7 +2667,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
if(not hasattr(PreTrainedModel, "_kai_patched")):
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
PreTrainedModel._kai_patched = True
|
||||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
8
utils.py
8
utils.py
|
@ -27,6 +27,7 @@ except ImportError:
|
||||||
HAS_ACCELERATE = False
|
HAS_ACCELERATE = False
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
|
args = None
|
||||||
num_shards: Optional[int] = None
|
num_shards: Optional[int] = None
|
||||||
current_shard = 0
|
current_shard = 0
|
||||||
from_pretrained_model_name = ""
|
from_pretrained_model_name = ""
|
||||||
|
@ -40,6 +41,8 @@ named_buffers: Optional[List[tuple]] = None
|
||||||
|
|
||||||
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
emit = None
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
# at least x seconds have passed without the function being called
|
# at least x seconds have passed without the function being called
|
||||||
|
@ -198,6 +201,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
aria2_port = 6799 if vars is None else vars.aria2_port
|
||||||
lengths = {}
|
lengths = {}
|
||||||
s = requests.Session()
|
s = requests.Session()
|
||||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||||
|
@ -208,9 +212,9 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
f.flush()
|
f.flush()
|
||||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
while p.poll() is None:
|
while p.poll() is None:
|
||||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
r = s.post(f"http://localhost:{aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||||
if not r:
|
if not r:
|
||||||
s.close()
|
s.close()
|
||||||
if bar is not None:
|
if bar is not None:
|
||||||
|
|
Loading…
Reference in New Issue