[PyTorch]中的随机mask以及根据句子的长度进行mask

2019-06-10  本文已影响0人  VanJordan
def drop_tokens(embeddings, word_dropout):
    batch, length, size = embeddings.size()
    mask = embeddings.new_empty(batch, length)
    mask = mask.bernoulli_(1 - word_dropout)
    embeddings = embeddings * mask.unsqueeze(-1).expand_as(embeddings).float()
    return embeddings, mask
上一篇 下一篇

猜你喜欢

热点阅读