diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index f879cb37..c65b3ab6 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -59,8 +59,6 @@ class model_backend(HFTorchInferenceModel): # Also, lazy loader doesn't support GPT-2 models self.lazy_load = False - # If we're using torch_lazy_loader, we need to get breakmodel config - # early so that it knows where to load the individual model tensors logger.debug( "lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format( self.lazy_load, @@ -70,6 +68,16 @@ class model_backend(HFTorchInferenceModel): ) ) + # If we're using torch_lazy_loader, we need to get breakmodel config + # early so that it knows where to load the individual model tensors + if ( + self.lazy_load + and utils.koboldai_vars.hascuda + and utils.koboldai_vars.breakmodel + and not utils.koboldai_vars.nobreakmodel + ): + self.breakmodel_device_config(self.model_config) + if self.lazy_load: # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time tf_kwargs.pop("low_cpu_mem_usage", None) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 1fb78717..c1bcdf0b 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -1,17 +1,13 @@ from __future__ import annotations +from dataclasses import dataclass -import gc import os import time import bisect -import zipfile -import functools import itertools import traceback import contextlib -from accelerate.big_modeling import load_checkpoint_and_dispatch -from accelerate.utils.modeling import infer_auto_device_map, load_checkpoint_in_model -from tqdm.auto import tqdm +from torch import nn from typing import Dict, List, Optional, Union import torch @@ -41,17 +37,36 @@ from modeling.inference_model import ( use_core_manipulations, ) -try: - import accelerate.utils -except ModuleNotFoundError as e: - if not utils.koboldai_vars.use_colab_tpu: - raise e - # When set to true, messages will appear in the console if samplers are not # changing the scores. Keep in mind some samplers don't always change the # scores for each token. LOG_SAMPLER_NO_EFFECT = False +class BreakmodelConfig: + def __init__(self) -> None: + self.disk_blocks = 0 + self.gpu_blocks = [] + self.primary_device = 0 if torch.cuda.device_count() > 0 else "cpu" + + def get_device_map(self, model: nn.Module) -> dict: + ram_blocks = len(utils.layers_module_names) - sum(self.gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(self.gpu_blocks)) + device_map = {} + + for name in utils.layers_module_names: + layer = int(name.rsplit(".", 1)[1]) + device = ( + ("disk" if layer < self.disk_blocks else "cpu") + if layer < ram_blocks + else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) + ) + device_map[name] = device + + for name in utils.get_missing_module_names(model, list(device_map.keys())): + device_map[name] = self.primary_device + + return device_map + class HFTorchInferenceModel(HFInferenceModel): def __init__(self) -> None: @@ -80,6 +95,16 @@ class HFTorchInferenceModel(HFInferenceModel): post_token_probs=True, ) self._old_stopping_criteria = None + self.breakmodel_config = BreakmodelConfig() + + def set_input_parameters(self, parameters): + ret = super().set_input_parameters(parameters) + + # Hook onto input param setting for setting breakmodel stuff + self.breakmodel_config.gpu_blocks = self.layers + self.breakmodel_config.disk_blocks = self.disk_layers + + return ret def _apply_warpers( self, scores: torch.Tensor, input_ids: torch.Tensor @@ -278,17 +303,20 @@ class HFTorchInferenceModel(HFInferenceModel): # Try to determine model type from either AutoModel or falling back to legacy try: - print("[HUGE SKELETON] LOADING FROM PRETRAINED") + with lazy_loader.use_lazy_load(dematerialized_modules=True): + metamodel = AutoModelForCausalLM.from_config(self.model_config) + device_map = self.breakmodel_config.get_device_map(metamodel) + with lazy_loader.use_lazy_load( enable=True, # DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!! - # dematerialized_modules=True, dematerialized_modules=False, ): + print(device_map) model = AutoModelForCausalLM.from_pretrained( location, - device_map="auto", - # max_memory={0: "10GiB", 1: "7GiB", "cpu": "20GiB"}, + # device_map="auto", + device_map=device_map, offload_folder="accelerate-disk-cache", torch_dtype=torch.float16, **tf_kwargs, @@ -389,18 +417,19 @@ class HFTorchInferenceModel(HFInferenceModel): yield False def breakmodel_device_list(self, n_layers, primary=None, selected=None): - return - # TODO: Find a better place for this or rework this - device_count = torch.cuda.device_count() if device_count < 2: primary = None + logger.debug("n_layers: {}".format(n_layers)) - logger.debug("gpu blocks: {}".format(breakmodel.gpu_blocks)) - gpu_blocks = breakmodel.gpu_blocks + ( - device_count - len(breakmodel.gpu_blocks) + logger.debug("gpu blocks: {}".format(self.breakmodel_config.gpu_blocks)) + + gpu_blocks = self.breakmodel_config.gpu_blocks + ( + device_count - len(self.breakmodel_config.gpu_blocks) ) * [0] + print(f"{Colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{Colors.END}") + for i in range(device_count): name = torch.cuda.get_device_name(i) if len(name) > 47: @@ -410,75 +439,70 @@ class HFTorchInferenceModel(HFInferenceModel): print( f"{row_color}{Colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{Colors.END}" ) + row_color = Colors.END sep_color = Colors.YELLOW print( - f"{row_color}{Colors.YELLOW + '->' + row_color if -1 == selected else ' '} {' '*9} N/A {sep_color}|{row_color} {breakmodel.disk_blocks:3} {sep_color}|{row_color} (Disk cache){Colors.END}" + f"{row_color}{Colors.YELLOW + '->' + row_color if -1 == selected else ' '} {' '*9} N/A {sep_color}|{row_color} {self.breakmodel_config.disk_blocks:3} {sep_color}|{row_color} (Disk cache){Colors.END}" ) print( f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){Colors.END}" ) def breakmodel_device_config(self, config): - # TODO: Find a better place for this or rework this - return - - global breakmodel, generator - import breakmodel - n_layers = utils.num_layers(config) - logger.debug("gpu blocks before modification: {}".format(breakmodel.gpu_blocks)) + logger.debug("gpu blocks before modification: {}".format(self.breakmodel_config.gpu_blocks)) if utils.args.cpu: - breakmodel.gpu_blocks = [0] * n_layers + self.breakmodel_config.gpu_blocks = [0] * n_layers return - elif breakmodel.gpu_blocks == []: + elif self.breakmodel_config.gpu_blocks == []: logger.info("Breakmodel not specified, assuming GPU 0") - breakmodel.gpu_blocks = [n_layers] + self.breakmodel_config.gpu_blocks = [n_layers] n_layers = 0 else: s = n_layers - for i in range(len(breakmodel.gpu_blocks)): - if breakmodel.gpu_blocks[i] <= -1: - breakmodel.gpu_blocks[i] = s + for i in range(len(self.breakmodel_config.gpu_blocks)): + if self.breakmodel_config.gpu_blocks[i] <= -1: + self.breakmodel_config.gpu_blocks[i] = s break else: - s -= breakmodel.gpu_blocks[i] - assert sum(breakmodel.gpu_blocks) <= n_layers - n_layers -= sum(breakmodel.gpu_blocks) - if breakmodel.disk_blocks is not None: - assert breakmodel.disk_blocks <= n_layers - n_layers -= breakmodel.disk_blocks + s -= self.breakmodel_config.gpu_blocks[i] + assert sum(self.breakmodel_config.gpu_blocks) <= n_layers + n_layers -= sum(self.breakmodel_config.gpu_blocks) + if self.breakmodel_config.disk_blocks is not None: + assert self.breakmodel_config.disk_blocks <= n_layers + n_layers -= self.breakmodel_config.disk_blocks logger.init_ok("Final device configuration:", status="Info") - self.breakmodel_device_list(n_layers, primary=breakmodel.primary_device) + self.breakmodel_device_list(n_layers, primary=self.breakmodel_config.primary_device) with open( "settings/{}.breakmodel".format(self.model_name.replace("/", "_")), "w" ) as file: file.write( "{}\n{}".format( - ",".join(map(str, breakmodel.gpu_blocks)), breakmodel.disk_blocks + ",".join(map(str, self.breakmodel_config.gpu_blocks)), self.breakmodel_config.disk_blocks ) ) # If all layers are on the same device, use the old GPU generation mode - while len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0: - breakmodel.gpu_blocks.pop() + while len(self.breakmodel_config.gpu_blocks) and self.breakmodel_config.gpu_blocks[-1] == 0: + self.breakmodel_config.gpu_blocks.pop() self.breakmodel = True - if len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in ( + if len(self.breakmodel_config.gpu_blocks) and self.breakmodel_config.gpu_blocks[-1] in ( -1, utils.num_layers(config), ): logger.debug("All layers on same GPU. Breakmodel disabled") self.breakmodel = False self.usegpu = True - utils.koboldai_vars.gpu_device = len(breakmodel.gpu_blocks) - 1 + utils.koboldai_vars.gpu_device = len(self.breakmodel_config.gpu_blocks) - 1 return - if not breakmodel.gpu_blocks: + if not self.breakmodel_config.gpu_blocks: logger.warning("Nothing assigned to a GPU, reverting to CPU only mode") self.breakmodel = False self.usegpu = False