Merge pull request #232 from VE-FORBRYDERNE/mkultra
Universal mkultra-based soft prompt tuner
This commit is contained in:
commit
351fb3c80b
13
aiserver.py
13
aiserver.py
|
@ -452,6 +452,7 @@ def emit(*args, **kwargs):
|
|||
return _emit(*args, **kwargs)
|
||||
except AttributeError:
|
||||
return socketio.emit(*args, **kwargs)
|
||||
utils.emit = emit
|
||||
|
||||
# marshmallow/apispec setup
|
||||
from apispec import APISpec
|
||||
|
@ -879,7 +880,7 @@ def device_config(config):
|
|||
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
||||
|
||||
logger.init_ok("Final device configuration:", status="Info")
|
||||
device_list(n_layers)
|
||||
device_list(n_layers, primary=breakmodel.primary_device)
|
||||
|
||||
# If all layers are on the same device, use the old GPU generation mode
|
||||
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
||||
|
@ -1360,6 +1361,8 @@ def general_startup(override_args=None):
|
|||
args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
utils.args = args
|
||||
|
||||
set_logger_verbosity(args.verbosity)
|
||||
quiesce_logger(args.quiesce)
|
||||
|
@ -1796,7 +1799,9 @@ def patch_transformers():
|
|||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
if(not hasattr(PreTrainedModel, "_kai_patched")):
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
PreTrainedModel._kai_patched = True
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
|
@ -2662,7 +2667,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
if(not hasattr(PreTrainedModel, "_kai_patched")):
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
PreTrainedModel._kai_patched = True
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
|
|
File diff suppressed because it is too large
Load Diff
10
utils.py
10
utils.py
|
@ -27,6 +27,7 @@ except ImportError:
|
|||
HAS_ACCELERATE = False
|
||||
|
||||
vars = None
|
||||
args = None
|
||||
num_shards: Optional[int] = None
|
||||
current_shard = 0
|
||||
from_pretrained_model_name = ""
|
||||
|
@ -40,6 +41,8 @@ named_buffers: Optional[List[tuple]] = None
|
|||
|
||||
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||
|
||||
emit = None
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
# at least x seconds have passed without the function being called
|
||||
|
@ -198,6 +201,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
|||
pass
|
||||
|
||||
import transformers
|
||||
aria2_port = 6799 if vars is None else vars.aria2_port
|
||||
lengths = {}
|
||||
s = requests.Session()
|
||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||
|
@ -208,9 +212,9 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
|||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
while p.poll() is None:
|
||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||
r = s.post(f"http://localhost:{aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||
if not r:
|
||||
s.close()
|
||||
if bar is not None:
|
||||
|
@ -602,4 +606,4 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s
|
|||
else:
|
||||
recurse(c[1], head=name + ".")
|
||||
recurse(model)
|
||||
return missing_names
|
||||
return missing_names
|
Loading…
Reference in New Issue