我爱编程

Permeability代码分析

2018-03-05  本文已影响20人  程序猪小羊

1

## load the label
ts_label_adr = (adr_input + '/permlty/ts_permeability.xlsx')
file_label = pd.ExcelFile(ts_label_adr)
labels = file_label.parse(header = None)
print(type(labels))
# data_indx = np.load('index.npy','r')
indx = np.random.randint(0, 8300, size = 1000)
np.save(open('index.npy','wb'), indx)
labels_sub = labels.iloc[indx]

2 input处理

each sample: 100100100

4 Split data into train, test and validation sets

'''
shuffle and split data into train and test sets
'''
def shuffle_split(data, label):
    data = np.array(data)
    label = np.array(label)
    label = 10**11 * label
    if len(data) == len(label):
        print('checked out')
    indx = np.random.permutation(len(data))
    test_size = int(0.2*len(data))
    test_indx = indx[:test_size]
    train_indx = indx[test_size:]
    val_indx = train_indx[:test_size]
    train_indx = train_indx[test_size:]
    train_dat, train_tar, val_dat, val_tar, test_dat, test_tar = data[train_indx], label[train_indx], data[val_indx], label[val_indx], data[test_indx], label[test_indx]
    return train_dat, train_tar, val_dat, val_tar, test_dat, test_tar

train_dat, train_tar, val_dat, val_tar, test_dat, test_tar = shuffle_split(data, labels_sub)

del(data, file_label, labels_sub) # free memory
or epoch in range(num_epoch):
    print epoch
    optimizer.zero_grad()
    for dat, tar in train_loader: 
        structure = Variable(dat.view(-1,100,100,100))
        permeability = Variable(tar.view(-1,1))

目前给我了:
extracted features, the npy index => I only need to import the corresponding permeability data.

上一篇 下一篇

猜你喜欢

热点阅读