python argparse
2022-09-11 本文已影响0人
吃醋不吃辣的雷儿
import argparse
import os
def parse_common_args(parser):
parser.add_argument('--model_type', type=str, default='base_model', help='used in model_entry.py')
parser.add_argument('--dataset', type=str, default='base_dataset', help='used in data_entry.py')
parser.add_argument('--save_prefix', type=str, default='save_prefix', help='the prefix of training or test model to save')
parser.add_argument('--load_model_path', type=str, default='/checkpoints/base_model_0.pth', help='model path to load')
parser.add_argument('--validation_set', type=str, default='/data/base_dataset/validation.csv', help='validation set')
#parser.add_argument('--gpus', nargs='+', type=int)
parser.add_argument('--seed', type=int, default=1)
return parser
def parse_train_args(parser):
parser = parse_common_args(parser)
parser.add_argument('--lr', type=float, default=1e-4, help='learning_rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum for sgd, alpha parameter for adam')
parser.add_argument('--beta', type=float, default=0.999, help='beta parameter for adam')
parser.add_argument('--weight_decay', '--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--model_dir', type=str, default='', help='if blank then auto generated')
parser.add_argument('--train_set', type=str, default='/data/base_dataset/train.csv', help='train set')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--training_epoch', type=int, default=1000, help='total training epochs')
return parser
def get_train_args():
new_parser = argparse.ArgumentParser()
new_parser = parse_train_args(new_parser)
#args = new_parser.parse_args() error in jupyter notebook but it works well in other envs
args = new_parser.parse_known_args()[0] # same as args = new_parser.parse_args()
return args
def get_train_model_dir(args):
model_dir = os.path.join('checkpoints', args.model_type + '_' + args.save_prefix)
if not os.path.exists(model_dir):
os.system('mkdir -p ' + model_dir)
args.model_dir = model_dir
train_args = get_train_args()
get_train_model_dir(train_args)
print("args:", train_args)