pytorch-crf

2021-05-17  本文已影响0人  三方斜阳

官方文档:
pytorch-crf — pytorch-crf 0.7.2 documentation
使用pytorch 实现的条件随机场(CRF)模型,基于 AllenNLP CRF 模块,关于 CRF 的原理理解可以看这篇:CRF-条件随机场 - 简书 (jianshu.com)

1. 安装:

pip install pytorch-crf

2. 导入模块使用:

import torch
from torchcrf import CRF
num_tags = 5  # number of tags is 5
model = CRF(num_tags , batch_first=True)

3. 计算转移概率:

seq_length = 3  # maximum sequence length in a batch
batch_size = 2  # number of samples in the batch
emissions = torch.randn(batch_size,seq_length, num_tags)
>>
tensor([[[ 0.3920, -2.0889,  1.0805, -0.6806, -0.0954],
         [ 0.1010,  0.2014, -0.0918, -0.7187, -1.2575],
         [-0.6948,  0.0528, -1.9853,  0.1679, -0.7857]],

        [[-1.0272, -0.2852, -0.5759,  1.3462,  0.7249],
         [ 0.6465,  0.1241, -0.9154, -0.6966, -0.0647],
         [-1.4029, -1.0029, -1.1149,  0.9312,  0.0092]]])
>>
tags = torch.tensor([[0,2,3], [1,4,1]], dtype=torch.long)  #(batch_size, seq_length)
model(emissions, tags)
>>
tensor(-9.8121, grad_fn=<SumBackward0>)

4. 如果输入有padding

# mask size is (batch_size,seq_length)
# the last sample has length of 1
mask = torch.tensor([[1, 1,1], [1, 1,0]], dtype=torch.uint8)
model(emissions, tags, mask=mask)
>>
tensor(-8.7959, grad_fn=<SumBackward0>)

5. 解码:

model.decode(emissions , mask=mask)
>>
[[1, 4, 4], [4, 1]]
上一篇 下一篇

猜你喜欢

热点阅读