From 4d9eab378508fb216bba869de3473664fefb7f59 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 23 Sep 2021 20:57:18 -0400 Subject: [PATCH 01/10] K80 test --- aiserver.py | 18 +++++----- breakmodel.py | 94 +++++++++++++++------------------------------------ 2 files changed, 36 insertions(+), 76 deletions(-) diff --git a/aiserver.py b/aiserver.py index 9375034b..9a48cd9b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -413,12 +413,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.transformer.wte.to(breakmodel.gpu_device) - model.transformer.ln_f.to(breakmodel.gpu_device) + model.transformer.wte.to(breakmodel.embedding_device) + model.transformer.ln_f.to(breakmodel.layernormfinal_device) if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.gpu_device) + model.lm_head.to(breakmodel.embedding_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.gpu_device) + model.transformer.wpe.to(breakmodel.positional_device) gc.collect() if(args.breakmodel_layers is not None): breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) @@ -465,12 +465,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): breakmodel.total_blocks = n_layers model.half().to('cpu') gc.collect() - model.transformer.wte.to(breakmodel.gpu_device) - model.transformer.ln_f.to(breakmodel.gpu_device) + model.transformer.wte.to(breakmodel.embedding_device) + model.transformer.ln_f.to(breakmodel.layernormfinal_device) if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.gpu_device) + model.lm_head.to(breakmodel.embedding_device) if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.gpu_device) + model.transformer.wpe.to(breakmodel.positional_device) gc.collect() if(args.breakmodel_layers is not None): breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) @@ -1229,7 +1229,7 @@ def generate(txt, min, max): # its first argument if we're using breakmodel, otherwise a string # is fine if(vars.hascuda and vars.breakmodel): - gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device) + gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device) else: gen_in = txt diff --git a/breakmodel.py b/breakmodel.py index c5bdde28..61c931b9 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -229,11 +229,17 @@ class MaxSharedRamBlocksException(Exception): breakmodel = True -gpu_device = 'cuda' +devices = ['cpu', 'cuda'] total_blocks = 24 ram_blocks = 7 max_shared_ram_blocks = None +# I highly suggest these all be set to the same device unless you really know what you're doing! +# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device) +embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device +positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J) +layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu' + def new_forward( self, @@ -260,44 +266,13 @@ def new_forward( setattr(self,"extrastorage",{}) torch.cuda.empty_cache() + for i in range(ram_blocks): + self.h[i].to(devices[0]) + for i in range(ram_blocks,len(self.h)): - self.h[i].to(gpu_device) + self.h[i].to(devices[1]) - for i in range(ram_blocks): - self.h[i].to("cpu") - self.extrastorage[i] = copy.deepcopy(self.h[i]) - smalltensor = torch.tensor(0).to(gpu_device) - for param1 in self.h[i].parameters(): - param1.data = smalltensor - self.h[i].to(gpu_device) - for i in range(len(self.h)): - for param in self.h[i].parameters(): - param.requires_grad = False - param.data = param.data.detach() - gc.collect() - torch.cuda.empty_cache() - - for i in range(ram_blocks): - for param in self.extrastorage[i].parameters(): - param.requires_grad = False - if i < max_shared_ram_blocks: - try: - param.data = param.data.detach().pin_memory() - except: - raise MaxSharedRamBlocksException(i) - else: - param.data = param.data.detach() - gc.collect() - torch.cuda.empty_cache() - - if ram_blocks: - for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()): - param1.data = param2.data.to(gpu_device, non_blocking=False).detach() - - for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()): - param1.data = param2.data.to(gpu_device, non_blocking=False).detach() - #END MODEL BREAK EDITS output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -331,7 +306,7 @@ def new_forward( else: past_length = past_key_values[0][0].size(-2) - device = input_ids.device if input_ids is not None else inputs_embeds.device + device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) @@ -368,6 +343,8 @@ def new_forward( head_mask = self.get_head_mask(head_mask, self.config.num_layers) if inputs_embeds is None: + if breakmodel: + input_ids = input_ids.to(embedding_device) inputs_embeds = self.wte(input_ids) if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None): @@ -382,7 +359,11 @@ def new_forward( if hasattr(self, 'rotary') and self.rotary: hidden_states = inputs_embeds else: + if breakmodel: + position_ids = position_ids.to(positional_device) position_embeds = self.wpe(position_ids) + if breakmodel: + position_embeds = position_embeds.to(embedding_device) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: @@ -396,23 +377,8 @@ def new_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - - - if breakmodel: - copystream = torch.cuda.Stream(device=0,priority = -1) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if breakmodel: - if i in range(ram_blocks): - index1 = (i+1)%ram_blocks - for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()): - param1.data = param2.data - for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()): - with torch.cuda.stream(copystream): - torch.cuda.comm.broadcast(param2.data,out = [param1.data]) - - attn_type = self.config.attention_layers[i] attn_mask = global_attention_mask @@ -443,11 +409,13 @@ def new_forward( head_mask[i], ) else: + if breakmodel: + device = devices[0] if i < ram_blocks else devices[1] outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attn_mask, - head_mask=head_mask[i], + hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, + layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past, + attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask, + head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i], use_cache=use_cache, output_attentions=output_attentions, ) @@ -460,18 +428,11 @@ def new_forward( all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if breakmodel: - if i in range(ram_blocks): - torch.cuda.synchronize() - torch.cuda.empty_cache() - if breakmodel: - del copystream - - torch.cuda.empty_cache() - - + hidden_states = hidden_states.to(layernormfinal_device) hidden_states = self.ln_f(hidden_states) + if breakmodel: + hidden_states = hidden_states.to(embedding_device) hidden_states = hidden_states.view(*output_shape) # Add last hidden state @@ -480,7 +441,6 @@ def new_forward( if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, From 0937bb33e7f72d7ca1ac252138d175dd23a23881 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 2 Oct 2021 12:19:37 -0400 Subject: [PATCH 02/10] Clarify licensing for breakmodel.py --- breakmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/breakmodel.py b/breakmodel.py index 61c931b9..1c5ae3fd 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -1,10 +1,10 @@ ''' This is a MODIFIED version of arrmansa's low VRAM patch. https://github.com/arrmansa/Basic-UI-for-GPT-J-6B-with-low-vram/blob/main/GPT-J-6B-Low-Vram-UI.ipynb +The ORIGINAL version of the patch is released under the Apache License 2.0 Copyright 2021 arrmansa Copyright 2021 finetuneanon Copyright 2018 The Hugging Face team -Released under the Apache License 2.0 Apache License From a283d34b2731abfe7f5f1e939117491f0755cedb Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 09:38:57 -0400 Subject: [PATCH 03/10] Multiple GPU support --- aiserver.py | 146 +++++++++++++++++++++++++++++--------------------- breakmodel.py | 89 +++++++++++++++++++----------- 2 files changed, 144 insertions(+), 91 deletions(-) diff --git a/aiserver.py b/aiserver.py index 92d33f8f..46ec25d5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -178,6 +178,88 @@ def getmodelname(): modelname = vars.model return modelname +#==================================================================# +# Breakmodel configuration functions +#==================================================================# +def device_list(n_layers, primary=None, selected=None): + device_count = torch.cuda.device_count() + if(device_count < 2): + primary = None + gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0] + print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}") + for i in range(device_count): + name = torch.cuda.get_device_name(i) + if(len(name) > 47): + name = "..." + name[-44:] + row_color = colors.END + sep_color = colors.YELLOW + print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}") + row_color = colors.END + sep_color = colors.YELLOW + print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") + +def device_config(model): + global breakmodel, generator + import breakmodel + n_layers = model.config.num_layers + model.half().to('cpu') + gc.collect() + if(args.breakmodel_layers is not None): + breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))] + else: + device_count = torch.cuda.device_count() + if(device_count > 1): + print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.") + print("VRAM usage in your primary GPU will be higher than for your other ones.") + print("It is recommended you make your fastest GPU your primary GPU.") + device_list(n_layers) + while(True): + primaryselect = input("device ID> ") + if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count): + breakmodel.primary_device = int(primaryselect) + else: + print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}") + else: + breakmodel.primary_device = 0 + + print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU") + print("you can split the model between your CPU and your GPU(s), or between") + print("multiple GPUs if you have more than one.") + print("By putting more 'layers' on a GPU or CPU, more computations will be") + print("done on that device and more VRAM or RAM will be required on that device") + print("(roughly proportional to number of layers).") + print("It should be noted that GPUs are orders of magnitude faster than the CPU.") + print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n") + + for i in range(device_count): + device_list(n_layers, primary=breakmodel.primary_device, selected=i) + print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n") + while(True): + layerselect = input("# of layers> ") + if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers): + layerselect = int(layerselect) + layerselect = n_layers if layerselect == -1 else layerselect + breakmodel.gpu_blocks.append(layerselect) + n_layers -= layerselect + break + else: + print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}") + if(n_layers == 0): + break + + print(colors.PURPLE + "\nFinal device configuration:") + device_list(n_layers) + + model.transformer.wte.to(breakmodel.primary_device) + model.transformer.ln_f.to(breakmodel.primary_device) + if(hasattr(model, 'lm_head')): + model.lm_head.to(breakmodel.primary_device) + if(not hasattr(model.config, 'rotary') or not model.config.rotary): + model.transformer.wpe.to(breakmodel.primary_device) + gc.collect() + GPTNeoModel.forward = breakmodel.new_forward + generator = model.generate + #==================================================================# # Startup #==================================================================# @@ -414,36 +496,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(vars.usegpu): generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0) elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) - import breakmodel - n_layers = model.config.num_layers - breakmodel.total_blocks = n_layers - model.half().to('cpu') - gc.collect() - model.transformer.wte.to(breakmodel.embedding_device) - model.transformer.ln_f.to(breakmodel.layernormfinal_device) - if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.embedding_device) - if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.positional_device) - gc.collect() - if(args.breakmodel_layers is not None): - breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) - else: - print(colors.CYAN + "\nHow many layers would you like to put into system RAM?") - print("The more of them you put into system RAM, the slower it will run,") - print("but it will require less VRAM") - print("(roughly proportional to number of layers).") - print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n") - while(True): - layerselect = input("# of layers> ") - if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers): - breakmodel.ram_blocks = int(layerselect) - break - else: - print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}") - print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}") - GPTNeoModel.forward = breakmodel.new_forward - generator = model.generate + device_config(model) else: generator = pipeline('text-generation', model=model, tokenizer=tokenizer) else: @@ -465,37 +518,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(vars.usegpu): generator = pipeline('text-generation', model=vars.model, device=0) elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) - import breakmodel model = AutoModelForCausalLM.from_pretrained(vars.model) - n_layers = model.config.num_layers - breakmodel.total_blocks = n_layers - model.half().to('cpu') - gc.collect() - model.transformer.wte.to(breakmodel.embedding_device) - model.transformer.ln_f.to(breakmodel.layernormfinal_device) - if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.embedding_device) - if(not hasattr(model.config, 'rotary') or not model.config.rotary): - model.transformer.wpe.to(breakmodel.positional_device) - gc.collect() - if(args.breakmodel_layers is not None): - breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) - else: - print(colors.CYAN + "\nHow many layers would you like to put into system RAM?") - print("The more of them you put into system RAM, the slower it will run,") - print("but it will require less VRAM") - print("(roughly proportional to number of layers).") - print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n") - while(True): - layerselect = input("# of layers> ") - if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers): - breakmodel.ram_blocks = int(layerselect) - break - else: - print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}") - print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}") - GPTNeoModel.forward = breakmodel.new_forward - generator = model.generate + device_config(model) else: generator = pipeline('text-generation', model=vars.model) else: @@ -1245,7 +1269,7 @@ def generate(txt, min, max): # its first argument if we're using breakmodel, otherwise a string # is fine if(vars.hascuda and vars.breakmodel): - gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device) + gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device) else: gen_in = txt diff --git a/breakmodel.py b/breakmodel.py index 1c5ae3fd..905768a3 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -215,6 +215,8 @@ import torch import torch.cuda.comm import copy import gc +import itertools +import bisect from transformers.modeling_outputs import BaseModelOutputWithPast @@ -222,23 +224,9 @@ from transformers.utils import logging logger = logging.get_logger(__name__) -class MaxSharedRamBlocksException(Exception): - def __init__(self, i: int): - self.corrected_max_shared_ram_blocks = i - super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i)) - - breakmodel = True -devices = ['cpu', 'cuda'] -total_blocks = 24 -ram_blocks = 7 -max_shared_ram_blocks = None - -# I highly suggest these all be set to the same device unless you really know what you're doing! -# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device) -embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device -positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J) -layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu' +gpu_blocks = [] +primary_device = 0 def new_forward( @@ -256,21 +244,41 @@ def new_forward( return_dict=None, embs=None, ): - global max_shared_ram_blocks + assert len(gpu_blocks) <= torch.cuda.device_count() + assert sum(gpu_blocks) <= len(self.h) + ram_blocks = len(self.h) - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) if breakmodel: - if max_shared_ram_blocks is None: - max_shared_ram_blocks = total_blocks - if not hasattr(self, 'extrastorage'): setattr(self,"extrastorage",{}) torch.cuda.empty_cache() for i in range(ram_blocks): - self.h[i].to(devices[0]) + self.h[i].to("cpu") + self.extrastorage[i] = copy.deepcopy(self.h[i]) + smalltensor = torch.tensor(0).to(primary_device) + for param1 in self.h[i].parameters(): + param1.data = smalltensor + self.h[i].to(primary_device) + for param in self.extrastorage[i].parameters(): + param.requires_grad = False + param.data = param.data.detach().pin_memory() + gc.collect() + torch.cuda.empty_cache() - for i in range(ram_blocks,len(self.h)): - self.h[i].to(devices[1]) + if ram_blocks: + for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()): + param1.data = param2.data.to(primary_device, non_blocking=False).detach() + + for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()): + param1.data = param2.data.to(primary_device, non_blocking=False).detach() + + i = ram_blocks + for j in range(len(gpu_blocks)): + for _ in range(gpu_blocks[j]): + self.h[i].to(j) + i += 1 @@ -306,7 +314,7 @@ def new_forward( else: past_length = past_key_values[0][0].size(-2) - device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device + device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) @@ -344,7 +352,7 @@ def new_forward( if inputs_embeds is None: if breakmodel: - input_ids = input_ids.to(embedding_device) + input_ids = input_ids.to(primary_device) inputs_embeds = self.wte(input_ids) if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None): @@ -360,10 +368,10 @@ def new_forward( hidden_states = inputs_embeds else: if breakmodel: - position_ids = position_ids.to(positional_device) + position_ids = position_ids.to(primary_device) position_embeds = self.wpe(position_ids) if breakmodel: - position_embeds = position_embeds.to(embedding_device) + position_embeds = position_embeds.to(primary_device) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: @@ -377,8 +385,21 @@ def new_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + + if breakmodel and ram_blocks: + copystream = torch.cuda.Stream(device=0,priority = -1) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if breakmodel: + if i in range(ram_blocks): + index1 = (i+1)%ram_blocks + for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()): + param1.data = param2.data + for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()): + with torch.cuda.stream(copystream): + torch.cuda.comm.broadcast(param2.data,out = [param1.data]) + attn_type = self.config.attention_layers[i] attn_mask = global_attention_mask @@ -410,7 +431,7 @@ def new_forward( ) else: if breakmodel: - device = devices[0] if i < ram_blocks else devices[1] + device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) outputs = block( hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past, @@ -428,11 +449,19 @@ def new_forward( all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if breakmodel: + if i in range(ram_blocks): + torch.cuda.synchronize() + torch.cuda.empty_cache() + if breakmodel: - hidden_states = hidden_states.to(layernormfinal_device) + if ram_blocks: + del copystream + torch.cuda.empty_cache() + hidden_states = hidden_states.to(primary_device) hidden_states = self.ln_f(hidden_states) if breakmodel: - hidden_states = hidden_states.to(embedding_device) + hidden_states = hidden_states.to(primary_device) hidden_states = hidden_states.view(*output_shape) # Add last hidden state From 231621e7c28cb50f93e7ab9fe384c5422ddc76be Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 09:45:12 -0400 Subject: [PATCH 04/10] Use AutoModelForCausalLM for custom models with a model_type --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 46ec25d5..0a090110 100644 --- a/aiserver.py +++ b/aiserver.py @@ -487,7 +487,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): model_config = open(vars.custmodpth + "/config.json", "r") js = json.load(model_config) if("model_type" in js): - model = vars.custmodpth + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth) else: model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth) tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth) From f9e6a6da17ad52778c9863749d7876fd7f6d68b7 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 10:25:06 -0400 Subject: [PATCH 05/10] Slightly increased performance in breakmodel mode Commit a283d34b2731abfe7f5f1e939117491f0755cedb made breakmodel mode slower. Performance has been restored to how it was before that commit. --- breakmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/breakmodel.py b/breakmodel.py index 905768a3..73e40222 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -434,7 +434,7 @@ def new_forward( device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) outputs = block( hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, - layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past, + layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past, attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask, head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i], use_cache=use_cache, From fb90a7ed175167812b46ae8a901bd39cd223d0db Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 10:31:28 -0400 Subject: [PATCH 06/10] Change the help text for breakmodel to be more helpful --- aiserver.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0a090110..ddce311d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -322,10 +322,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): vars.breakmodel = True elif(vars.hascuda): if(vars.bmsupported): - print(colors.YELLOW + "You're using a model that supports GPU-CPU hybrid generation!\nCurrently only GPT-Neo models and GPT-J-6B support this feature.") + print(colors.YELLOW + "You're using a model that supports hybrid generation!") + print("This feature allows you to split the model between the CPU and GPU(s)") + print("(slower than GPU-only but uses less VRAM) or between multiple GPUs") + print("(allowing you to use the combined VRAM of all your GPUs).") + print("Currently only GPT-Neo and GPT-J models support this feature.") print("{0}Use GPU or CPU for generation?: (Default GPU){1}".format(colors.CYAN, colors.END)) if(vars.bmsupported): - print(f" 1 - GPU\n 2 - CPU\n 3 - Both (slower than GPU-only but uses less VRAM)\n") + print(f" 1 - GPU\n 2 - CPU\n 3 - Hybrid generation\n") else: print(" 1 - GPU\n 2 - CPU\n") genselected = False From a1e4405aa68635cd5abcd1e56c5dfb7f2a5cc16c Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 10:36:51 -0400 Subject: [PATCH 07/10] Automatically use breakmodel instead of GPU-only where supported There's really no reason to use GPU-only mode if breakmodel is supported because breakmodel can run in GPU-only mode too. --- aiserver.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/aiserver.py b/aiserver.py index ddce311d..c3dd4292 100644 --- a/aiserver.py +++ b/aiserver.py @@ -314,12 +314,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): genselected = True vars.usegpu = True vars.breakmodel = False + if(vars.bmsupported): + vars.usegpu = False + vars.breakmodel = True if(args.cpu): vars.usegpu = False vars.breakmodel = False - if(vars.bmsupported and args.breakmodel): - vars.usegpu = False - vars.breakmodel = True elif(vars.hascuda): if(vars.bmsupported): print(colors.YELLOW + "You're using a model that supports hybrid generation!") @@ -327,9 +327,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): print("(slower than GPU-only but uses less VRAM) or between multiple GPUs") print("(allowing you to use the combined VRAM of all your GPUs).") print("Currently only GPT-Neo and GPT-J models support this feature.") - print("{0}Use GPU or CPU for generation?: (Default GPU){1}".format(colors.CYAN, colors.END)) - if(vars.bmsupported): - print(f" 1 - GPU\n 2 - CPU\n 3 - Hybrid generation\n") + print("{0}Use hybrid generation or CPU-only generation?: (Default hybrid){1}".format(colors.CYAN, colors.END)) + print(f" 1 - Hybrid generation\n 2 - CPU\n") else: print(" 1 - GPU\n 2 - CPU\n") genselected = False @@ -342,17 +341,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): vars.usegpu = True genselected = True elif(genselect.isnumeric() and int(genselect) == 1): - vars.breakmodel = False - vars.usegpu = True - genselected = True + if(vars.bmsupported): + vars.breakmodel = True + vars.usegpu = False + genselected = True + else: + vars.breakmodel = False + vars.usegpu = True + genselected = True elif(genselect.isnumeric() and int(genselect) == 2): vars.breakmodel = False vars.usegpu = False genselected = True - elif(vars.bmsupported and genselect.isnumeric() and int(genselect) == 3): - vars.breakmodel = True - vars.usegpu = False - genselected = True else: print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END)) From 91352ea9f1f09480c56cc02a6602ebf2bf6797d8 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 11:22:09 -0400 Subject: [PATCH 08/10] Change the command line flags for breakmodel --- aiserver.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/aiserver.py b/aiserver.py index c3dd4292..e6f9d798 100644 --- a/aiserver.py +++ b/aiserver.py @@ -204,7 +204,15 @@ def device_config(model): n_layers = model.config.num_layers model.half().to('cpu') gc.collect() - if(args.breakmodel_layers is not None): + if(args.layers is not None): + try: + breakmodel.gpu_blocks = list(map(int, args.layers.split(','))) + assert len(gpu_blocks) <= torch.cuda.device_count() + assert sum(gpu_blocks) <= n_layers + except: + print("WARNING: --layers is malformatted. Please use the --help option to see correct usage of --layers. Defaulting to all layers on device 0.", file=sys.stderr) + breakmodel.gpu_blocks = [n_layers] + elif(args.breakmodel_layers is not None): breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))] else: device_count = torch.cuda.device_count() @@ -270,8 +278,9 @@ parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI fo parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)") parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.") -parser.add_argument("--breakmodel", action='store_true', help="For models that support GPU-CPU hybrid generation, use this feature instead of GPU or CPU generation") -parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used") +parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS) +parser.add_argument("--breakmodel_layers", type=int, help=argparse.SUPPRESS) +parser.add_argument("--layers", type=str, help="If using a model that supports hybrid generation, this is a comma-separated list that specifies how many layers to put on each GPU device. For example to put 8 layers on device 0, 9 layers on device 1 and 11 layers on device 2, use --layers 8,9,11") parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.") parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") @@ -304,6 +313,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "NeoCustom") + if(args.breakmodel is not None and args.breakmodel): + print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr) + if(args.breakmodel_layers is not None): + print("WARNING: --breakmodel_layers is deprecated. Use --layers instead (see --help for details).", file=sys.stderr) + if(not vars.bmsupported and (args.layers is not None or args.breakmodel_layers is not None)): + print("WARNING: This model does not support hybrid generation. --layers will be ignored.", file=sys.stderr) if(vars.hascuda): print("{0}FOUND!{1}".format(colors.GREEN, colors.END)) else: From aa59f8b4b22145034a2b4ced660ad584c2e4639d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 5 Oct 2021 11:29:47 -0400 Subject: [PATCH 09/10] Fix CPU layers not displaying correctly when using --layers --- aiserver.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index e6f9d798..0bbe17e4 100644 --- a/aiserver.py +++ b/aiserver.py @@ -207,13 +207,16 @@ def device_config(model): if(args.layers is not None): try: breakmodel.gpu_blocks = list(map(int, args.layers.split(','))) - assert len(gpu_blocks) <= torch.cuda.device_count() - assert sum(gpu_blocks) <= n_layers + assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count() + assert sum(breakmodel.gpu_blocks) <= n_layers + n_layers -= sum(breakmodel.gpu_blocks) except: print("WARNING: --layers is malformatted. Please use the --help option to see correct usage of --layers. Defaulting to all layers on device 0.", file=sys.stderr) breakmodel.gpu_blocks = [n_layers] + n_layers = 0 elif(args.breakmodel_layers is not None): breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))] + n_layers -= sum(breakmodel.gpu_blocks) else: device_count = torch.cuda.device_count() if(device_count > 1): From 3649ba9fa4f67ff05369d725a25e1cd4860da29d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 6 Oct 2021 12:04:56 -0400 Subject: [PATCH 10/10] Breakmodel's CUDA stream should be on primary device --- breakmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/breakmodel.py b/breakmodel.py index 73e40222..5724d4e2 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -387,7 +387,7 @@ def new_forward( all_hidden_states = () if output_hidden_states else None if breakmodel and ram_blocks: - copystream = torch.cuda.Stream(device=0,priority = -1) + copystream = torch.cuda.Stream(device=primary_device, priority=-1) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):