diff --git a/aiserver.py b/aiserver.py index 9375034b..9a48cd9b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -413,12 +413,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.transformer.wte.to(breakmodel.gpu_device) - model.transformer.ln_f.to(breakmodel.gpu_device) + model.transformer.wte.to(breakmodel.embedding_device) + model.transformer.ln_f.to(breakmodel.layernormfinal_device) if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.gpu_device) + model.lm_head.to(breakmodel.embedding_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.gpu_device) + model.transformer.wpe.to(breakmodel.positional_device) gc.collect() if(args.breakmodel_layers is not None): breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) @@ -465,12 +465,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.transformer.wte.to(breakmodel.gpu_device) - model.transformer.ln_f.to(breakmodel.gpu_device) + model.transformer.wte.to(breakmodel.embedding_device) + model.transformer.ln_f.to(breakmodel.layernormfinal_device) if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.gpu_device) + model.lm_head.to(breakmodel.embedding_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.gpu_device) + model.transformer.wpe.to(breakmodel.positional_device) gc.collect() if(args.breakmodel_layers is not None): breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) @@ -1229,7 +1229,7 @@ def generate(txt, min, max): # its first argument if we're using breakmodel, otherwise a string # is fine if(vars.hascuda and vars.breakmodel): - gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device) + gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device) else: gen_in = txt diff --git a/breakmodel.py b/breakmodel.py index c5bdde28..61c931b9 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -229,11 +229,17 @@ class MaxSharedRamBlocksException(Exception): breakmodel = True -gpu_device = 'cuda' +devices = ['cpu', 'cuda'] total_blocks = 24 ram_blocks = 7 max_shared_ram_blocks = None +# I highly suggest these all be set to the same device unless you really know what you're doing! +# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device) +embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device +positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J) +layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu' + def new_forward( self, @@ -260,44 +266,13 @@ def new_forward( setattr(self,"extrastorage",{}) torch.cuda.empty_cache() + for i in range(ram_blocks): + self.h[i].to(devices[0]) + for i in range(ram_blocks,len(self.h)): - self.h[i].to(gpu_device) + self.h[i].to(devices[1]) - for i in range(ram_blocks): - self.h[i].to("cpu") - self.extrastorage[i] = copy.deepcopy(self.h[i]) - smalltensor = torch.tensor(0).to(gpu_device) - for param1 in self.h[i].parameters(): - param1.data = smalltensor - self.h[i].to(gpu_device) - for i in range(len(self.h)): - for param in self.h[i].parameters(): - param.requires_grad = False - param.data = param.data.detach() - gc.collect() - torch.cuda.empty_cache() - - for i in range(ram_blocks): - for param in self.extrastorage[i].parameters(): - param.requires_grad = False - if i < max_shared_ram_blocks: - try: - param.data = param.data.detach().pin_memory() - except: - raise MaxSharedRamBlocksException(i) - else: - param.data = param.data.detach() - gc.collect() - torch.cuda.empty_cache() - - if ram_blocks: - for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()): - param1.data = param2.data.to(gpu_device, non_blocking=False).detach() - - for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()): - param1.data = param2.data.to(gpu_device, non_blocking=False).detach() - #END MODEL BREAK EDITS output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -331,7 +306,7 @@ def new_forward( else: past_length = past_key_values[0][0].size(-2) - device = input_ids.device if input_ids is not None else inputs_embeds.device + device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) @@ -368,6 +343,8 @@ def new_forward( head_mask = self.get_head_mask(head_mask, self.config.num_layers) if inputs_embeds is None: + if breakmodel: + input_ids = input_ids.to(embedding_device) inputs_embeds = self.wte(input_ids) if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None): @@ -382,7 +359,11 @@ def new_forward( if hasattr(self, 'rotary') and self.rotary: hidden_states = inputs_embeds else: + if breakmodel: + position_ids = position_ids.to(positional_device) position_embeds = self.wpe(position_ids) + if breakmodel: + position_embeds = position_embeds.to(embedding_device) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: @@ -396,23 +377,8 @@ def new_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - - - if breakmodel: - copystream = torch.cuda.Stream(device=0,priority = -1) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if breakmodel: - if i in range(ram_blocks): - index1 = (i+1)%ram_blocks - for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()): - param1.data = param2.data - for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()): - with torch.cuda.stream(copystream): - torch.cuda.comm.broadcast(param2.data,out = [param1.data]) - - attn_type = self.config.attention_layers[i] attn_mask = global_attention_mask @@ -443,11 +409,13 @@ def new_forward( head_mask[i], ) else: + if breakmodel: + device = devices[0] if i < ram_blocks else devices[1] outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attn_mask, - head_mask=head_mask[i], + hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, + layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past, + attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask, + head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i], use_cache=use_cache, output_attentions=output_attentions, ) @@ -460,18 +428,11 @@ def new_forward( all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if breakmodel: - if i in range(ram_blocks): - torch.cuda.synchronize() - torch.cuda.empty_cache() - if breakmodel: - del copystream - - torch.cuda.empty_cache() - - + hidden_states = hidden_states.to(layernormfinal_device) hidden_states = self.ln_f(hidden_states) + if breakmodel: + hidden_states = hidden_states.to(embedding_device) hidden_states = hidden_states.view(*output_shape) # Add last hidden state @@ -480,7 +441,6 @@ def new_forward( if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents,