Merge pull request #232 from VE-FORBRYDERNE/mkultra

Universal mkultra-based soft prompt tuner
This commit is contained in:
henk717 2022-10-22 14:13:42 +02:00 committed by GitHub
commit 351fb3c80b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 1099 additions and 6 deletions

View File

@ -452,6 +452,7 @@ def emit(*args, **kwargs):
return _emit(*args, **kwargs)
except AttributeError:
return socketio.emit(*args, **kwargs)
utils.emit = emit
# marshmallow/apispec setup
from apispec import APISpec
@ -879,7 +880,7 @@ def device_config(config):
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
logger.init_ok("Final device configuration:", status="Info")
device_list(n_layers)
device_list(n_layers, primary=breakmodel.primary_device)
# If all layers are on the same device, use the old GPU generation mode
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
@ -1360,6 +1361,8 @@ def general_startup(override_args=None):
args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
else:
args = parser.parse_args()
utils.args = args
set_logger_verbosity(args.verbosity)
quiesce_logger(args.quiesce)
@ -1796,7 +1799,9 @@ def patch_transformers():
if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained
if(not hasattr(PreTrainedModel, "_kai_patched")):
PreTrainedModel.from_pretrained = new_from_pretrained
PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
@ -2662,7 +2667,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained
if(not hasattr(PreTrainedModel, "_kai_patched")):
PreTrainedModel.from_pretrained = new_from_pretrained
PreTrainedModel._kai_patched = True
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):

1082
prompt_tuner.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,7 @@ except ImportError:
HAS_ACCELERATE = False
vars = None
args = None
num_shards: Optional[int] = None
current_shard = 0
from_pretrained_model_name = ""
@ -40,6 +41,8 @@ named_buffers: Optional[List[tuple]] = None
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
emit = None
#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called
@ -198,6 +201,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
pass
import transformers
aria2_port = 6799 if vars is None else vars.aria2_port
lengths = {}
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
@ -208,9 +212,9 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
while p.poll() is None:
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
r = s.post(f"http://localhost:{aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
s.close()
if bar is not None:
@ -602,4 +606,4 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s
else:
recurse(c[1], head=name + ".")
recurse(model)
return missing_names
return missing_names