from pathlib import Path import torch import pickle import argparse import logging import torch.distributed as dist from config import MyParser from steps import trainer if __name__ == "__main__": formatter = ( "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" ) logging.basicConfig(format=formatter, level=logging.INFO) torch.cuda.empty_cache() args = MyParser().parse_args() logging.info(args) exp_dir = Path(args.exp_dir) exp_dir.mkdir(exist_ok=True, parents=True) logging.info(f"exp_dir: {str(exp_dir)}") if args.resume: resume = args.resume assert(bool(args.exp_dir)) with open("%s/args.pkl" % args.exp_dir, "rb") as f: old_args = pickle.load(f) new_args = vars(args) old_args = vars(old_args) for key in new_args: if key not in old_args or old_args[key] != new_args[key]: old_args[key] = new_args[key] args = argparse.Namespace(**old_args) args.resume = resume else: with open("%s/args.pkl" % args.exp_dir, "wb") as f: pickle.dump(args, f) dist.init_process_group(backend='nccl', init_method='env://') rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(rank) my_trainer = trainer.Trainer(args, world_size, rank) my_trainer.train()