From defbb53b689fa448b4848eb257046934145c1d9b Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 01:03:38 -0400 Subject: [PATCH] OPT breakmodel --- aiserver.py | 49 ++++++++++---- breakmodel.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++++-- utils.py | 6 ++ 3 files changed, 220 insertions(+), 17 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6c9401e2..f988e6ba 100644 --- a/aiserver.py +++ b/aiserver.py @@ -274,7 +274,7 @@ class vars: recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None useprompt = False # Whether to send the full prompt with every submit action breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only - bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM only, currently) + bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently) nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config) smandelete = False # Whether stories can be deleted from inside the browser smanrename = False # Whether stories can be renamed from inside the browser @@ -391,7 +391,7 @@ def device_list(n_layers, primary=None, selected=None): def device_config(config): global breakmodel, generator import breakmodel - n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer + n_layers = utils.num_layers(config) if(args.breakmodel_gpulayers is not None): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) @@ -464,7 +464,7 @@ def device_config(config): # If all layers are on the same device, use the old GPU generation mode while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): breakmodel.gpu_blocks.pop() - if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)): + if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, utils.num_layers(config))): vars.breakmodel = False vars.usegpu = True vars.gpu_device = len(breakmodel.gpu_blocks)-1 @@ -496,22 +496,33 @@ def move_model_to_devices(model): model.lm_head.to(breakmodel.primary_device) if(hasattr(model.transformer, 'wpe')): model.transformer.wpe.to(breakmodel.primary_device) - else: + elif(not hasattr(model.model, "decoder")): model.model.embed_tokens.to(breakmodel.primary_device) model.model.layer_norm.to(breakmodel.primary_device) model.lm_head.to(breakmodel.primary_device) model.model.embed_positions.to(breakmodel.primary_device) + else: + model.model.decoder.embed_tokens.to(breakmodel.primary_device) + if(model.model.decoder.project_in is not None): + model.model.decoder.project_in.to(breakmodel.primary_device) + if(model.model.decoder.project_out is not None): + model.model.decoder.project_out.to(breakmodel.primary_device) + model.model.decoder.embed_positions.to(breakmodel.primary_device) gc.collect() GPTNeoModel.forward = breakmodel.new_forward_neo if("GPTJModel" in globals()): GPTJModel.forward = breakmodel.new_forward_neo # type: ignore if("XGLMModel" in globals()): XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore + if("OPTDecoder" in globals()): + OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore generator = model.generate if(hasattr(model, "transformer")): breakmodel.move_hidden_layers(model.transformer) - else: + elif(not hasattr(model.model, "decoder")): breakmodel.move_hidden_layers(model.model, model.model.layers) + else: + breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers) #==================================================================# # Allow the models to override some settings @@ -911,7 +922,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() - vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel + vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel if(args.breakmodel is not None and args.breakmodel): print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr) if(args.breakmodel_layers is not None): @@ -1123,6 +1134,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go globals()[m] = getattr(__import__("transformers"), m) except: pass + try: + from transformers.models.opt.modeling_opt import OPTDecoder + except: + pass import transformers.generation_utils from transformers import __version__ as transformers_version @@ -1253,8 +1268,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go input_ids.clamp_(max=self.config.vocab_size-1) if(hasattr(self, "transformer")): inputs_embeds = self.transformer.wte(input_ids) - else: + elif(not hasattr(model.model, "decoder")): inputs_embeds = self.model.embed_tokens(input_ids) + else: + inputs_embeds = self.model.decoder.embed_tokens(input_ids) if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( @@ -1262,14 +1279,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) - if(not hasattr(self, "transformer")): + if(hasattr(self.model, "embed_scale")): inputs_embeds *= self.model.embed_scale kwargs['inputs_embeds'] = inputs_embeds return old_forward(self, *args, **kwargs) cls.forward = new_causallm_forward for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): patch_causallm(cls) - for c in ("GPTJForCausalLM", "XGLMForCausalLM"): + for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"): try: patch_causallm(getattr(__import__("transformers"), c)) except: @@ -1430,12 +1447,18 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def get_hidden_size_from_model(model): try: - return int(model.transformer.hidden_size) + return int(model.model.decoder.project_in.in_features) except: try: - return int(model.transformer.embed_dim) + return int(model.model.decoder.embed_tokens.out_features) except: - return int(model.lm_head.in_features) + try: + return int(model.transformer.hidden_size) + except: + try: + return int(model.transformer.embed_dim) + except: + return int(model.lm_head.in_features) def maybe_low_cpu_mem_usage() -> Dict[str, Any]: if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")): @@ -1490,7 +1513,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go import shutil shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) print("\n", flush=True) - with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True): + with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True): if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time lowmem = {} if(os.path.isdir(vars.custmodpth)): diff --git a/breakmodel.py b/breakmodel.py index 9818e6d9..262d39bd 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -633,11 +633,11 @@ def new_forward_xglm( layer_outputs = decoder_layer( hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, - encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None, - encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None, - layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None), + encoder_hidden_states=encoder_hidden_states.to(device) if breakmodel and encoder_hidden_states is not None else encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask.to(device) if breakmodel and encoder_attention_mask is not None else encoder_attention_mask, + layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), cross_attn_layer_head_mask=( - (cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None + (cross_attn_head_mask[idx].to(device) if breakmodel and cross_attn_head_mask[idx] is not None else cross_attn_head_mask[idx]) if cross_attn_head_mask is not None else None ), past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, output_attentions=output_attentions, @@ -686,3 +686,177 @@ def new_forward_xglm( attentions=all_self_attns, cross_attentions=all_cross_attentions, ) + + +def new_forward_opt( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, +): + assert len(gpu_blocks) <= torch.cuda.device_count() + assert sum(gpu_blocks) <= len(self.layers) + ram_blocks = len(self.layers) - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + if breakmodel: + input_ids = input_ids.to(primary_device) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if breakmodel: + inputs_embeds = inputs_embeds.to(primary_device) + if attention_mask is None: + attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) + + positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :] + if breakmodel: + positions = positions.to(primary_device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + if breakmodel and ram_blocks: + copystream = torch.cuda.Stream(device=primary_device, priority=-1) + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + i = idx + if breakmodel: + if i in range(ram_blocks): + index1 = (i+1)%ram_blocks + for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()): + param1.data = param2.data + for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()): + with torch.cuda.stream(copystream): + torch.cuda.comm.broadcast(param2.data,out = [param1.data]) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + if breakmodel: + device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, + layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), + past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if breakmodel: + if i in range(ram_blocks): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if breakmodel: + if ram_blocks: + del copystream + torch.cuda.empty_cache() + hidden_states = hidden_states.to(primary_device) + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + if breakmodel: + hidden_states = hidden_states.to(primary_device) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/utils.py b/utils.py index 0fdfa125..ea3cab94 100644 --- a/utils.py +++ b/utils.py @@ -135,6 +135,12 @@ def decodenewlines(txt): return txt.replace("", '\n') return txt +#==================================================================# +# Returns number of layers given an HF model config +#==================================================================# +def num_layers(config): + return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers + #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================#