import time import os, random import torch import math, pickle from tqdm import tqdm from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR import torch.nn as nn import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter import numpy as np from torch.utils.data.distributed import DistributedSampler import logging from data import gigaspeech from models import voicecraft from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, AverageMeter, print_model_info from .optim import ScaledAdam, Eden class Trainer: def __init__(self, args, world_size, rank): self.start_time = time.time() self.args = args self.world_size, self.rank = world_size, rank self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") if self.rank == 0: self.writer = SummaryWriter(args.exp_dir) self.seed_everything(seed=self.args.seed) self.meters = self._setup_meters() self.progress, self.total_progress = self._setup_progress() self.model, self.trainables, self.optim_states, self.scheduler_states = self._setup_models() self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() if self.args.num_steps != None: self.total_step = self.args.num_steps self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None else: self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs self.optimizer, self.scheduler = self._setup_optimizer() self.scaler = torch.cuda.amp.GradScaler() self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank], find_unused_parameters=False) if self.rank == 0: self.early_stop_accu_steps = 0 if self.args.dynamic_batching: logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}") else: logging.info(f"batch size (summed over all GPUs): {self.args.batch_size}") def train(self): flag = True skip_flag = False data_start_time = time.time() while flag: self.train_sampler.set_epoch(self.progress['epoch']) for i, batch in enumerate(self.train_loader): data_end_time = time.time() self.model.train() if self.progress['step'] > self.total_step: flag = False self.validate_and_save() if self.rank == 0: self.writer.close() break if isinstance(self.scheduler, Eden): self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1) if self.args.optimizer_name == "ScaledAdam": cur_lr = self.scheduler.get_last_lr()[0] else: lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] assert lrs[0] == lrs[1] cur_lr = lrs[0] if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0: self.writer.add_scalar("train/lr", cur_lr, self.progress['step']) self.wandb.log({"train/lr": cur_lr}, step=self.progress['step']) all_inds = list(range(len(batch['y']))) sum_losses = 0 sum_top10acc = 0 sum_ntoken = 0 sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] for j in range(self.args.gradient_accumulation_steps): cur_ind = all_inds[j::self.args.gradient_accumulation_steps] cur_batch = {key: batch[key][cur_ind] for key in batch} with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32): out = self.model(cur_batch) record_loss = out['loss'].detach().to(self.rank) top10acc = out['top10acc'].to(self.rank) effective_ntoken = out['effective_ntoken'].to(self.rank) is_nan = torch.tensor(int(torch.isnan(record_loss).any()), dtype=torch.float32, device=self.rank) dist.all_reduce(record_loss, op=dist.ReduceOp.SUM) dist.all_reduce(top10acc, op=dist.ReduceOp.SUM) dist.all_reduce(effective_ntoken, op=dist.ReduceOp.SUM) dist.all_reduce(is_nan, op=dist.ReduceOp.SUM) # check if loss is nan if is_nan.item() > 0: logging.info(f"loss at step {self.progress['step']} is nan, therefore skip this batch") skip_flag = True continue sum_losses += record_loss.item() sum_top10acc += top10acc.item() sum_ntoken += effective_ntoken.item() if 'top10acc_by_codebook' in out: for cb in range(self.args.n_codebooks): top10acc_cbi = out['top10acc_by_codebook'][cb] dist.all_reduce(top10acc_cbi, op=dist.ReduceOp.SUM) sum_top10acc_cbi[cb] += top10acc_cbi.item() if self.rank == 0: average_loss = sum_losses / sum_ntoken average_top10acc = sum_top10acc / sum_ntoken self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size) self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)] for cb in range(self.args.n_codebooks): self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size) if self.progress['step'] % self.args.tb_write_every_n_steps == 0: self.writer.add_scalar('train/loss', average_loss, self.progress['step']) self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step']) self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step']) for cb in range(self.args.n_codebooks): self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step']) if self.args.optimizer_name == "ScaledAdam": self.scaler.scale(out['loss']).backward() else: self.scaler.scale(out['loss']/out['effective_ntoken']).backward() if skip_flag: self.optimizer.zero_grad() skip_flag = False continue if self.args.optimizer_name != "ScaledAdam": self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if self.args.optimizer_name == "ScaledAdam": self.scheduler.step_batch(self.progress['step']) else: self.scheduler.step() if self.rank == 0: self.meters['data_time'].update(data_end_time - data_start_time) self.meters['train_time'].update(time.time() - data_end_time) if self.progress['step'] % self.args.tb_write_every_n_steps == 0: self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step']) self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step']) # logging if self.progress['step'] % self.args.print_every_n_steps == 0: log_out = {} log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}" log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}" log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}" log_out['lr'] = f"{cur_lr:.7f}" log_out['ntokens'] = f"{sum_ntoken}" for key in self.meters: if self.meters[key].val != 0 or self.meters[key].avg != 0: log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}" logging.info(log_out) if np.isnan(self.meters['train_loss'].avg): logging.warning("training diverged...") raise RuntimeError("training diverged...") # validation and save models if self.progress['step'] % self.args.val_every_n_steps == 0: dist.barrier() self.validate_and_save() self.progress['step'] += 1 self.progress['cur_step'] += 1 data_start_time = time.time() self.progress['epoch'] += 1 self.progress['cur_step'] = 0 # reset cur_step to be 0 dist.destroy_process_group() def validate_and_save(self): self.model.eval() score = self.validate(self.valid_loader) if self.rank == 0: if self.args.early_stop_threshold > 0: if self.progress['best_score'] - score < self.args.early_stop_threshold: self.early_stop_accu_steps += self.args.val_every_n_steps if self.early_stop_accu_steps >= self.args.early_stop_step-1: logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}") logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}") dist.destroy_process_group() raise RuntimeError("early stop") else: self.early_stop_accu_steps = 0 if (score < self.progress['best_score']): self.progress['best_step'] = self.progress['step'] self.progress['best_score'] = score save_path = os.path.join(self.args.exp_dir,"best_bundle.pth") torch.save( { "model": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "config": self.args, "phn2num": self.train_loader.dataset.phn2num },save_path ) logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}") self._save_progress() save_path = os.path.join(self.args.exp_dir,"bundle.pth") torch.save( { "model": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "config": self.args, "phn2num": self.train_loader.dataset.phn2num },save_path ) logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}") dist.barrier() def validate(self, valid_loader=None, hide_progress=True): if valid_loader == None: valid_loader = self.valid_loader self.model.eval() start_val_time = time.time() sum_losses = 0 sum_top10acc = 0 sum_ntoken = 0 sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] with torch.no_grad(): for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)): out = self.model(batch) sum_losses += out['loss'] sum_top10acc += out['top10acc'] sum_ntoken += out['effective_ntoken'] if 'top10acc_by_codebook' in out: for cb in range(self.args.n_codebooks): sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb] dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM) dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM) dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM) if 'top10acc_by_codebook' in out: for cb in range(self.args.n_codebooks): dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM) if self.rank == 0: val_loss = sum_losses / sum_ntoken val_top10acc = sum_top10acc / sum_ntoken # logging self.meters['val_loss'].update(val_loss) logging.info(f"val loss: {val_loss:.5f}") self.writer.add_scalar("val/loss", val_loss, self.progress['step']) self.meters['val_top10acc'].update(val_top10acc) logging.info(f"val top10acc: {val_top10acc:.5f}") self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step']) for cb in range(self.args.n_codebooks): average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi) self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step']) logging.info(f"validation takes: {time.time() - start_val_time:.2f}s") logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}") return val_loss.item() else: return None def _setup_meters(self): meters = {} meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time'] meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc'] meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] for name in meter_names: meters[name] = AverageMeter() return meters def _setup_progress(self): progress = {} progress['best_step'] = 1 progress['best_score'] = np.inf # this records loss value progress['step'] = 1 progress['epoch'] = 1 progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler total_progress = [] # if self.args.resume or self.args.validate: if self.args.resume: progress_pkl = "%s/progress.pkl" % self.args.exp_dir with open(progress_pkl, "rb") as f: total_progress = pickle.load(f) progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1] if self.rank == 0: logging.info("\nResume training from:") logging.info(" epoch = %s" % progress['epoch']) logging.info(" cur_step = %s" % progress['cur_step']) logging.info(" step = %s" % progress['step']) logging.info(" best_step = %s" % progress['best_step']) logging.info(" best_score = %s" % progress['best_score']) return progress, total_progress def _save_progress(self): self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time]) with open("%s/progress.pkl" % self.args.exp_dir, "wb") as f: pickle.dump(self.total_progress, f) def _setup_dataloader(self): assert self.args.dataset == 'gigaspeech', "only gigaspeech is supported for now" train_dataset, val_dataset = gigaspeech.dataset(self.args, 'train'), gigaspeech.dataset(self.args, 'validation') if self.args.dynamic_batching: train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0) valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0) else: train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True) valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False) if self.progress['step'] > 1: train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step']) if self.args.dynamic_batching: train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=self.args.num_workers//self.world_size, collate_fn=train_dataset.collate, persistent_workers=True ) valid_loader = torch.utils.data.DataLoader(val_dataset, batch_sampler=valid_sampler, num_workers=self.args.num_workers//self.world_size, collate_fn=val_dataset.collate, persistent_workers=True ) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size//self.world_size, sampler=train_sampler, num_workers=self.args.num_workers//self.world_size, collate_fn=train_dataset.collate, persistent_workers=True ) valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.args.batch_size//self.world_size, sampler=valid_sampler, num_workers=self.args.num_workers//self.world_size, collate_fn=val_dataset.collate, persistent_workers=True ) return len(train_dataset), train_sampler, train_loader, valid_loader def _setup_models(self): model = voicecraft.VoiceCraft(self.args) if self.rank == 0: logging.info(model) logging.info("model parameters") print_model_info(model) if self.progress['step'] > 1: bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu") model.load_state_dict(bundle['model']) optim_states = bundle['optimizer'] scheduler_states = bundle['scheduler'] if self.rank == 0: logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step'])) del bundle['model'] else: optim_states = None scheduler_states = None if self.args.load_model_from != None and self.progress['step'] <= 1: sd = torch.load(self.args.load_model_from, map_location="cpu")['model'] model.load_state_dict(sd) del sd if self.args.optimizer_name == "ScaledAdam": trainables = [p for p in model.parameters() if p.requires_grad] else: no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], "weight_decay": 0.0, }, ] if len(optimizer_grouped_parameters[1]['params']) == 0: logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names") trainables = optimizer_grouped_parameters[0] else: trainables = optimizer_grouped_parameters model.to(self.device) return model, trainables, optim_states, scheduler_states def _setup_optimizer(self): if self.args.optimizer_name == "ScaledAdam": parameters_names = [] parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad]) optimizer = ScaledAdam( self.trainables, lr=self.args.lr, betas=(0.9, 0.95), clipping_scale=2.0, parameters_names=parameters_names, show_dominant_parameters=False, clipping_update_period=self.args.clipping_update_period, ) scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) else: optimizer = AdamW(self.trainables, lr=self.args.lr) warmup_steps = self.total_step * self.args.warmup_fraction def lr_lambda(current_step: int): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps)) ) scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) # if resume if self.progress['step'] > 1: optimizer.load_state_dict(self.optim_states) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() del self.optim_states scheduler.load_state_dict(self.scheduler_states) optimizer.zero_grad() return optimizer, scheduler def seed_everything(self, seed=1): os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True