mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
K80 test
This commit is contained in:
18
aiserver.py
18
aiserver.py
@ -413,12 +413,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||||||
breakmodel.total_blocks = n_layers
|
breakmodel.total_blocks = n_layers
|
||||||
model.half().to('cpu')
|
model.half().to('cpu')
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model.transformer.wte.to(breakmodel.gpu_device)
|
model.transformer.wte.to(breakmodel.embedding_device)
|
||||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
|
||||||
if(hasattr(model, 'lm_head')):
|
if(hasattr(model, 'lm_head')):
|
||||||
model.lm_head.to(breakmodel.gpu_device)
|
model.lm_head.to(breakmodel.embedding_device)
|
||||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
model.transformer.wpe.to(breakmodel.positional_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if(args.breakmodel_layers is not None):
|
if(args.breakmodel_layers is not None):
|
||||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||||
@ -465,12 +465,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||||||
breakmodel.total_blocks = n_layers
|
breakmodel.total_blocks = n_layers
|
||||||
model.half().to('cpu')
|
model.half().to('cpu')
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model.transformer.wte.to(breakmodel.gpu_device)
|
model.transformer.wte.to(breakmodel.embedding_device)
|
||||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
|
||||||
if(hasattr(model, 'lm_head')):
|
if(hasattr(model, 'lm_head')):
|
||||||
model.lm_head.to(breakmodel.gpu_device)
|
model.lm_head.to(breakmodel.embedding_device)
|
||||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
model.transformer.wpe.to(breakmodel.positional_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if(args.breakmodel_layers is not None):
|
if(args.breakmodel_layers is not None):
|
||||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||||
@ -1229,7 +1229,7 @@ def generate(txt, min, max):
|
|||||||
# its first argument if we're using breakmodel, otherwise a string
|
# its first argument if we're using breakmodel, otherwise a string
|
||||||
# is fine
|
# is fine
|
||||||
if(vars.hascuda and vars.breakmodel):
|
if(vars.hascuda and vars.breakmodel):
|
||||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device)
|
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device)
|
||||||
else:
|
else:
|
||||||
gen_in = txt
|
gen_in = txt
|
||||||
|
|
||||||
|
@ -229,11 +229,17 @@ class MaxSharedRamBlocksException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
breakmodel = True
|
breakmodel = True
|
||||||
gpu_device = 'cuda'
|
devices = ['cpu', 'cuda']
|
||||||
total_blocks = 24
|
total_blocks = 24
|
||||||
ram_blocks = 7
|
ram_blocks = 7
|
||||||
max_shared_ram_blocks = None
|
max_shared_ram_blocks = None
|
||||||
|
|
||||||
|
# I highly suggest these all be set to the same device unless you really know what you're doing!
|
||||||
|
# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device)
|
||||||
|
embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device
|
||||||
|
positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J)
|
||||||
|
layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def new_forward(
|
def new_forward(
|
||||||
self,
|
self,
|
||||||
@ -260,44 +266,13 @@ def new_forward(
|
|||||||
setattr(self,"extrastorage",{})
|
setattr(self,"extrastorage",{})
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
for i in range(ram_blocks):
|
||||||
|
self.h[i].to(devices[0])
|
||||||
|
|
||||||
for i in range(ram_blocks,len(self.h)):
|
for i in range(ram_blocks,len(self.h)):
|
||||||
self.h[i].to(gpu_device)
|
self.h[i].to(devices[1])
|
||||||
|
|
||||||
for i in range(ram_blocks):
|
|
||||||
self.h[i].to("cpu")
|
|
||||||
self.extrastorage[i] = copy.deepcopy(self.h[i])
|
|
||||||
smalltensor = torch.tensor(0).to(gpu_device)
|
|
||||||
for param1 in self.h[i].parameters():
|
|
||||||
param1.data = smalltensor
|
|
||||||
self.h[i].to(gpu_device)
|
|
||||||
|
|
||||||
for i in range(len(self.h)):
|
|
||||||
for param in self.h[i].parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
param.data = param.data.detach()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
for i in range(ram_blocks):
|
|
||||||
for param in self.extrastorage[i].parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
if i < max_shared_ram_blocks:
|
|
||||||
try:
|
|
||||||
param.data = param.data.detach().pin_memory()
|
|
||||||
except:
|
|
||||||
raise MaxSharedRamBlocksException(i)
|
|
||||||
else:
|
|
||||||
param.data = param.data.detach()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if ram_blocks:
|
|
||||||
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
|
|
||||||
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
|
|
||||||
|
|
||||||
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
|
|
||||||
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
|
|
||||||
#END MODEL BREAK EDITS
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -331,7 +306,7 @@ def new_forward(
|
|||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
@ -368,6 +343,8 @@ def new_forward(
|
|||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
if breakmodel:
|
||||||
|
input_ids = input_ids.to(embedding_device)
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
|
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
|
||||||
@ -382,7 +359,11 @@ def new_forward(
|
|||||||
if hasattr(self, 'rotary') and self.rotary:
|
if hasattr(self, 'rotary') and self.rotary:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
|
if breakmodel:
|
||||||
|
position_ids = position_ids.to(positional_device)
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
|
if breakmodel:
|
||||||
|
position_embeds = position_embeds.to(embedding_device)
|
||||||
hidden_states = inputs_embeds + position_embeds
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
@ -396,23 +377,8 @@ def new_forward(
|
|||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
|
||||||
if breakmodel:
|
|
||||||
copystream = torch.cuda.Stream(device=0,priority = -1)
|
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
if breakmodel:
|
|
||||||
if i in range(ram_blocks):
|
|
||||||
index1 = (i+1)%ram_blocks
|
|
||||||
for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()):
|
|
||||||
param1.data = param2.data
|
|
||||||
for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()):
|
|
||||||
with torch.cuda.stream(copystream):
|
|
||||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
|
||||||
|
|
||||||
|
|
||||||
attn_type = self.config.attention_layers[i]
|
attn_type = self.config.attention_layers[i]
|
||||||
attn_mask = global_attention_mask
|
attn_mask = global_attention_mask
|
||||||
|
|
||||||
@ -443,11 +409,13 @@ def new_forward(
|
|||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if breakmodel:
|
||||||
|
device = devices[0] if i < ram_blocks else devices[1]
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -460,18 +428,11 @@ def new_forward(
|
|||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
|
|
||||||
if breakmodel:
|
|
||||||
if i in range(ram_blocks):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if breakmodel:
|
if breakmodel:
|
||||||
del copystream
|
hidden_states = hidden_states.to(layernormfinal_device)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
if breakmodel:
|
||||||
|
hidden_states = hidden_states.to(embedding_device)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(*output_shape)
|
hidden_states = hidden_states.view(*output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
@ -480,7 +441,6 @@ def new_forward(
|
|||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=presents,
|
||||||
|
Reference in New Issue
Block a user