自然语言处理AI人工智能与数学之美

序列标注任务常用方法

2022-06-27  本文已影响0人  晓柒NLP与药物设计

1. HMM

1.1 模型原理

HMM中,有5个基本元素:{N,M,A,B,π},结合序列标志任务(NER)对其的概念定义为:

而以上的这些元素,都是可以从训练语料集中统计出来的。最后根据这些统计值,应用维特比(viterbi)算法,算出词语序列背后的标注序列了,命名实体识别本质上就是序列标注,只需要定义好对应的标签以及模式串,就可以从标注序列中提取出实体

假设对于t时刻的一个词W_t公式就可写作:
P(x_t)=P(x_t|x_{t-1})P(x_{t-1})\\ P(x_t|x_{t-1})=\frac{count(x_{t-1}\rightarrow x)}{count(x_{t-1})}
齐次马尔科夫性假设:隐藏的马尔科夫链在任意时刻t的状态只依赖于其前一时刻的状态,与其他时刻的状态及观测无关,也与时刻t无关

观测独立性假设:假设任意时刻的观测只依赖于该时刻的马尔科夫链的状态,与其他观测即状态无关.观测概率的公式可以表达如下:
P(y_t)=P(y_t|x_t)P(x_t)\\ P(y_t|x_t)=\frac{count(x \rightarrow y)}{count(x)}
将发射概率和转移概率相结合,得到整个句子最后的公式:
P(x,y)=P(y_1|start)\prod_{t=1}^{L-1}P(y_{t+1}|y_t)P(end|y_t)\prod_{t=1}^LP(x_t|y_t)

1.2 模型实现

class Model(object):
    def __init__(self, hidden_status):
        # param hidden_status: int, 隐状态数
        self.hmm_N = hidden_status
        # 状态转移概率矩阵 A[i][j]表示从i状态转移到j状态的概率
        self.hmm_A = torch.zeros(self.hmm_N, self.hmm_N)
        # 初始状态概率  Pi[i]表示初始时刻为状态i的概率
        self.hmm_pi = torch.zeros(self.hmm_N)
    def _build_corpus_map(self, sentences_list):
        char2id = {}
        for sentence in sentences_list:
            for word in sentence:
                if word not in char2id:
                    char2id[word] = len(char2id)
        return char2id
    def _init_emission(self):
        self.hmm_M = len(self.word2id)
        # 观测概率矩阵, B[i][j]表示i状态下生成j观测的概率
        self.hmm_B = torch.zeros(self.hmm_N, self.hmm_M)
    def train(self, sentences_list, tags_list):
        """
        参数:
            sentences_list: list,其中每个元素由字组成的列表,如 ['担','任','科','员']
            tags_list: list,其中每个元素是由对应的标注组成的列表,如 ['O','O','B-TITLE', 'E-TITLE']
        """
        start_time = time.time()
        assert len(sentences_list) == len(tags_list), "the lens of tag_lists is not eq to word_lists"
        logger.info('开始构建token字典...')
        self.word2id = self._build_corpus_map(sentences_list)
        self.tag2id = self._build_corpus_map(tags_list)
        self.id2tag = dict((id_, tag) for tag, id_ in self.tag2id.items())
        logger.info('训练语料总数:{}'.format(len(sentences_list)))
        logger.info('词典总数:{}'.format(len(self.word2id)))
        logger.info('标签总数:{}'.format(len(self.tag2id)))

        assert self.hmm_N == len(self.tag2id), "hidden_status is {}, but total tag is {}".\
            format(self.hmm_N, len(self.tag2id))
        self._init_emission()
        logger.info('构建词典完成{:>.4f}s'.format(time.time()-start_time))
        logger.info('开始构建转移概率矩阵...')
        # 估计转移概率矩阵
        for tags in tqdm(tags_list):
            seq_len = len(tags)
            for i in range(seq_len - 1):
                current_tagid = self.tag2id[tags[i]]
                next_tagid = self.tag2id[tags[i+1]]
                self.hmm_A[current_tagid][next_tagid] += 1.
        # 问题:如果某元素没有出现过,该位置为0,这在后续的计算中是不允许的
        # 解决方法:我们将等于0的概率加上很小的数
        self.hmm_A[self.hmm_A == 0.] = 1e-10
        self.hmm_A = self.hmm_A / self.hmm_A.sum(axis=1, keepdims=True)
        logger.info('完成转移概率矩阵构建. {:>.4f}s'.format(time.time() - start_time))
        logger.info('开始构建观测概率矩阵...')
        # 估计观测概率矩阵
        for tags, sentence in tqdm(zip(tags_list, sentences_list)):
            assert len(tags) == len(sentence), \
                "the lens of tag_list is not eq to word_list"
            for tag, word in zip(tags, sentence):
                tag_id = self.tag2id[tag]
                word_id = self.word2id[word]
                self.hmm_B[tag_id][word_id] += 1.
        self.hmm_B[self.hmm_B == 0.] = 1e-10
        self.hmm_B = self.hmm_B / self.hmm_B.sum(axis=1, keepdims=True)
        logger.info('完成观测概率矩阵构建. {:>.4f}s'.format(time.time() - start_time))
        logger.info('初始化初识状态概率...')
        # 估计初始状态概率
        for tags in tqdm(tags_list):
            init_tagid = self.tag2id[tags[0]]
            self.hmm_pi[init_tagid] += 1.
        self.hmm_pi[self.hmm_pi == 0.] = 1e-10
        self.hmm_pi = self.hmm_pi / self.hmm_pi.sum()
        logger.info('完成初始状态概率构建. {:>.4f}s'.format(time.time() - start_time))
    def predict(self, sentences_list):
        pred_tag_lists = []
        for sentence in tqdm(sentences_list):
            pred_tag_list = self.decoding(sentence)
            pred_tag_lists.append(pred_tag_list)
        return pred_tag_lists
    def decoding(self, word_list):
        """
        使用维特比算法对给定观测序列求状态序列, 这里就是对字组成的序列,求其对应的标注。
        维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划求概率最大路径(最优路径)
        这时一条路径对应着一个状态序列
        """
        A = torch.log(self.hmm_A)
        B = torch.log(self.hmm_B)
        Pi = torch.log(self.hmm_pi)
        # 初始化 维比特矩阵viterbi 它的维度为[状态数, 序列长度]
        seq_len = len(word_list)
        viterbi = torch.zeros(self.hmm_N, seq_len)
        # 等解码的时候,我们用backpointer进行回溯,以求出最优路径
        backpointer = torch.zeros(self.hmm_N, seq_len).long()
        start_wordid = self.word2id.get(word_list[0], None)
        Bt = B.t()
        if start_wordid is None:
            # 如果字不再字典里,则假设状态的概率分布是均匀的
            bt = torch.log(torch.ones(self.hmm_N) / self.hmm_N)
        else:
            bt = Bt[start_wordid]
        viterbi[:, 0] = Pi + bt
        backpointer[:, 0] = -1
        for step in range(1, seq_len):
            wordid = self.word2id.get(word_list[step], None)
            # 处理字不在字典中的情况
            # bt是在t时刻字为wordid时,状态的概率分布
            if wordid is None:
                # 如果字不再字典里,则假设状态的概率分布是均匀的
                bt = torch.log(torch.ones(self.hmm_N) / self.hmm_N)
            else:
                bt = Bt[wordid]  # 否则从观测概率矩阵中取bt
            for tag_id in range(len(self.tag2id)):
                max_prob, max_id = torch.max(
                    viterbi[:, step - 1] + A[:, tag_id],
                    dim=0
                )
                viterbi[tag_id, step] = max_prob + bt[tag_id]
                backpointer[tag_id, step] = max_id
        # 终止, t=seq_len 即 viterbi[:, seq_len]中的最大概率,就是最优路径的概率
        best_path_prob, best_path_pointer = torch.max(
            viterbi[:, seq_len - 1], dim=0
        )
        # 回溯,求最优路径
        best_path_pointer = best_path_pointer.item()
        best_path = [best_path_pointer]
        for back_step in range(seq_len - 1, 0, -1):
            best_path_pointer = backpointer[best_path_pointer, back_step]
            best_path_pointer = best_path_pointer.item()
            best_path.append(best_path_pointer)
        # 将tag_id组成的序列转化为tag
        assert len(best_path) == len(word_list)
        tag_list = [self.id2tag[id_] for id_ in reversed(best_path)]
        return tag_list

2. CRF

2.1 模型原理

相对于HMM,CRF有两个优势

x=(x_1,...,x_m)是观测序列,s=(s_1, ...,s_m)是状态序列,wCRF模型的参数,则s的条件概率是:
p(s|x;w)=\frac{exp(wF(x,s))}{\sum_{\bar{s}}exp(wF(x,\bar{s}))}
其中F(x,s)CRF特征函数集,加上正则化项,在做对数变换就得到Loss
L(w)=-\sum_{i=1}^nlogp(s_i|x_i,w)+\frac{\lambda_2}{2}|w|^2+\lambda_1|w|
CRF训练的目的是求解令p(s|x,w)最大化的w

2.2 模型实现(使用TorchCRF第三方库)

$ pip install TorchCRF
import torch
from TorchCRF import CRF
device = "cuda"
batch_size = 2
sequence_size = 3
num_labels = 5
mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device)  # (batch_size, sequence_size)
hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
crf = CRF(num_labels)
Computing log-likelihood (used where forward)
crf.forward(hidden, labels, mask)
>>> tensor([-7.6204, -3.6124], device='cuda:0', grad_fn=<ThSubBackward>)
crf.viterbi_decode(hidden, mask)
>>> [[0, 2, 2], [4, 0]]

3. BiLSTM+CRF

3.1 模型原理

应用于NER中的BiLSTM-CRF模型主要由Embedding层(主要有词向量,字向量以及一些额外特征),双向LSTM层,以及最后的CRF层构成。实验结果表明biLSTM-CRF已经达到或者超过了基于丰富特征的CRF模型,成为目前基于深度学习的NER方法中的最主流模型。在特征方面,该模型继承了深度学习方法的优势,无需特征工程,使用词向量以及字符向量就可以达到很好的效果,如果有高质量的词典特征,能够进一步获得提高

微信图片_20220627150456.png

3.2 模型实现

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.tagset_size = config.num_label + 2
        self.hidden_dim = config.hidden_size
        self.start_tag_id = config.num_label
        self.end_tag_id = config.num_label + 1
        self.device = config.device
        self.embedding = nn.Embedding(config.vocab_size, config.emb_size, padding_idx=config.vocab_size - 1) 
        torch.nn.init.uniform_(self.embedding.weight, -0.10, 0.10)
        self.encoder = nn.LSTM(config.emb_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTM(config.hidden_size * 2, config.hidden_size, batch_first=True, bidirectional=True)
        self.linear = nn.Linear(2*config.hidden_size, self.tagset_size)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[self.start_tag_id, :] = -10000.
        self.transitions.data[:, self.end_tag_id] = -10000.
        self.hidden = self.init_hidden()
    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim).to(self.device),
                torch.randn(2, 1, self.hidden_dim).to(self.device))
    def _get_lstm_features(self, input_ids):
        embeds = self.embedding(input_ids).view(1, input_ids.shape[1], -1)
        self.encoder.flatten_parameters()
        self.decoder.flatten_parameters()
        encoder_out, _ = self.encoder(embeds, self.hidden)
        decoder_out, _ = self.decoder(encoder_out, self.hidden)
        decoder_out = decoder_out.view(input_ids.shape[1], -1)
        lstm_logits = self.linear(decoder_out)
        return lstm_logits
    def log_sum_exp(self, smat):
        # 每一列的最大数
        vmax = smat.max(dim=0, keepdim=True).values
        return torch.log(torch.sum(torch.exp(smat - vmax), axis=0, keepdim=True)) + vmax
    def _forward_alg(self, feats):
        alphas = torch.full((1, self.tagset_size), -10000.).to(self.device)
        # 初始化分值分布. START_TAG是log(1)=0, 其他都是很小的值 "-10000"
        alphas[0][self.start_tag_id] = 0.
        # Iterate through the sentence
        for feat in feats:
            # log_sum_exp()内三者相加会广播: 当前各状态的分值分布(列向量) + 发射分值(行向量) + 转移矩阵(方形矩阵)
            # 相加所得矩阵的物理意义见log_sum_exp()函数的注释; 然后按列求log_sum_exp得到行向量
            alphas = self.log_sum_exp(alphas.T + self.transitions + feat.unsqueeze(0))
        # 最后转到EOS,发射分值为0,转移分值为列向量 self.transitions[:, self.end_tag_id]
        score = self.log_sum_exp(alphas.T + 0 + self.transitions[:, self.end_tag_id].view(-1, 1))
        return score.flatten()
    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1).to(self.device)
        tags = torch.cat([torch.tensor([self.start_tag_id], dtype=torch.long).to(self.device), tags])
        for i, feat in enumerate(feats):
            # emit = X0,start + x1,label + ... + xn-2,label + (xn-1, end[0])
            # trans = 每一个状态的转移状态
            score += self.transitions[tags[i], tags[i+1]] + feat[tags[i + 1]]
        # 加上到END_TAG的转移
        score += self.transitions[tags[-1], self.end_tag_id]
        return score
    def _viterbi_decode(self, feats):
        backtrace = []  # 回溯路径;  backtrace[i][j] := 第i帧到达j状态的所有路径中, 得分最高的那条在i-1帧是神马状态
        alpha = torch.full((1, self.tagset_size), -10000.).to(self.device)
        alpha[0][self.start_tag_id] = 0
        for frame in feats:
            smat = alpha.T + frame.unsqueeze(0) + self.transitions
            backtrace.append(smat.argmax(0))  # 当前帧每个状态的最优"来源"
            alpha = smat.max(dim=0, keepdim=True).values
        # Transition to STOP_TAG
        smat = alpha.T + 0 + self.transitions[:, self.end_tag_id].view(-1, 1)
        best_tag_id = smat.flatten().argmax().item()
        best_score = smat.max(dim=0, keepdim=True).values.item()
        best_path = [best_tag_id]
        for bptrs_t in reversed(backtrace[1:]):  # 从[1:]开始,去掉开头的 START_TAG
            best_tag_id = bptrs_t[best_tag_id].item()
            best_path.append(best_tag_id)
        best_path.reverse()
        return best_score, best_path  # 返回最优路径分值 和 最优路径
    def forward(self, sentence_ids, tags_ids):
        tags_ids = tags_ids.view(-1)
        feats = self._get_lstm_features(sentence_ids)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags_ids)
        outputs = (forward_score - gold_score, )
        _, tag_seq = self._viterbi_decode(feats)
        outputs = (tag_seq, ) + outputs
        return outputs
    def predict(self, sentence_ids):
        lstm_feats = self._get_lstm_features(sentence_ids)
        _, tag_seq = self._viterbi_decode(lstm_feats)
        return tag_seq

4. IDCNN+CRF

4.1 模型原理

正常CNN的filter,都是作用在输入矩阵一片连续的区域上,不断sliding做卷积。dilated CNN为这个filter增加了一个dilation width,作用在输入矩阵的时候,会skip所有dilation width中间的输入数据;而filter本身的大小保持不变,这样filter获取到了更广阔的输入矩阵上的数据,看上去就像是膨胀了一般。具体使用时,dilated width会随着层数的增加而指数增加。这样随着层数的增加,参数数量是线性增加的,而receptive field却是指数增加的,可以很快覆盖到全部的输入数据。

图中可见感受域是以指数速率扩大的。原始感受域是位于中心点的1x1区域:

对应在文本上,输入是一个一维的向量,每个元素是一个character embedding

IDCNN对输入句子的每一个字生成一个logits,这里就和BiLSTM模型输出logits完全一样,加入CRF层,用Viterbi算法解码出标注结果,在BiLSTM或者IDCNN这样的网络模型末端接上CRF层是序列标注的一个很常见的方法。BiLSTM或者IDCNN计算出的是每个词的各标签概率,而CRF层引入序列的转移概率,最终计算出loss反馈回网络

4.2 模型实现

class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()
        self.tagset_size = config.num_label + 2
        self.hidden_dim = config.hidden_size
        self.start_tag_id = config.num_label
        self.end_tag_id = config.num_label + 1
        self.device = config.device
        self.embedding = nn.Embedding(config.vocab_size, config.emb_size, padding_idx=config.vocab_size - 1)
        torch.nn.init.uniform_(self.embedding.weight, -0.10, 0.10)
        self.encoder = nn.LSTM(config.emb_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTM(config.hidden_size * 2, config.hidden_size, batch_first=True, bidirectional=True)
        self.linear = nn.Linear(2*config.hidden_size, self.tagset_size)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[self.start_tag_id, :] = -10000.
        self.transitions.data[:, self.end_tag_id] = -10000.
        self.hidden = self.init_hidden()
    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim).to(self.device),
                torch.randn(2, 1, self.hidden_dim).to(self.device))
    def _get_lstm_features(self, input_ids):
        embeds = self.embedding(input_ids).view(1, input_ids.shape[1], -1)
        self.encoder.flatten_parameters()
        self.decoder.flatten_parameters()
        encoder_out, _ = self.encoder(embeds, self.hidden)
        decoder_out, _ = self.decoder(encoder_out, self.hidden)
        decoder_out = decoder_out.view(input_ids.shape[1], -1)
        lstm_logits = self.linear(decoder_out)
        return lstm_logits
    def log_sum_exp(self, smat):
        # 每一列的最大数
        vmax = smat.max(dim=0, keepdim=True).values
        # return (smat - vmax).exp().sum(axis=0, keepdim=True).log() + vmax
        return torch.log(torch.sum(torch.exp(smat - vmax), axis=0, keepdim=True)) + vmax
    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        alphas = torch.full((1, self.tagset_size), -10000.).to(self.device)
        # 初始化分值分布. START_TAG是log(1)=0, 其他都是很小的值 "-10000"
        alphas[0][self.start_tag_id] = 0.
        # Iterate through the sentence
        for feat in feats:
            # log_sum_exp()内三者相加会广播: 当前各状态的分值分布(列向量) + 发射分值(行向量) + 转移矩阵(方形矩阵)
            # 相加所得矩阵的物理意义见log_sum_exp()函数的注释; 然后按列求log_sum_exp得到行向量
            alphas = self.log_sum_exp(alphas.T + self.transitions + feat.unsqueeze(0))
        # 最后转到EOS,发射分值为0,转移分值为列向量 self.transitions[:, self.end_tag_id]
        score = self.log_sum_exp(alphas.T + 0 + self.transitions[:, self.end_tag_id].view(-1, 1))
        return score.flatten()
    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1).to(self.device)
        tags = torch.cat([torch.tensor([self.start_tag_id], dtype=torch.long).to(self.device), tags])
        for i, feat in enumerate(feats):
            # emit = X0,start + x1,label + ... + xn-2,label + (xn-1, end[0])
            # trans = 每一个状态的转移状态
            score += self.transitions[tags[i], tags[i+1]] + feat[tags[i + 1]]
        # 加上到END_TAG的转移
        score += self.transitions[tags[-1], self.end_tag_id]
        return score
    def _viterbi_decode(self, feats):
        backtrace = []  # 回溯路径;  backtrace[i][j] := 第i帧到达j状态的所有路径中, 得分最高的那条在i-1帧是神马状态
        alpha = torch.full((1, self.tagset_size), -10000.).to(self.device)
        alpha[0][self.start_tag_id] = 0
        for frame in feats:
            smat = alpha.T + frame.unsqueeze(0) + self.transitions
            backtrace.append(smat.argmax(0))  # 当前帧每个状态的最优"来源"
            alpha = smat.max(dim=0, keepdim=True).values
        # Transition to STOP_TAG
        smat = alpha.T + 0 + self.transitions[:, self.end_tag_id].view(-1, 1)
        best_tag_id = smat.flatten().argmax().item()
        best_score = smat.max(dim=0, keepdim=True).values.item()
        best_path = [best_tag_id]
        for bptrs_t in reversed(backtrace[1:]):  # 从[1:]开始,去掉开头的 START_TAG
            best_tag_id = bptrs_t[best_tag_id].item()
            best_path.append(best_tag_id)
        best_path.reverse()
        return best_score, best_path  # 返回最优路径分值 和 最优路径
    def forward(self, sentence_ids, tags_ids):
        tags_ids = tags_ids.view(-1)
        feats = self._get_lstm_features(sentence_ids)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags_ids)
        outputs = (forward_score - gold_score, )
        _, tag_seq = self._viterbi_decode(feats)
        outputs = (tag_seq, ) + outputs
        return outputs
    def predict(self, sentence_ids):
        lstm_feats = self._get_lstm_features(sentence_ids)
        _, tag_seq = self._viterbi_decode(lstm_feats)
        return tag_seq

5. BERT+CRF

5.1 模型原理

BERT模型+全连接层:BERT的encoding vector通过FC layer映射到标签集合后,单个token的output vector经过Softmax处理,每一维度的数值就表示该token的词性为某一词性的概率。基于此数据便可计算loss并训练模型。但根据BiLSTM+CRF模型的启发,在BERT+FC layer的基础上增加CRF layer加入一些约束来保证最终的预测结果是有效的。这些约束可以在训练数据时被CRF层自动学习得到,从而减少预测错误的概率

5.2 模型实现

class BertCrfForNer(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None):
        outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
            outputs =(-1*loss,)+outputs
        return outputs # (loss), scores

6. BERT+BiLSTM+CRF

6.1 模型原理

BiLSTM+CRF优点是泛化能力强;缺点是需要大量的标注样本。在样本很少的情况下,效果会很不理想。为了更快速地实现一个实体提取器,提高系统易用性,可以采用迁移学习的思想,在先验知识的基础上进行模型训练,从而使用BERT+BiLSTM+CRF

同样的,输入是wordPiece tokenizer得到的tokenid,进入Bert预训练模型抽取丰富的文本特征得到batch\_size*max_seq\_len * emb\_size的输出向量,输出向量过BiLSTM从中提取实体识别所需的特征,得到batch\_size * max\_seq\_len * (2*hidden\_size)的向量,最终进入CRF层进行解码,计算最优的标注序列

6.2 模型实现

class BertBiLstmCrf(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, out_size, drop_out=0.1, use_pretrained_w2v=False):
        super(BertBiLstmCrf, self).__init__()
        self.bert_path = get_chinese_wwm_ext_pytorch_path()
        self.bert_config = BeitConfig.from_pretrained(self.bert_path)
        self.bert = BertModel.from_pretrained(self.bert_path)
        emb_size = 768
        for param in self.bert.parameters():
            param.requires_grad = True
        self.bilstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, out_size)
        self.dropout = nn.Dropout(drop_out)
        self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, x, lengths):
        emb = self.bert(x)[0]
        emb = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True)
        emb, _ = self.bilstm(emb)
        output, _ = nn.utils.rnn.pad_packed_sequence(emb, batch_first=True, padding_value=0., total_length=x.shape[1])
        output = self.dropout(output)
        emission = self.fc(output)
        batch_size, max_len, out_size = emission.size()
        crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0)
        return crf_scores

NLP新人,欢迎大家一起交流,互相学习,共同成长~~

上一篇下一篇

猜你喜欢

热点阅读