推荐算法

行为序列建模:MIMN系列2——消费Kafka实时预测代码实战

2023-06-23  本文已影响0人  xiaogp

关键词:MIMNKafka

内容摘要

本文主要是MIMN代码实战,理论部分见行为序列建模:MIMN系列1——原理初探和源码解析


业务背景交代

用户的行为序列不仅可以用户表征用户的兴趣,从而对用户进行商品/广告推荐,也可以表征出用户的行为模式,从而对该模式进行打分形成风控预警。本文举例的业务背景为,用户的行为标签序列,如

user: behavior_label_1, behavior_label_2, behavior_label_3, behavior_label_4 ...

所有行为标签根据时间的先后顺序排开,预测用户是否会在未来达成某种风险行为标签,比如risk_ behavior_label_1,risk_ behavior_label_2,risk_ behavior_label_3...中的任何一个,或者其他Y值标签,最终的目标是基于用户的历史序列表征出embedding,然后预测给出用户的坏率。传统的做法是对用户序列进行截断和padding,比如最近1年最大100个事件,灌入类似LSTM进行二分类预测,而本次实战就是要基于MIMN实现,允许用户历史所有事件序列,并且能够对接事件消息中间件实现实时预测。


MIMN如何运用于风控

风控业务和MIMN的推荐业务从模型任务上基本相同都是二分类,实际上推荐,风控从算法任务上是相通的,不少推荐的论文都被应用在风控中,比如MMOE等算法,而在MIMN场景,推荐和风控是有一个明显区别的,推荐场景的MIMN是有一组候选商品的(Target AD),给出最高得分的候选商品输出给用户推荐,而风控场景没有候选,如果没有候选则没有基于M矩阵的读头输出,没有DIN Attention的标的,因此风控场景下MIMN对应的这个输入实体应该是类似风险行为标签的,比如risk_ behavior_label_1,risk_ behavior_label_2,risk_ behavior_label_3...中的任何一个,用风险行为标签去记忆里面读输出,基于注意力提取历史信息中和风险行为标签相关的信息,本案例采用的策略是所有风险标签的embedding求平均值作为Target AD,因此候选商品是定死不变的,召回是固定的,从而一蹴而就模型可以实时预测,当用户有新的行为进入时,更新S和M矩阵,然后Target AD再次输入MIMN单元直接预测,及用户每有一个新的序列进来,模型的评分都可随之响应。


MIMN代码实战

本案例的目标是MIMN对接Kafka中的实时序列标签,实现实时预测,其中S和M的记忆信息存储到MySQL(或者其他外部存储都行),所有的行为标签都已经经过word2vec进行embedding初始化,另外本案例取消了最后全连接层除了S和M矩阵的其他上下文特征,相当于最后只有M读头的输出和S和targat的DIN Attention的拼接进入分类模型,在字节跳动的技术博客里面也是仅有这两个输入取消了其他内容,如图

风控中的MIMN模型架构图
(1)MIMN单元源码改造

MIMN的源码写的质量很高,仅有一个地方需要增加点内容,需要在MIMN单元的call里面允许输入一个S矩阵,以及输出一个S矩阵,代码如下

    def __call__(self, x, prev_state, prev_s_state=None):
        prev_read_vector_list = prev_state["read_vector_list"]
        # TODO S矩阵
        if prev_s_state:
            self.channel_rnn_state = prev_s_state["channel_rnn_state"]
            # TODO [[batch_size, 32],[batch_size, 32],[batch_size, 32],[batch_size, 32]]
            self.channel_rnn_state = [tf.squeeze(x, axis=0) for x in tf.split(self.channel_rnn_state, self.memory_size)]

如上增加了prev_s_state,赋值给channel_rnn_state,代替和源码中类实例化中的初始化

self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]

在训练阶段还是使用初始化的channel_rnn_state,但是在部署阶段手动传入S矩阵的信息,否则S矩阵无法和外部进行交互。同样的在call输出里面也增加channel_rnn_state

return read_output, {
            "controller_state": controller_state,
            "read_vector_list": read_vector_list,
            "w_list": w_list,
            "M": M,
            "key_M": key_M, 
            "w_aggre": w_aggre
        }, output_list, self.channel_rnn_state

这样S矩阵可传入和输出,从而可以和外部存储交互,摆脱tensorflow图。

(2)增量预测代码新增

该部分是核心,增量预测是模型应用阶段,训练部分的任务仅仅是训练出各种GRU和全连接的参数固定在tensorflow图上,增量部分涉及M和S,以及MIMN单元的交互运用。
该部分的工作是将训练的图上200个MIMN拍成的MIMN序列,从中抽出一个MIMN单元,预测阶段仅对一个标签进行预测,另外所有需要模型学习初始化的参数都要和训练的图保持一致,使用tensorflow的tf.train.init_from_checkpoint实现网络重构

tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map)
该方法传入一个检查点路径,和一个映射,实现将在检查点图里面的变量传递到新图上的同名变量,其中映射记录了检查点中的变量名和新图中变量名的对应关系

增量部分代码如下

class Model_MIMN:
    """
    将存量MIMN网络中的全连接,din,MIU和NTM的参数拿出来,单独对一个序列元素进行更新state和channel_rnn_state
    """

    def __init__(self, seq_length, seq_embedding_size, basic_embedding_size, memory_size, memory_vector_dim,
                 mem_induction=1, util_reg=1, rnn_hidden_size=16, batch_size=1, read_head_num=1):
        # TODO 部署阶段序列长度1,实时更新
        self.input_seq = tf.placeholder(tf.int32, [batch_size, seq_length], name="input_seq")
        with tf.variable_scope("embedding"):
            # TODO 从网络恢复
            self.rule_embedding_var = tf.get_variable("embedding", dtype=tf.float32, shape=[350, seq_embedding_size],
                                                      trainable=False)
            # [None, 1, 32]
            self.rule_embedding = tf.nn.embedding_lookup(self.rule_embedding_var, self.input_seq)
            # [emb_size]
            # TODO 从网络恢复
            self.target_rule_embedding_mean = tf.get_variable("target_embedding", shape=[seq_embedding_size],
                                                              trainable=False)
            # [1, emb_size]
            self.target_rule_batch = tf.expand_dims(self.target_rule_embedding_mean, 0)

        with tf.variable_scope('init'):
            # TODO 第一个序列元素进来的时候需要使用模型训练得到的初始化
            self.read_vector_list = [
                expand(tf.tanh(learned_init(memory_vector_dim)), dim=0, N=batch_size)
                for i in range(read_head_num)]
            self.M = expand(
                tf.tanh(tf.get_variable('init_M', [memory_size, memory_vector_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
                                        trainable=False)),
                dim=0, N=batch_size)
            self.key_M = expand(
                tf.tanh(tf.get_variable('key_M', [memory_size, memory_vector_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))),
                dim=0, N=batch_size)

        # TODO m矩阵和s矩阵全部从外部输入
        self.prev_m_state = {
            "controller_state": tf.placeholder_with_default(tf.zeros(shape=[batch_size, memory_vector_dim]),
                                                            [batch_size, memory_vector_dim],
                                                            name="prev_controller_state"),
            "read_vector_list": [
                tf.placeholder_with_default(self.read_vector_list[0], [batch_size, memory_vector_dim],
                                            name="prev_read_vector_list")],
            "M": tf.placeholder_with_default(self.M, [batch_size, memory_size, memory_vector_dim],
                                             name="prev_M"),
            "w_aggre": tf.placeholder_with_default(tf.zeros(shape=[batch_size, memory_size]),
                                                   [batch_size, memory_size], name="prev_w_aggre"),
            "key_M": tf.placeholder_with_default(self.key_M, [batch_size, memory_size, memory_vector_dim],
                                                 name="prev_key_M")}

        self.prev_s_state = {"channel_rnn_state": tf.placeholder_with_default(
            tf.zeros(shape=[memory_size, batch_size, memory_vector_dim]), [memory_size, batch_size, memory_vector_dim],
            name="prev_channel_rnn_state")}

        # 其他参数
        self.seq_embedding_size = seq_embedding_size
        self.rnn_hidden_size = rnn_hidden_size
        self.seq_length = seq_length
        self.reg = util_reg
        self.memory_size = memory_size
        self.mem_induction = mem_induction
        self.util_reg = util_reg

        # TODO 从网络恢复一个MIMN单元
        self.cell = MIMNCell(controller_units=self.rnn_hidden_size, memory_size=self.memory_size,
                             memory_vector_dim=self.seq_embedding_size, read_head_num=1, write_head_num=1,
                             reuse=False, output_dim=self.rnn_hidden_size, clip_value=20,
                             mem_induction=self.mem_induction, util_reg=self.util_reg, batch_size=batch_size)
        # TODO 新进来一个元素之后更新state,输出不要
        _, self.m_state, temp_output_list, self.s_state = self.cell(self.rule_embedding[:, 0, :],
                                                                    self.prev_m_state, self.prev_s_state)
        # 把所有状态拆出来,否则在图pb拿不到
        self.controller_state = tf.identity(self.m_state["controller_state"], name="controller_state")
        self.read_vector_list = tf.identity(self.m_state["read_vector_list"][0], name="read_vector_list")
        self.M = tf.identity(self.m_state["M"], name="M")
        self.key_M = tf.identity(self.m_state["key_M"], name="key_M")
        self.w_aggre = tf.identity(self.m_state["w_aggre"], name="w_aggre")
        self.channel_rnn_state = tf.identity(tf.concat([tf.expand_dims(x, axis=0) for x in self.s_state], axis=0),
                                             name="channel_rnn_state")

        # TODO 再将target输入一遍,拿到read_out,state不要
        read_out, _, _, _ = self.cell(self.target_rule_batch, self.m_state, {"channel_rnn_state": self.s_state})

        if self.mem_induction == 1:
            channel_memory_tensor = tf.concat(temp_output_list, 1)
            multi_channel_hist = din_attention(self.target_rule_batch, channel_memory_tensor, self.rnn_hidden_size,
                                               None, stag='pal')
            inp = tf.concat([read_out, tf.squeeze(multi_channel_hist, axis=1)], 1)
        else:
            inp = tf.concat([read_out], 1)

        # fc
        bn1 = inp
        # 一层
        dnn1 = tf.layers.dense(bn1, 200, activation=None, name='f1')
        dnn1 = tf.nn.relu(dnn1)

        # 二层
        dnn2 = tf.layers.dense(dnn1, 80, activation=None, name='f2')
        dnn2 = tf.nn.relu(dnn2)

        # 输出维度2,1和0
        dnn3 = tf.layers.dense(dnn2, 2, activation=None, name='f3')
        # 输出
        self.y_hat = tf.nn.softmax(dnn3) + 0.00000001
        self.prob = tf.identity(self.y_hat, name="prob")

其中所有初始化基本照搬源码,从而可以保证原图中的变量命名和新图一致,另外所有的记忆信息全部使用tf.identity重命名,保证在图里面可以将这些变量取出来,从而可以存储到数据库。
使用以下初始化方法实例化增量模型网络

assignment_map = {
        # embedding
        "embedding/embedding": "embedding/embedding",
        "embedding/target_embedding": "embedding/target_embedding",
        # NTM controller
        "controller/gru_cell/gates/kernel": "controller/gru_cell/gates/kernel",
        "controller/gru_cell/gates/bias": "controller/gru_cell/gates/bias",
        "controller/gru_cell/candidate/kernel": "controller/gru_cell/candidate/kernel",
        "controller/gru_cell/candidate/bias": "controller/gru_cell/candidate/bias",
        ...
    }

    seq_embedding_size = 32
    rnn_hidden_size = seq_embedding_size
    memory_vector_dim = seq_embedding_size
    tf.reset_default_graph()
    model = Model_MIMN(seq_length=1, seq_embedding_size=seq_embedding_size, basic_embedding_size=80,
                       memory_size=8, memory_vector_dim=32, mem_induction=1, util_reg=1,
                       rnn_hidden_size=rnn_hidden_size, batch_size=1)

    with tf.Session() as sess:
        tf.train.init_from_checkpoint(os.path.join(ROOT_PATH, "ckpt"), assignment_map)
        sess.run(tf.global_variables_initializer())

其中ckpt是训练代码输出的ckpt文件,训练部分代码为源码给出,稍作修改即可在本业务数据上跑通,该内容本文省略。
重构增量模型网络之后,重新保存一个新的cpkt和冻结图,该网络仅包含一个MIMN单元,且记忆信息依赖外部输入进来,这是我们部署在线上真正需要的。
另外第二个需要注意点记忆信息的初始化,使用tf.placeholder_with_default实现,可以允许模型预测的时候传和不传,当用户冷启动时需要初始化状态

tf.placeholder_with_default(input, shape, name=None)
如果传入input,代表当该占位符没有传值时,就以这个input作为占位符的值,如果占位符传值时,就和普通的tf.placeholder功能一致

(3)实时消费Kafka预测

现在部署对接Kafka,MIMN的论文是将记忆更新和在线ctr预测分为连个独立的系统,而本案例直接将两个流程合并,一方式是因为召回集是固定的,另一方面本案例的序列更新并不频繁,来一个更新一个模型撑的住。采用Java代码读取冻结图进行部署,代码demo如下

public static Map<String, Double> oneMIMNCell(String userName, int seqNo) throws SQLException {
        int [][] inputNo = {{seqNo}};
        JSONObject prevState = MySQLUtils.getState(entName);
        JSONObject state = null;
        if (prevState.isEmpty()) {
            state = MIMN.getInstance().getMIMNRes(inputNo);
        } else {
            state = MIMN.getInstance().getMIMNRes(inputNo,
                    JSON.parseObject(prevState.getString("key_M"), float[][][].class),
                    JSON.parseObject(prevState.getString("M"), float[][][].class),
                    JSON.parseObject(prevState.getString("controller_state"), float[][].class),
                    JSON.parseObject(prevState.getString("read_vector_list"), float[][].class),
                    JSON.parseObject(prevState.getString("w_aggre"), float[][].class),
                    JSON.parseObject(prevState.getString("channel_rnn_state"), float[][][].class));
        }
        JSONObject stateInfo = new JSONObject();
        stateInfo.put("ent_name", entName);
        stateInfo.put("controller_state", state.getString("controller_state"));
        stateInfo.put("read_vector_list", state.getString("read_vector_list"));
        stateInfo.put("w_aggre", state.getString("w_aggre"));
        stateInfo.put("key_M", state.getString("key_M"));
        stateInfo.put("channel_rnn_state", state.getString("channel_rnn_state"));
        stateInfo.put("M", state.getString("M"));
        stateInfo.put("prob", state.getDouble("prob"));
        stateInfo.put("update_cnt", (prevState.containsKey("update_cnt") ? prevState.getIntValue("update_cnt") : 0) + 1);
        stateInfo.put("update_time", LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
        // 写入mysql
        MySQLUtils.saveState(stateInfo);
        // 日志信息
        Map<String, Double> map = new HashMap<>();
        map.put("score", state.getDouble("prob"));
        map.put("prev_score", prevState.getDouble("score"));
        map.put("score_inc", state.getDouble("prob") - (prevState.getDouble("score") == null ? 0.0 : prevState.getDouble("score")));
        return map;
    }

现在MySQL里面获取该用户最新的状态,如果为空则全部使用初始化状态,否则手动传入状态,计算完成后将结果在写到MySQL。
在增量之前需要先刷存量,每个用户的所有序列都要刷一遍增量MIMN网络到MySQL数据库,增量部分指定Kafka的offset进行消费,该offset对应存量的截止时间点,增量消费Kafka的代码demo如下

public static void main(String[] args) {
        Properties props = initConfig();
        KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props);
        String offsetString = Config.getString("seekOffset");
        consumer.subscribe(Collections.singletonList("topicName"));
        if (offsetString != null && !offsetString.equals("null")) {
            long offset = Long.parseLong(offsetString);
            consumer.poll(5000L);
            consumer.seek(new TopicPartition("topicName", 0), offset);
            LOGGER.info("offset从{}开始消费", offset);
        }

        while (true) {
            try {
                ConsumerRecords<String, String> records = consumer.poll(5000L);
                if (records != null && !records.isEmpty()) {
                    for (TopicPartition partition : records.partitions()) {
                        List<ConsumerRecord<String, String>> partitionRecords = records.records(partition);
                        long currentOffset = 0;
                        for (ConsumerRecord<String, String> record : partitionRecords) {
                            JSONObject json = JSONObject.parseObject((String) record.value());
                            JSONArray userRule = parseEntNameRule(json);
                            if (!userRule.isEmpty()) {
                                for (int i = 0; i < userRule.size(); i++) {
                                    String ruleNo = userRule.getJSONObject(i).getString("rule_no");
                                    String userName = userRule.getJSONObject(i).getString("ent_name");
                                    int ruleId = ruleIndex.containsKey(ruleNo) ? ruleIndex.get(ruleNo) : ruleIndex.get("UKN");
                                    Map<String, Double> map = MIMN.oneMIMNCell(userName, ruleId);
                                }
                            }
                            currentOffset = record.offset();
                        }
                        long lastoffset = currentOffset + 1;
                        consumer.seek(partition, lastoffset);
                        consumer.commitSync(Collections.singletonMap(partition, new OffsetAndMetadata(lastoffset)));
                    }
                }
            } catch (Exception e) {
                LOGGER.info("kafka消费异常!", e);
                System.exit(1);
            }
        }
    }

消费日志如下

2023-06-24 20:48:37 INFO  FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-1.3(4.7->3.4)
2023-06-24 20:48:37 INFO  FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数+0.2(22.0->22.2)
2023-06-24 20:50:02 INFO  FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数+1.3(19.7->21.0)
2023-06-24 20:50:02 INFO  FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-3.9(21.0->17.1)
2023-06-24 20:50:02 INFO  FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-5.2(17.1->11.9)

MySQL存储的记忆信息如下,矩阵存储为JSONArray

记忆信息存储
上一篇下一篇

猜你喜欢

热点阅读