learning

python argparse

2022-09-11  本文已影响0人  吃醋不吃辣的雷儿
import argparse
import os

def parse_common_args(parser):
  parser.add_argument('--model_type', type=str, default='base_model', help='used in model_entry.py')
  parser.add_argument('--dataset', type=str, default='base_dataset', help='used in data_entry.py')
  parser.add_argument('--save_prefix', type=str, default='save_prefix', help='the prefix of training or test model to save')
  parser.add_argument('--load_model_path', type=str, default='/checkpoints/base_model_0.pth', help='model path to load')
  parser.add_argument('--validation_set', type=str, default='/data/base_dataset/validation.csv', help='validation set')
  #parser.add_argument('--gpus', nargs='+', type=int)
  parser.add_argument('--seed', type=int, default=1)
  return parser

def parse_train_args(parser):
  parser = parse_common_args(parser)
  parser.add_argument('--lr', type=float, default=1e-4, help='learning_rate')
  parser.add_argument('--momentum', type=float, default=0.9, help='momentum for sgd, alpha parameter for adam')
  parser.add_argument('--beta', type=float, default=0.999, help='beta parameter for adam')
  parser.add_argument('--weight_decay', '--wd', type=float, default=0.0, help='weight decay')
  parser.add_argument('--model_dir', type=str, default='', help='if blank then auto generated')
  parser.add_argument('--train_set', type=str, default='/data/base_dataset/train.csv', help='train set')
  parser.add_argument('--batch_size', type=int, default=32, help='batch size')
  parser.add_argument('--training_epoch', type=int, default=1000, help='total training epochs')
  return parser

def get_train_args():
  new_parser = argparse.ArgumentParser()
  new_parser = parse_train_args(new_parser)
  #args = new_parser.parse_args() error in jupyter notebook but it works well in other envs
  args = new_parser.parse_known_args()[0] # same as args = new_parser.parse_args()
  return args

def get_train_model_dir(args):
  model_dir = os.path.join('checkpoints', args.model_type + '_' + args.save_prefix)
  if not os.path.exists(model_dir):
    os.system('mkdir -p ' + model_dir)
  args.model_dir = model_dir

train_args = get_train_args()
get_train_model_dir(train_args)
print("args:", train_args)
上一篇下一篇

猜你喜欢

热点阅读