推荐算法

行为序列建模:MIMN系列1——原理初探和源码解析

2023-06-21  本文已影响0人  xiaogp

关键词行为序列建模MIMNRNN神经图灵机Attention

内容摘要

本文主要是MIMN原理迅速扫描,实战部分见行为序列建模:MIMN系列2——消费Kafka实时预测代码实战


研究背景

本文受到字节跳动技术团队的一片博客《行为序列模型在抖音风控中的应用》的启发,在长序列建模中引入MIMN算法(Multi-channel user Interest Memory Network),进一步研究了阿里妈妈MIMN的论文和源码,将该算法成功部署到了风控业务系统,使得模型可以接受任意长度的历史序列对实体进行风险预测,同时引入外部存储记录在此之前所有的记忆状态,当有新的序列元素进入时,读写记录实时预测,简单而言相比于原始的通过滑窗限制序列长度的LSTM算法,MIMN具有两大优势:

MIMN原理迅速概括

MIMN论文涉及好几个独立的知识点,作者的创新是将这些技术串起来解决了一个实际的问题,其中设计的子模块包括NTM神经图灵机MIU记忆感知单元DIN注意力网络三个知识点,本文对于这三块不做展开,只在整体层面介绍下几大模块的最用,以及内部参数的更新维护方式,原论文地址

(1)模型输入输出介绍

下面先从模型的输入开始了解MIMN,左侧橙色是增存量序列构建记忆的过程,右侧是在线部署时的预测部分。


模型架构

对于增存量记忆构建部分,输入是历史所有序列元素,每个元素包括物品id和物品的其他上下文信息拼接的结果,序列元素输入的目的是维护了一个M矩阵S矩阵,对于每一个用户都有它对应的M和S矩阵,每来一个新的序列元素,都会对M和S进行更新

对于在线部署部分,输入是目标物品(Target Ad)历史记忆的读输出(Read Head)记忆感知模块和目标物品的Attention输出,以及其他上下文信息(Context Feas),四大输入拼接之后两层全连接在softmax得到0-1的输出,预测用户是否对目标物品有行为交互。

对于序列元素,是由历史到现在所有商品/广告形成的序列,细分的话有三种,一种是历史商品,一种是最后一个商品(或者是当前最新的一次行为商品),一种是目标商品(通过召回得到的候选商品),三个作用如下

搞清楚三种元素的区别基本MIMN大体上吃透一半了。

(2)模型部署介绍

模型部署也分为增存量记忆维护,和线上预测两个部分

模型部署

虚线下面是增存量记忆维护,增量和存量的行为序列产出UIC Server的M矩阵和S矩阵以及其他记忆信息,没来一个新的序列元素就更新UIC的内容,不需要全部从头开始重新计算记忆信息。虚线上的在线预测部分,简单而言就是根据目标物品信息,用户静态信息,再去UIC中拿到无延迟的记忆信息,预测得到用户对目标响应概率。这两个流程是完全解耦的,相当于UIC对实时预测部分是无延迟的,不再像传统RNN那样维护历史序列id,而是维护一个历史到现在为止的记忆矩阵在外部存储即可。


MIMN源码速览

下面进一步了解MIMN都从源码开始,源码地址,源码比较复杂涉及一些其他算法,挑一些重点记录一下。

(1)主模型框架类

模型的主类是Model_MIMN

class Model_MIMN(Model):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, MEMORY_SIZE, SEQ_LEN=400, Mem_Induction=0,
                 Util_Reg=0, use_negsample=False, mask_flag=False):
        super(Model_MIMN, self).__init__(n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE,
                                         BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN")
        self.reg = Util_Reg
...

该类继承Model类,Model类主要包含输入序列id的embedding映射过程和最后的全连接过程,NTM,MIU,DIN Attention全部在子类Model_MIMN中。

class Model(object):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=False, Flag="DNN"):
        self.model_flag = Flag
        self.reg = False
        self.use_negsample = use_negsample
        with tf.name_scope('Inputs'):
        ...
        # Embedding layer
        with tf.name_scope('Embedding_layer'):
        ...
    # 基于之前网络的输出构造最后的全连接层
    def build_fcn_net(self, inp, use_dice=False):
        bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1')
        ...

从功能上来说Model_MIMN的目的就是构造出最后一层全连接的输入inp,inp输入到全连接层,全连接包含batchNorm和两层全连接,和上图灰色的在线预测部分内容一致。

(2)MIMN单元

这是整个代码的核心,先看MIMN单元的实例化

cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=2 * EMBEDDING_DIM,
                             read_head_num=1, write_head_num=1,
                             reuse=False, output_dim=HIDDEN_SIZE, clip_value=20, batch_size=BATCH_SIZE,
                             mem_induction=Mem_Induction, util_reg=Util_Reg)

在Model_MIMN中实例化了一个MIMN单元,而每一个序列的输入都会进这个MIMN单元,全局共享这个MIMN单元的模型参数,比如控制器和MIU中的GRU部分。在实例化MIMN单元的时候,这一段代码初始化了S矩阵

        if self.mem_induction > 0:
            self.channel_rnn = single_cell(self.memory_vector_dim)
            # TODO channel_rnn_state是S矩阵 [[256, 32], [256, 32], [256, 32], [256, 32]]
            self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]
            self.channel_rnn_output = [tf.zeros(((batch_size, self.memory_vector_dim))) for i in range(memory_size)]

S矩阵为全0初始化,维度是[memory_size, batch_size, memory_dim],memory_size是记忆矩阵的高,memory_dim是记忆矩阵的宽,每个输入进来的样本都会有有一个自己的S矩阵。

下面初始化M矩阵的状态,当模型才开始训练和用户处于冷启动的时候,状态需要初始化,M矩阵比S矩阵复杂,会多一些相关的变量

state = cell.zero_state(BATCH_SIZE, tf.float32)

注意zero_state将BATCH_SIZE传进去,说明初始化和输入训练的用户数量有关,实际是每个用户都分配了一个初始化状态。举个例子看M矩阵的初始化

M = expand(
                tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
                                        trainable=False)),
                dim=0, N=batch_size)
def expand(x, dim, N):
    return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)

对于每一个输入的用户,给他一个均值是0标准差是1e-5的随机(4,32)的初始化,然后复制batch_size(比如256)的份数,拼接成(256,4,32)的该batch下的init_M矩阵。由此可见虽然每个用户都给到一个单独的初始化M,但是他们初始化的结果是一模一样的,注意该变量trainable=False,不随着损失函数优化迭代。同理创建controller_state
,read_vector,w_list,M,key_M,w_aggre其他NTM需维护的变量,其中w_list包含了读头和写头。

(3)历史序列刷存量构建M和S矩阵

在MIMN单元实例化和MIMN的state初始化后,作者开始将历史200长度的序列灌入MIMN单元,代码如下

        for t in range(SEQ_LEN):
            output, state, temp_output_list = cell(self.item_his_eb[:, t, :], state)
            if mask_flag:
                # TODO mask的作用是修正状态,排除prepare阶段由于padding导致的state变动
                state = clear_mask_state(state, begin_state, begin_channel_rnn_output, self.mask, cell, t)
            # 记录下每个序列元素输出的output和status
            self.mimn_o.append(output)
            self.state_list.append(state)

代码里面通过item_his_eb[:, t, :]切片拿到了对应步长的序列元素,和当前的state一起输入MIMN单元,第一个元素对应的state是cell.zero_state得到的状态,后面的都是在循环中更新最新的state给下一个序列元素使用。注意这个for循环构造了一张tensorflow长图,及从第一个MIMN走到最后一个MIMN的路径,每一个样本,每一个批次进来的时候,都要经过这条路径,互不干扰,代码里面的self.state_list可以打印出来看一下,每一个样本的第一次state都是0初始化,不会存在参数继承的情况。
clear_mask_state函数是避免左边padding为0给state带来影响,代码如下

        def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t):
            # TODO mask[:, t] = [256, 1] => [256, 1]
            # TODO 如果mask是0相当于将controller_state重新置为begin_state,全0初始化,否则保持原样不变
            state["controller_state"] = (1 - tf.reshape(mask[:, t], (batch_size, 1))) * begin_state[
                "controller_state"] + tf.reshape(mask[:, t], (batch_size, 1)) * state["controller_state"]
            ...

以controller_state的计算为例,如果mask是0(代表padding了0),则左式保留controller_state打回原样成为begin_state,否则mask是1(代表不padding,是实际的序列元素),则左式删除,右式和state["controller_state"]没有差异保留模型对controller_state的更改。

(4)看看MIMN在做什么

下面深入这个cell(self.item_his_eb[:, t, :], state),看看MIMN在做什么,代码较长,挑提纲挈领的说。先看看这东西输入输出啥

def __call__(self, x, prev_state):
    return read_output, {
                "controller_state": controller_state,
                "read_vector_list": read_vector_list,
                "w_list": w_list,
                "M": M,
                "key_M": key_M,  # TODO key_M用完了之后没有修改
                "w_aggre": w_aggre,
                "sum_aggre": sum_aggre
             }, output_list

输入是当前步长的元素embedding和当前最新的state,输出是读M矩阵的输出,最新的状态,以及读S矩阵的输出,简单说一下三个输出的代码链路

总结数据输入MIMN单元之后,输出读M和S矩阵的输出,以及更新M和S矩阵的参数状态,其中读M和S矩阵的输出要输入最后的全连接模型进行ctr预测,更新M和S矩阵的参数状态需要输入给下一个序列元素进行记忆更新来表征用户的行为。

(5)MIMN单元的后处理,构造主模型输入

MIMN的输出需要准备构造为最终主模型的输入的,首先用拥有最新的state的MIMN单元将目标商品灌进来走一边,拿到读输出,来表征原始记忆信息,第二第三全部不要,只要read_out

read_out, _, _ = cell(self.item_eb, state)

然后拿到现在最新的读S矩阵的输出,和目标商品一起输入给DIN Attention,提取高阶特征

        if Mem_Induction == 1:
            channel_memory_tensor = tf.concat(temp_output_list, 1)
            multi_channel_hist = din_attention(self.item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal')
            # TODO read_out是读取M矩阵输出的结果,multi_channel_hist是读取S矩阵输出的结果,其他都是目标商品自身特征和上下文特征
            inp = tf.concat([self.item_eb, self.item_his_eb_sum, read_out, tf.squeeze(multi_channel_hist),
                             mean_memory * self.item_eb], 1)

最终的inp包含read_out, tf.squeeze(multi_channel_hist)这两大主要特征,以及其他上下文特征。


最终输入构造

在回过头来看图示,很清楚了呀,Target Ad拿到M的Read Head,同时和最新的S一起输入Attention。inp最终输入全连接进行ctr预测。整个代码的概览结束,里面复杂的NTM和DIN Attention先不展开研究。


MIMN参数维护方式总结

作者的代码是训练部分,该代码的目的仅仅是训练出控制器GRU,MIU的GRU,DIN以及其他几个全连接的参数,保存在tensorflow网络中,而S和M矩阵虽然在里面也产出了,但是真正部署上线肯定是重新刷历史存量所有序列得到的,而不是采用padding和截取200的方式,示意图如下

模型参数如何保存

其中NTM的读写w权重直接基于cos相似度计算得到,得到后直接更新M矩阵,不需要保存,其他记忆部分都是保存到外部存储自行维护,而右侧部分全部是tensorflow图来维护,不需要手动维护,在线上环节,读取外部存储拿到记忆参数,输入给tensorflow图即可完成预测。
另外看一下记忆参数是如何初始化,以及如何更新的


参数的初始化和更新方式

其中有的初始化是需要模型学习的,在部署的时候需要在训练的网络中将它恢复出来,否则初始化不一样,有些初始化是0初始化是写死的,相对而言方便一点。

上一篇下一篇

猜你喜欢

热点阅读