mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
OPT breakmodel
This commit is contained in:
43
aiserver.py
43
aiserver.py
@ -274,7 +274,7 @@ class vars:
|
|||||||
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
|
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
|
||||||
useprompt = False # Whether to send the full prompt with every submit action
|
useprompt = False # Whether to send the full prompt with every submit action
|
||||||
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
||||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM only, currently)
|
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently)
|
||||||
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
|
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
|
||||||
smandelete = False # Whether stories can be deleted from inside the browser
|
smandelete = False # Whether stories can be deleted from inside the browser
|
||||||
smanrename = False # Whether stories can be renamed from inside the browser
|
smanrename = False # Whether stories can be renamed from inside the browser
|
||||||
@ -391,7 +391,7 @@ def device_list(n_layers, primary=None, selected=None):
|
|||||||
def device_config(config):
|
def device_config(config):
|
||||||
global breakmodel, generator
|
global breakmodel, generator
|
||||||
import breakmodel
|
import breakmodel
|
||||||
n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
|
n_layers = utils.num_layers(config)
|
||||||
if(args.breakmodel_gpulayers is not None):
|
if(args.breakmodel_gpulayers is not None):
|
||||||
try:
|
try:
|
||||||
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
|
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
|
||||||
@ -464,7 +464,7 @@ def device_config(config):
|
|||||||
# If all layers are on the same device, use the old GPU generation mode
|
# If all layers are on the same device, use the old GPU generation mode
|
||||||
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
||||||
breakmodel.gpu_blocks.pop()
|
breakmodel.gpu_blocks.pop()
|
||||||
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
|
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, utils.num_layers(config))):
|
||||||
vars.breakmodel = False
|
vars.breakmodel = False
|
||||||
vars.usegpu = True
|
vars.usegpu = True
|
||||||
vars.gpu_device = len(breakmodel.gpu_blocks)-1
|
vars.gpu_device = len(breakmodel.gpu_blocks)-1
|
||||||
@ -496,22 +496,33 @@ def move_model_to_devices(model):
|
|||||||
model.lm_head.to(breakmodel.primary_device)
|
model.lm_head.to(breakmodel.primary_device)
|
||||||
if(hasattr(model.transformer, 'wpe')):
|
if(hasattr(model.transformer, 'wpe')):
|
||||||
model.transformer.wpe.to(breakmodel.primary_device)
|
model.transformer.wpe.to(breakmodel.primary_device)
|
||||||
else:
|
elif(not hasattr(model.model, "decoder")):
|
||||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||||
model.model.layer_norm.to(breakmodel.primary_device)
|
model.model.layer_norm.to(breakmodel.primary_device)
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
model.lm_head.to(breakmodel.primary_device)
|
||||||
model.model.embed_positions.to(breakmodel.primary_device)
|
model.model.embed_positions.to(breakmodel.primary_device)
|
||||||
|
else:
|
||||||
|
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
|
||||||
|
if(model.model.decoder.project_in is not None):
|
||||||
|
model.model.decoder.project_in.to(breakmodel.primary_device)
|
||||||
|
if(model.model.decoder.project_out is not None):
|
||||||
|
model.model.decoder.project_out.to(breakmodel.primary_device)
|
||||||
|
model.model.decoder.embed_positions.to(breakmodel.primary_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||||
if("GPTJModel" in globals()):
|
if("GPTJModel" in globals()):
|
||||||
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
|
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
|
||||||
if("XGLMModel" in globals()):
|
if("XGLMModel" in globals()):
|
||||||
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
|
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
|
||||||
|
if("OPTDecoder" in globals()):
|
||||||
|
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
if(hasattr(model, "transformer")):
|
if(hasattr(model, "transformer")):
|
||||||
breakmodel.move_hidden_layers(model.transformer)
|
breakmodel.move_hidden_layers(model.transformer)
|
||||||
else:
|
elif(not hasattr(model.model, "decoder")):
|
||||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||||
|
else:
|
||||||
|
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Allow the models to override some settings
|
# Allow the models to override some settings
|
||||||
@ -911,7 +922,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
loadsettings()
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
vars.hascuda = torch.cuda.is_available()
|
vars.hascuda = torch.cuda.is_available()
|
||||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel
|
||||||
if(args.breakmodel is not None and args.breakmodel):
|
if(args.breakmodel is not None and args.breakmodel):
|
||||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
||||||
if(args.breakmodel_layers is not None):
|
if(args.breakmodel_layers is not None):
|
||||||
@ -1123,6 +1134,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
globals()[m] = getattr(__import__("transformers"), m)
|
globals()[m] = getattr(__import__("transformers"), m)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||||
|
except:
|
||||||
|
pass
|
||||||
import transformers.generation_utils
|
import transformers.generation_utils
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
@ -1253,8 +1268,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
input_ids.clamp_(max=self.config.vocab_size-1)
|
input_ids.clamp_(max=self.config.vocab_size-1)
|
||||||
if(hasattr(self, "transformer")):
|
if(hasattr(self, "transformer")):
|
||||||
inputs_embeds = self.transformer.wte(input_ids)
|
inputs_embeds = self.transformer.wte(input_ids)
|
||||||
else:
|
elif(not hasattr(model.model, "decoder")):
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
|
||||||
if(vars.sp is not None):
|
if(vars.sp is not None):
|
||||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||||
inputs_embeds = torch.where(
|
inputs_embeds = torch.where(
|
||||||
@ -1262,14 +1279,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
)
|
)
|
||||||
if(not hasattr(self, "transformer")):
|
if(hasattr(self.model, "embed_scale")):
|
||||||
inputs_embeds *= self.model.embed_scale
|
inputs_embeds *= self.model.embed_scale
|
||||||
kwargs['inputs_embeds'] = inputs_embeds
|
kwargs['inputs_embeds'] = inputs_embeds
|
||||||
return old_forward(self, *args, **kwargs)
|
return old_forward(self, *args, **kwargs)
|
||||||
cls.forward = new_causallm_forward
|
cls.forward = new_causallm_forward
|
||||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||||
patch_causallm(cls)
|
patch_causallm(cls)
|
||||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM"):
|
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
|
||||||
try:
|
try:
|
||||||
patch_causallm(getattr(__import__("transformers"), c))
|
patch_causallm(getattr(__import__("transformers"), c))
|
||||||
except:
|
except:
|
||||||
@ -1429,6 +1446,12 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
def get_hidden_size_from_model(model):
|
def get_hidden_size_from_model(model):
|
||||||
|
try:
|
||||||
|
return int(model.model.decoder.project_in.in_features)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
return int(model.model.decoder.embed_tokens.out_features)
|
||||||
|
except:
|
||||||
try:
|
try:
|
||||||
return int(model.transformer.hidden_size)
|
return int(model.transformer.hidden_size)
|
||||||
except:
|
except:
|
||||||
@ -1490,7 +1513,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
import shutil
|
import shutil
|
||||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||||
print("\n", flush=True)
|
print("\n", flush=True)
|
||||||
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True):
|
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
|
||||||
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||||
lowmem = {}
|
lowmem = {}
|
||||||
if(os.path.isdir(vars.custmodpth)):
|
if(os.path.isdir(vars.custmodpth)):
|
||||||
|
182
breakmodel.py
182
breakmodel.py
@ -633,11 +633,11 @@ def new_forward_xglm(
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||||
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None,
|
encoder_hidden_states=encoder_hidden_states.to(device) if breakmodel and encoder_hidden_states is not None else encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None,
|
encoder_attention_mask=encoder_attention_mask.to(device) if breakmodel and encoder_attention_mask is not None else encoder_attention_mask,
|
||||||
layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None),
|
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
|
||||||
cross_attn_layer_head_mask=(
|
cross_attn_layer_head_mask=(
|
||||||
(cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None
|
(cross_attn_head_mask[idx].to(device) if breakmodel and cross_attn_head_mask[idx] is not None else cross_attn_head_mask[idx]) if cross_attn_head_mask is not None else None
|
||||||
),
|
),
|
||||||
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@ -686,3 +686,177 @@ def new_forward_xglm(
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def new_forward_opt(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||||
|
assert sum(gpu_blocks) <= len(self.layers)
|
||||||
|
ram_blocks = len(self.layers) - sum(gpu_blocks)
|
||||||
|
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||||
|
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
if breakmodel:
|
||||||
|
input_ids = input_ids.to(primary_device)
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if breakmodel:
|
||||||
|
inputs_embeds = inputs_embeds.to(primary_device)
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :]
|
||||||
|
if breakmodel:
|
||||||
|
positions = positions.to(primary_device)
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.project_in is not None:
|
||||||
|
inputs_embeds = self.project_in(inputs_embeds)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
if breakmodel and ram_blocks:
|
||||||
|
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.size()[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||||
|
f" {head_mask.size()[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
i = idx
|
||||||
|
if breakmodel:
|
||||||
|
if i in range(ram_blocks):
|
||||||
|
index1 = (i+1)%ram_blocks
|
||||||
|
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
|
||||||
|
param1.data = param2.data
|
||||||
|
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
|
||||||
|
with torch.cuda.stream(copystream):
|
||||||
|
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||||
|
|
||||||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
dropout_probability = random.uniform(0, 1)
|
||||||
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
|
continue
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, None)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if breakmodel:
|
||||||
|
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||||
|
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
|
||||||
|
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if breakmodel:
|
||||||
|
if i in range(ram_blocks):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if breakmodel:
|
||||||
|
if ram_blocks:
|
||||||
|
del copystream
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
hidden_states = hidden_states.to(primary_device)
|
||||||
|
if self.project_out is not None:
|
||||||
|
hidden_states = self.project_out(hidden_states)
|
||||||
|
if breakmodel:
|
||||||
|
hidden_states = hidden_states.to(primary_device)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
6
utils.py
6
utils.py
@ -135,6 +135,12 @@ def decodenewlines(txt):
|
|||||||
return txt.replace("</s>", '\n')
|
return txt.replace("</s>", '\n')
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Returns number of layers given an HF model config
|
||||||
|
#==================================================================#
|
||||||
|
def num_layers(config):
|
||||||
|
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user