2024-03-21 19:02:20 +01:00
import time
import os , random
import torch
import math , pickle
from tqdm import tqdm
from torch . optim import AdamW
from torch . optim . lr_scheduler import LambdaLR
import torch . nn as nn
import torch . distributed as dist
from torch . utils . tensorboard import SummaryWriter
import numpy as np
from torch . utils . data . distributed import DistributedSampler
import logging
from data import gigaspeech
from models import voicecraft
from . trainer_utils import DistributedDynamicBatchSampler , StatefulDistributedSampler , AverageMeter , print_model_info
from . optim import ScaledAdam , Eden
class Trainer :
def __init__ ( self , args , world_size , rank ) :
self . start_time = time . time ( )
self . args = args
self . world_size , self . rank = world_size , rank
self . device = torch . device ( f " cuda: { rank } " if torch . cuda . is_available ( ) else " cpu " )
if self . rank == 0 :
self . writer = SummaryWriter ( args . exp_dir )
self . seed_everything ( seed = self . args . seed )
self . meters = self . _setup_meters ( )
self . progress , self . total_progress = self . _setup_progress ( )
self . model , self . trainables , self . optim_states , self . scheduler_states = self . _setup_models ( )
self . train_dataset_length , self . train_sampler , self . train_loader , self . valid_loader = self . _setup_dataloader ( )
if self . args . num_steps != None :
self . total_step = self . args . num_steps
self . args . num_epochs = math . ceil ( self . total_step / math . floor ( self . train_dataset_length / self . args . batch_size ) ) if not self . args . dynamic_batching else None
else :
self . total_step = int ( math . floor ( self . train_dataset_length / self . args . batch_size ) ) * self . args . num_epochs
self . optimizer , self . scheduler = self . _setup_optimizer ( )
self . scaler = torch . cuda . amp . GradScaler ( )
self . model = torch . nn . parallel . DistributedDataParallel ( self . model , device_ids = [ self . rank ] , find_unused_parameters = False )
if self . rank == 0 :
self . early_stop_accu_steps = 0
if self . args . dynamic_batching :
logging . info ( f " max number of tokens per GPU in a training batch: { self . args . max_num_tokens } , max number of tokens per GPU in a inference batch: { self . args . val_max_num_tokens } " )
else :
logging . info ( f " batch size (summed over all GPUs): { self . args . batch_size } " )
def train ( self ) :
flag = True
skip_flag = False
data_start_time = time . time ( )
while flag :
self . train_sampler . set_epoch ( self . progress [ ' epoch ' ] )
for i , batch in enumerate ( self . train_loader ) :
data_end_time = time . time ( )
self . model . train ( )
if self . progress [ ' step ' ] > self . total_step :
flag = False
self . validate_and_save ( )
if self . rank == 0 :
self . writer . close ( )
break
if isinstance ( self . scheduler , Eden ) :
self . scheduler . step_epoch ( self . progress [ ' step ' ] / / self . args . pseudo_epoch_size + 1 )
if self . args . optimizer_name == " ScaledAdam " :
cur_lr = self . scheduler . get_last_lr ( ) [ 0 ]
else :
lrs = [ param_group [ ' lr ' ] for param_group in self . optimizer . param_groups ]
assert lrs [ 0 ] == lrs [ 1 ]
cur_lr = lrs [ 0 ]
if self . rank == 0 and self . progress [ ' step ' ] % self . args . tb_write_every_n_steps == 0 :
self . writer . add_scalar ( " train/lr " , cur_lr , self . progress [ ' step ' ] )
self . wandb . log ( { " train/lr " : cur_lr } , step = self . progress [ ' step ' ] )
all_inds = list ( range ( len ( batch [ ' y ' ] ) ) )
sum_losses = 0
sum_top10acc = 0
sum_ntoken = 0
sum_top10acc_cbi = [ 0 for _ in range ( self . args . n_codebooks ) ]
for j in range ( self . args . gradient_accumulation_steps ) :
cur_ind = all_inds [ j : : self . args . gradient_accumulation_steps ]
cur_batch = { key : batch [ key ] [ cur_ind ] for key in batch }
with torch . cuda . amp . autocast ( dtype = torch . float16 if self . args . precision == " float16 " else torch . float32 ) :
out = self . model ( cur_batch )
2024-04-09 20:33:58 +02:00
if out == None :
continue
2024-03-21 19:02:20 +01:00
record_loss = out [ ' loss ' ] . detach ( ) . to ( self . rank )
top10acc = out [ ' top10acc ' ] . to ( self . rank )
effective_ntoken = out [ ' effective_ntoken ' ] . to ( self . rank )
is_nan = torch . tensor ( int ( torch . isnan ( record_loss ) . any ( ) ) , dtype = torch . float32 , device = self . rank )
dist . all_reduce ( record_loss , op = dist . ReduceOp . SUM )
dist . all_reduce ( top10acc , op = dist . ReduceOp . SUM )
dist . all_reduce ( effective_ntoken , op = dist . ReduceOp . SUM )
dist . all_reduce ( is_nan , op = dist . ReduceOp . SUM )
# check if loss is nan
if is_nan . item ( ) > 0 :
logging . info ( f " loss at step { self . progress [ ' step ' ] } is nan, therefore skip this batch " )
skip_flag = True
continue
sum_losses + = record_loss . item ( )
sum_top10acc + = top10acc . item ( )
sum_ntoken + = effective_ntoken . item ( )
if ' top10acc_by_codebook ' in out :
for cb in range ( self . args . n_codebooks ) :
top10acc_cbi = out [ ' top10acc_by_codebook ' ] [ cb ]
dist . all_reduce ( top10acc_cbi , op = dist . ReduceOp . SUM )
sum_top10acc_cbi [ cb ] + = top10acc_cbi . item ( )
if self . rank == 0 :
average_loss = sum_losses / sum_ntoken
average_top10acc = sum_top10acc / sum_ntoken
self . meters [ ' train_loss ' ] . update ( average_loss , batch [ ' x ' ] . shape [ 0 ] * self . world_size )
self . meters [ ' train_top10acc ' ] . update ( average_top10acc , batch [ ' x ' ] . shape [ 0 ] * self . world_size )
self . meters [ ' train_top10acc ' ] . update ( average_top10acc , batch [ ' x ' ] . shape [ 0 ] * self . world_size )
average_top10acc_cbi = [ sum_top10acc_cbi [ cb ] / sum_ntoken * self . args . n_codebooks for cb in range ( self . args . n_codebooks ) ]
for cb in range ( self . args . n_codebooks ) :
self . meters [ f ' train_top10acc_cb { cb + 1 } ' ] . update ( average_top10acc_cbi [ cb ] , batch [ ' x ' ] . shape [ 0 ] * self . world_size )
if self . progress [ ' step ' ] % self . args . tb_write_every_n_steps == 0 :
self . writer . add_scalar ( ' train/loss ' , average_loss , self . progress [ ' step ' ] )
self . writer . add_scalar ( ' train/top10acc ' , average_top10acc , self . progress [ ' step ' ] )
self . writer . add_scalar ( " train/ntokens " , sum_ntoken , self . progress [ ' step ' ] )
for cb in range ( self . args . n_codebooks ) :
self . writer . add_scalar ( f ' train/top10acc_cb { cb + 1 } ' , average_top10acc_cbi [ cb ] , self . progress [ ' step ' ] )
if self . args . optimizer_name == " ScaledAdam " :
self . scaler . scale ( out [ ' loss ' ] ) . backward ( )
else :
self . scaler . scale ( out [ ' loss ' ] / out [ ' effective_ntoken ' ] ) . backward ( )
if skip_flag :
self . optimizer . zero_grad ( )
skip_flag = False
continue
if self . args . optimizer_name != " ScaledAdam " :
self . scaler . unscale_ ( self . optimizer )
torch . nn . utils . clip_grad_norm_ ( self . model . parameters ( ) , self . args . gradient_clip_val )
self . scaler . step ( self . optimizer )
self . scaler . update ( )
self . optimizer . zero_grad ( )
if self . args . optimizer_name == " ScaledAdam " :
self . scheduler . step_batch ( self . progress [ ' step ' ] )
else :
self . scheduler . step ( )
if self . rank == 0 :
self . meters [ ' data_time ' ] . update ( data_end_time - data_start_time )
self . meters [ ' train_time ' ] . update ( time . time ( ) - data_end_time )
if self . progress [ ' step ' ] % self . args . tb_write_every_n_steps == 0 :
self . writer . add_scalar ( " train/data_time " , data_end_time - data_start_time , self . progress [ ' step ' ] )
self . writer . add_scalar ( " train/train_time " , time . time ( ) - data_end_time , self . progress [ ' step ' ] )
# logging
if self . progress [ ' step ' ] % self . args . print_every_n_steps == 0 :
log_out = { }
log_out [ ' cur_epoch ' ] = f " { self . progress [ ' epoch ' ] } / { self . args . num_epochs } " if self . args . num_epochs is not None else f " { self . progress [ ' epoch ' ] } "
log_out [ ' cur_step ' ] = f " { int ( self . progress [ ' cur_step ' ] + 1 ) } "
log_out [ ' total_step ' ] = f " { self . progress [ ' step ' ] } / { self . args . num_steps } "
log_out [ ' lr ' ] = f " { cur_lr : .7f } "
log_out [ ' ntokens ' ] = f " { sum_ntoken } "
for key in self . meters :
if self . meters [ key ] . val != 0 or self . meters [ key ] . avg != 0 :
log_out [ key ] = f " { self . meters [ key ] . val : .4f } ( { self . meters [ key ] . avg : .4f } ) " if isinstance ( self . meters [ key ] . val , float ) else f " { self . meters [ key ] . val } "
logging . info ( log_out )
if np . isnan ( self . meters [ ' train_loss ' ] . avg ) :
logging . warning ( " training diverged... " )
raise RuntimeError ( " training diverged... " )
# validation and save models
if self . progress [ ' step ' ] % self . args . val_every_n_steps == 0 :
dist . barrier ( )
self . validate_and_save ( )
self . progress [ ' step ' ] + = 1
self . progress [ ' cur_step ' ] + = 1
data_start_time = time . time ( )
self . progress [ ' epoch ' ] + = 1
self . progress [ ' cur_step ' ] = 0 # reset cur_step to be 0
dist . destroy_process_group ( )
def validate_and_save ( self ) :
self . model . eval ( )
score = self . validate ( self . valid_loader )
if self . rank == 0 :
if self . args . early_stop_threshold > 0 :
if self . progress [ ' best_score ' ] - score < self . args . early_stop_threshold :
self . early_stop_accu_steps + = self . args . val_every_n_steps
if self . early_stop_accu_steps > = self . args . early_stop_step - 1 :
logging . info ( f " early stop based on self.args.early_stop_threshold: { self . args . early_stop_threshold } , and self.args.early_stop_step: { self . args . early_stop_step } " )
logging . info ( f " best validation score at step: { self . progress [ ' best_step ' ] } , and the score is { self . progress [ ' best_score ' ] : .4f } " )
dist . destroy_process_group ( )
raise RuntimeError ( " early stop " )
else :
self . early_stop_accu_steps = 0
if ( score < self . progress [ ' best_score ' ] ) :
self . progress [ ' best_step ' ] = self . progress [ ' step ' ]
self . progress [ ' best_score ' ] = score
save_path = os . path . join ( self . args . exp_dir , " best_bundle.pth " )
torch . save (
{
" model " : self . model . module . state_dict ( ) ,
" optimizer " : self . optimizer . state_dict ( ) ,
" scheduler " : self . scheduler . state_dict ( ) ,
" config " : self . args ,
" phn2num " : self . train_loader . dataset . phn2num
} , save_path
)
logging . info ( f " save *best* models at { save_path } at global step { self . progress [ ' step ' ] } " )
self . _save_progress ( )
save_path = os . path . join ( self . args . exp_dir , " bundle.pth " )
torch . save (
{
" model " : self . model . module . state_dict ( ) ,
" optimizer " : self . optimizer . state_dict ( ) ,
" scheduler " : self . scheduler . state_dict ( ) ,
" config " : self . args ,
" phn2num " : self . train_loader . dataset . phn2num
} , save_path
)
logging . info ( f " save models, indices, acc and other statistics at { save_path } and { self . args . exp_dir } /progress.pkl at global step { self . progress [ ' step ' ] } " )
dist . barrier ( )
def validate ( self , valid_loader = None , hide_progress = True ) :
if valid_loader == None :
valid_loader = self . valid_loader
self . model . eval ( )
start_val_time = time . time ( )
sum_losses = 0
sum_top10acc = 0
sum_ntoken = 0
sum_top10acc_cbi = [ 0 for _ in range ( self . args . n_codebooks ) ]
with torch . no_grad ( ) :
for i , batch in enumerate ( tqdm ( valid_loader , disable = hide_progress ) ) :
out = self . model ( batch )
sum_losses + = out [ ' loss ' ]
sum_top10acc + = out [ ' top10acc ' ]
sum_ntoken + = out [ ' effective_ntoken ' ]
if ' top10acc_by_codebook ' in out :
for cb in range ( self . args . n_codebooks ) :
sum_top10acc_cbi [ cb ] + = out [ ' top10acc_by_codebook ' ] [ cb ]
dist . all_reduce ( sum_losses , op = dist . ReduceOp . SUM )
dist . all_reduce ( sum_top10acc , op = dist . ReduceOp . SUM )
dist . all_reduce ( sum_ntoken , op = dist . ReduceOp . SUM )
if ' top10acc_by_codebook ' in out :
for cb in range ( self . args . n_codebooks ) :
dist . all_reduce ( sum_top10acc_cbi [ cb ] , op = dist . ReduceOp . SUM )
if self . rank == 0 :
val_loss = sum_losses / sum_ntoken
val_top10acc = sum_top10acc / sum_ntoken
# logging
self . meters [ ' val_loss ' ] . update ( val_loss )
logging . info ( f " val loss: { val_loss : .5f } " )
self . writer . add_scalar ( " val/loss " , val_loss , self . progress [ ' step ' ] )
self . meters [ ' val_top10acc ' ] . update ( val_top10acc )
logging . info ( f " val top10acc: { val_top10acc : .5f } " )
self . writer . add_scalar ( " val/top10acc " , val_top10acc , self . progress [ ' step ' ] )
for cb in range ( self . args . n_codebooks ) :
average_top10acc_cbi = sum_top10acc_cbi [ cb ] / sum_ntoken * self . args . n_codebooks
self . meters [ f ' val_top10acc_cb { cb + 1 } ' ] . update ( average_top10acc_cbi )
self . writer . add_scalar ( f ' val/top10acc_cb { cb + 1 } ' , average_top10acc_cbi , self . progress [ ' step ' ] )
logging . info ( f " validation takes: { time . time ( ) - start_val_time : .2f } s " )
logging . info ( f " Step [ { self . progress [ ' step ' ] } / { self . total_step } ] \t Time elapsed { ( time . time ( ) - self . start_time ) / 3600. : .2f } h, Val Loss: { val_loss : .4f } , Val Top10Acc: { val_top10acc : .4f } " )
return val_loss . item ( )
else :
return None
def _setup_meters ( self ) :
meters = { }
meter_names = [ ' train_loss ' , ' val_loss ' , ' train_top10acc ' , ' val_top10acc ' , ' data_time ' , ' train_time ' ]
meter_names + = [ ' train_dur_loss ' , ' train_dur_acc ' , ' val_dur_loss ' , ' val_dur_acc ' ]
meter_names + = [ f ' train_top10acc_cb { cb + 1 } ' for cb in range ( self . args . n_codebooks ) ]
meter_names + = [ f ' val_top10acc_cb { cb + 1 } ' for cb in range ( self . args . n_codebooks ) ]
for name in meter_names :
meters [ name ] = AverageMeter ( )
return meters
def _setup_progress ( self ) :
progress = { }
progress [ ' best_step ' ] = 1
progress [ ' best_score ' ] = np . inf # this records loss value
progress [ ' step ' ] = 1
progress [ ' epoch ' ] = 1
progress [ ' cur_step ' ] = 0 # step in the current epoch, for resuming the sampler
total_progress = [ ]
# if self.args.resume or self.args.validate:
if self . args . resume :
progress_pkl = " %s /progress.pkl " % self . args . exp_dir
with open ( progress_pkl , " rb " ) as f :
total_progress = pickle . load ( f )
progress [ ' best_step ' ] , progress [ ' best_score ' ] , progress [ ' step ' ] , progress [ ' epoch ' ] , progress [ ' cur_step ' ] , _ = total_progress [ - 1 ]
if self . rank == 0 :
logging . info ( " \n Resume training from: " )
logging . info ( " epoch = %s " % progress [ ' epoch ' ] )
logging . info ( " cur_step = %s " % progress [ ' cur_step ' ] )
logging . info ( " step = %s " % progress [ ' step ' ] )
logging . info ( " best_step = %s " % progress [ ' best_step ' ] )
logging . info ( " best_score = %s " % progress [ ' best_score ' ] )
return progress , total_progress
def _save_progress ( self ) :
self . total_progress . append ( [ self . progress [ ' best_step ' ] , self . progress [ ' best_score ' ] , int ( self . progress [ ' step ' ] + 1 ) , self . progress [ ' epoch ' ] , int ( self . progress [ ' cur_step ' ] + 1 ) , time . time ( ) - self . start_time ] )
with open ( " %s /progress.pkl " % self . args . exp_dir , " wb " ) as f :
pickle . dump ( self . total_progress , f )
def _setup_dataloader ( self ) :
assert self . args . dataset == ' gigaspeech ' , " only gigaspeech is supported for now "
train_dataset , val_dataset = gigaspeech . dataset ( self . args , ' train ' ) , gigaspeech . dataset ( self . args , ' validation ' )
if self . args . dynamic_batching :
train_sampler = DistributedDynamicBatchSampler ( train_dataset , self . args , num_replicas = self . world_size , rank = self . rank , shuffle = True , seed = self . args . seed , drop_last = True , lengths_list = train_dataset . lengths_list , verbose = True , epoch = 0 )
valid_sampler = DistributedDynamicBatchSampler ( val_dataset , self . args , num_replicas = self . world_size , rank = self . rank , shuffle = True , seed = self . args . seed , drop_last = True , lengths_list = val_dataset . lengths_list , verbose = True , epoch = 0 )
else :
train_sampler = StatefulDistributedSampler ( train_dataset , self . args . batch_size / / self . world_size , num_replicas = self . world_size , rank = self . rank , shuffle = True , seed = self . args . seed , drop_last = True )
valid_sampler = DistributedSampler ( val_dataset , num_replicas = self . world_size , rank = self . rank , shuffle = False , seed = self . args . seed , drop_last = False )
if self . progress [ ' step ' ] > 1 :
train_sampler . set_epoch_resume ( self . progress [ ' epoch ' ] , self . progress [ ' cur_step ' ] )
if self . args . dynamic_batching :
train_loader = torch . utils . data . DataLoader ( train_dataset ,
batch_sampler = train_sampler ,
num_workers = self . args . num_workers / / self . world_size ,
collate_fn = train_dataset . collate , persistent_workers = True
)
valid_loader = torch . utils . data . DataLoader ( val_dataset ,
batch_sampler = valid_sampler ,
num_workers = self . args . num_workers / / self . world_size ,
collate_fn = val_dataset . collate , persistent_workers = True
)
else :
train_loader = torch . utils . data . DataLoader ( train_dataset ,
batch_size = self . args . batch_size / / self . world_size , sampler = train_sampler , num_workers = self . args . num_workers / / self . world_size ,
collate_fn = train_dataset . collate , persistent_workers = True
)
valid_loader = torch . utils . data . DataLoader ( val_dataset ,
batch_size = self . args . batch_size / / self . world_size , sampler = valid_sampler ,
num_workers = self . args . num_workers / / self . world_size ,
collate_fn = val_dataset . collate , persistent_workers = True
)
return len ( train_dataset ) , train_sampler , train_loader , valid_loader
def _setup_models ( self ) :
model = voicecraft . VoiceCraft ( self . args )
if self . rank == 0 :
logging . info ( model )
logging . info ( " model parameters " )
print_model_info ( model )
if self . progress [ ' step ' ] > 1 :
bundle = torch . load ( os . path . join ( self . args . exp_dir , " bundle.pth " ) , map_location = " cpu " )
model . load_state_dict ( bundle [ ' model ' ] )
optim_states = bundle [ ' optimizer ' ]
scheduler_states = bundle [ ' scheduler ' ]
if self . rank == 0 :
logging . info ( " loaded parameters and data indices from epoch %d , global step %d " % ( self . progress [ ' epoch ' ] , self . progress [ ' step ' ] ) )
del bundle [ ' model ' ]
else :
optim_states = None
scheduler_states = None
if self . args . load_model_from != None and self . progress [ ' step ' ] < = 1 :
sd = torch . load ( self . args . load_model_from , map_location = " cpu " ) [ ' model ' ]
model . load_state_dict ( sd )
del sd
if self . args . optimizer_name == " ScaledAdam " :
trainables = [ p for p in model . parameters ( ) if p . requires_grad ]
else :
no_decay = [ " .bias " , " .audio_embeddings.weight " , " .text_embeddings.weight " , " .norm.weight " , " .norm1.weight " , " .norm2.weight " ]
optimizer_grouped_parameters = [
{
" params " : [ p for n , p in model . named_parameters ( ) if not any ( nd in n for nd in no_decay ) and p . requires_grad ] ,
" weight_decay " : self . args . weight_decay ,
} ,
{
" params " : [ p for n , p in model . named_parameters ( ) if any ( nd in n for nd in no_decay ) and p . requires_grad ] ,
" weight_decay " : 0.0 ,
} ,
]
if len ( optimizer_grouped_parameters [ 1 ] [ ' params ' ] ) == 0 :
logging . info ( " there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names " )
trainables = optimizer_grouped_parameters [ 0 ]
else :
trainables = optimizer_grouped_parameters
model . to ( self . device )
return model , trainables , optim_states , scheduler_states
def _setup_optimizer ( self ) :
if self . args . optimizer_name == " ScaledAdam " :
parameters_names = [ ]
parameters_names . append ( [ n for n , p in self . model . named_parameters ( ) if p . requires_grad ] )
optimizer = ScaledAdam (
self . trainables ,
lr = self . args . lr ,
betas = ( 0.9 , 0.95 ) ,
clipping_scale = 2.0 ,
parameters_names = parameters_names ,
show_dominant_parameters = False ,
clipping_update_period = self . args . clipping_update_period ,
)
scheduler = Eden ( optimizer , self . args . reduce_lr_start_step , self . args . reduce_lr_start_epoch , warmup_batches = self . total_step * self . args . warmup_fraction )
else :
optimizer = AdamW ( self . trainables , lr = self . args . lr )
warmup_steps = self . total_step * self . args . warmup_fraction
def lr_lambda ( current_step : int ) :
if current_step < warmup_steps :
return float ( current_step ) / float ( max ( 1 , warmup_steps ) )
return max (
0.0 , float ( self . total_step - current_step ) / float ( max ( 1 , self . total_step - warmup_steps ) )
)
scheduler = LambdaLR ( optimizer , lr_lambda , last_epoch = - 1 )
# if resume
if self . progress [ ' step ' ] > 1 :
optimizer . load_state_dict ( self . optim_states )
for state in optimizer . state . values ( ) :
for k , v in state . items ( ) :
if isinstance ( v , torch . Tensor ) :
state [ k ] = v . cuda ( )
del self . optim_states
scheduler . load_state_dict ( self . scheduler_states )
optimizer . zero_grad ( )
return optimizer , scheduler
def seed_everything ( self , seed = 1 ) :
os . environ [ ' PYTHONHASHSEED ' ] = str ( seed )
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed ( seed )
torch . backends . cudnn . benchmark = False
torch . backends . cudnn . deterministic = True