训练模型的batch样本如何生成?

2019-03-25  本文已影响0人  夕宝爸爸

代码摘自https://github.com/RandolphVI/Multi-Label-Text-Classification中的data_helper.py

def batch_iter(data, batch_size, num_epochs, shuffle=True):
    """
    含有 yield 说明不是一个普通函数,是一个 Generator.
    函数效果:对 data,一共分成 num_epochs 个阶段(epoch),在每个 epoch 内,如果 shuffle=True,就将 data 重新洗牌,
    批量生成 (yield) 一批一批的重洗过的 data,每批大小是 batch_size,一共生成 int(len(data)/batch_size)+1 批。

    Args:
        data: The data
        batch_size: The size of the data batch
        num_epochs: The number of epochs
        shuffle: Shuffle or not (default: True)
    Returns:
        A batch iterator for data set
    """
    data = np.array(data)
    data_size = len(data)
    num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
    for epoch in range(num_epochs):
        # Shuffle the data at each epoch
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            yield shuffled_data[start_index:end_index]
上一篇下一篇

猜你喜欢

热点阅读