行为序列建模:MIMN系列2——消费Kafka实时预测代码实战
关键词:MIMN
,Kafka
内容摘要
- MIMN原理整体提要解析(系列1已完结)
- MIMN源码速览(系列1已完结)
- MIMN中参数维护方式总结(系列1已完结)
- 在风控场景下,MIMN的训练,部署代码实战
本文主要是MIMN代码实战,理论部分见行为序列建模:MIMN系列1——原理初探和源码解析
业务背景交代
用户的行为序列不仅可以用户表征用户的兴趣,从而对用户进行商品/广告推荐,也可以表征出用户的行为模式,从而对该模式进行打分形成风控预警。本文举例的业务背景为,用户的行为标签序列,如
user: behavior_label_1, behavior_label_2, behavior_label_3, behavior_label_4 ...
所有行为标签根据时间的先后顺序排开,预测用户是否会在未来达成某种风险行为标签,比如risk_ behavior_label_1,risk_ behavior_label_2,risk_ behavior_label_3...中的任何一个,或者其他Y值标签,最终的目标是基于用户的历史序列表征出embedding,然后预测给出用户的坏率。传统的做法是对用户序列进行截断和padding,比如最近1年最大100个事件,灌入类似LSTM进行二分类预测,而本次实战就是要基于MIMN实现,允许用户历史所有事件序列,并且能够对接事件消息中间件实现实时预测。
MIMN如何运用于风控
风控业务和MIMN的推荐业务从模型任务上基本相同都是二分类,实际上推荐,风控从算法任务上是相通的,不少推荐的论文都被应用在风控中,比如MMOE等算法,而在MIMN场景,推荐和风控是有一个明显区别的,推荐场景的MIMN是有一组候选商品的(Target AD),给出最高得分的候选商品输出给用户推荐,而风控场景没有候选,如果没有候选则没有基于M矩阵的读头输出,没有DIN Attention的标的,因此风控场景下MIMN对应的这个输入实体应该是类似风险行为标签的,比如risk_ behavior_label_1,risk_ behavior_label_2,risk_ behavior_label_3...中的任何一个,用风险行为标签去记忆里面读输出,基于注意力提取历史信息中和风险行为标签相关的信息,本案例采用的策略是所有风险标签的embedding求平均值作为Target AD,因此候选商品是定死不变的,召回是固定的,从而一蹴而就模型可以实时预测,当用户有新的行为进入时,更新S和M矩阵,然后Target AD再次输入MIMN单元直接预测,及用户每有一个新的序列进来,模型的评分都可随之响应。
MIMN代码实战
本案例的目标是MIMN对接Kafka中的实时序列标签,实现实时预测,其中S和M的记忆信息存储到MySQL(或者其他外部存储都行),所有的行为标签都已经经过word2vec进行embedding初始化,另外本案例取消了最后全连接层除了S和M矩阵的其他上下文特征,相当于最后只有M读头的输出和S和targat的DIN Attention的拼接进入分类模型,在字节跳动的技术博客里面也是仅有这两个输入取消了其他内容,如图
(1)MIMN单元源码改造
MIMN的源码写的质量很高,仅有一个地方需要增加点内容,需要在MIMN单元的call里面允许输入一个S矩阵,以及输出一个S矩阵,代码如下
def __call__(self, x, prev_state, prev_s_state=None):
prev_read_vector_list = prev_state["read_vector_list"]
# TODO S矩阵
if prev_s_state:
self.channel_rnn_state = prev_s_state["channel_rnn_state"]
# TODO [[batch_size, 32],[batch_size, 32],[batch_size, 32],[batch_size, 32]]
self.channel_rnn_state = [tf.squeeze(x, axis=0) for x in tf.split(self.channel_rnn_state, self.memory_size)]
如上增加了prev_s_state,赋值给channel_rnn_state,代替和源码中类实例化中的初始化
self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]
在训练阶段还是使用初始化的channel_rnn_state,但是在部署阶段手动传入S矩阵的信息,否则S矩阵无法和外部进行交互。同样的在call输出里面也增加channel_rnn_state
return read_output, {
"controller_state": controller_state,
"read_vector_list": read_vector_list,
"w_list": w_list,
"M": M,
"key_M": key_M,
"w_aggre": w_aggre
}, output_list, self.channel_rnn_state
这样S矩阵可传入和输出,从而可以和外部存储交互,摆脱tensorflow图。
(2)增量预测代码新增
该部分是核心,增量预测是模型应用阶段,训练部分的任务仅仅是训练出各种GRU和全连接的参数固定在tensorflow图上,增量部分涉及M和S,以及MIMN单元的交互运用。
该部分的工作是将训练的图上200个MIMN拍成的MIMN序列,从中抽出一个MIMN单元,预测阶段仅对一个标签进行预测,另外所有需要模型学习初始化的参数都要和训练的图保持一致,使用tensorflow的tf.train.init_from_checkpoint
实现网络重构
tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map)
该方法传入一个检查点路径,和一个映射,实现将在检查点图里面的变量传递到新图上的同名变量,其中映射记录了检查点中的变量名和新图中变量名的对应关系
增量部分代码如下
class Model_MIMN:
"""
将存量MIMN网络中的全连接,din,MIU和NTM的参数拿出来,单独对一个序列元素进行更新state和channel_rnn_state
"""
def __init__(self, seq_length, seq_embedding_size, basic_embedding_size, memory_size, memory_vector_dim,
mem_induction=1, util_reg=1, rnn_hidden_size=16, batch_size=1, read_head_num=1):
# TODO 部署阶段序列长度1,实时更新
self.input_seq = tf.placeholder(tf.int32, [batch_size, seq_length], name="input_seq")
with tf.variable_scope("embedding"):
# TODO 从网络恢复
self.rule_embedding_var = tf.get_variable("embedding", dtype=tf.float32, shape=[350, seq_embedding_size],
trainable=False)
# [None, 1, 32]
self.rule_embedding = tf.nn.embedding_lookup(self.rule_embedding_var, self.input_seq)
# [emb_size]
# TODO 从网络恢复
self.target_rule_embedding_mean = tf.get_variable("target_embedding", shape=[seq_embedding_size],
trainable=False)
# [1, emb_size]
self.target_rule_batch = tf.expand_dims(self.target_rule_embedding_mean, 0)
with tf.variable_scope('init'):
# TODO 第一个序列元素进来的时候需要使用模型训练得到的初始化
self.read_vector_list = [
expand(tf.tanh(learned_init(memory_vector_dim)), dim=0, N=batch_size)
for i in range(read_head_num)]
self.M = expand(
tf.tanh(tf.get_variable('init_M', [memory_size, memory_vector_dim],
initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
trainable=False)),
dim=0, N=batch_size)
self.key_M = expand(
tf.tanh(tf.get_variable('key_M', [memory_size, memory_vector_dim],
initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))),
dim=0, N=batch_size)
# TODO m矩阵和s矩阵全部从外部输入
self.prev_m_state = {
"controller_state": tf.placeholder_with_default(tf.zeros(shape=[batch_size, memory_vector_dim]),
[batch_size, memory_vector_dim],
name="prev_controller_state"),
"read_vector_list": [
tf.placeholder_with_default(self.read_vector_list[0], [batch_size, memory_vector_dim],
name="prev_read_vector_list")],
"M": tf.placeholder_with_default(self.M, [batch_size, memory_size, memory_vector_dim],
name="prev_M"),
"w_aggre": tf.placeholder_with_default(tf.zeros(shape=[batch_size, memory_size]),
[batch_size, memory_size], name="prev_w_aggre"),
"key_M": tf.placeholder_with_default(self.key_M, [batch_size, memory_size, memory_vector_dim],
name="prev_key_M")}
self.prev_s_state = {"channel_rnn_state": tf.placeholder_with_default(
tf.zeros(shape=[memory_size, batch_size, memory_vector_dim]), [memory_size, batch_size, memory_vector_dim],
name="prev_channel_rnn_state")}
# 其他参数
self.seq_embedding_size = seq_embedding_size
self.rnn_hidden_size = rnn_hidden_size
self.seq_length = seq_length
self.reg = util_reg
self.memory_size = memory_size
self.mem_induction = mem_induction
self.util_reg = util_reg
# TODO 从网络恢复一个MIMN单元
self.cell = MIMNCell(controller_units=self.rnn_hidden_size, memory_size=self.memory_size,
memory_vector_dim=self.seq_embedding_size, read_head_num=1, write_head_num=1,
reuse=False, output_dim=self.rnn_hidden_size, clip_value=20,
mem_induction=self.mem_induction, util_reg=self.util_reg, batch_size=batch_size)
# TODO 新进来一个元素之后更新state,输出不要
_, self.m_state, temp_output_list, self.s_state = self.cell(self.rule_embedding[:, 0, :],
self.prev_m_state, self.prev_s_state)
# 把所有状态拆出来,否则在图pb拿不到
self.controller_state = tf.identity(self.m_state["controller_state"], name="controller_state")
self.read_vector_list = tf.identity(self.m_state["read_vector_list"][0], name="read_vector_list")
self.M = tf.identity(self.m_state["M"], name="M")
self.key_M = tf.identity(self.m_state["key_M"], name="key_M")
self.w_aggre = tf.identity(self.m_state["w_aggre"], name="w_aggre")
self.channel_rnn_state = tf.identity(tf.concat([tf.expand_dims(x, axis=0) for x in self.s_state], axis=0),
name="channel_rnn_state")
# TODO 再将target输入一遍,拿到read_out,state不要
read_out, _, _, _ = self.cell(self.target_rule_batch, self.m_state, {"channel_rnn_state": self.s_state})
if self.mem_induction == 1:
channel_memory_tensor = tf.concat(temp_output_list, 1)
multi_channel_hist = din_attention(self.target_rule_batch, channel_memory_tensor, self.rnn_hidden_size,
None, stag='pal')
inp = tf.concat([read_out, tf.squeeze(multi_channel_hist, axis=1)], 1)
else:
inp = tf.concat([read_out], 1)
# fc
bn1 = inp
# 一层
dnn1 = tf.layers.dense(bn1, 200, activation=None, name='f1')
dnn1 = tf.nn.relu(dnn1)
# 二层
dnn2 = tf.layers.dense(dnn1, 80, activation=None, name='f2')
dnn2 = tf.nn.relu(dnn2)
# 输出维度2,1和0
dnn3 = tf.layers.dense(dnn2, 2, activation=None, name='f3')
# 输出
self.y_hat = tf.nn.softmax(dnn3) + 0.00000001
self.prob = tf.identity(self.y_hat, name="prob")
其中所有初始化基本照搬源码,从而可以保证原图中的变量命名和新图一致,另外所有的记忆信息全部使用tf.identity
重命名,保证在图里面可以将这些变量取出来,从而可以存储到数据库。
使用以下初始化方法实例化增量模型网络
assignment_map = {
# embedding
"embedding/embedding": "embedding/embedding",
"embedding/target_embedding": "embedding/target_embedding",
# NTM controller
"controller/gru_cell/gates/kernel": "controller/gru_cell/gates/kernel",
"controller/gru_cell/gates/bias": "controller/gru_cell/gates/bias",
"controller/gru_cell/candidate/kernel": "controller/gru_cell/candidate/kernel",
"controller/gru_cell/candidate/bias": "controller/gru_cell/candidate/bias",
...
}
seq_embedding_size = 32
rnn_hidden_size = seq_embedding_size
memory_vector_dim = seq_embedding_size
tf.reset_default_graph()
model = Model_MIMN(seq_length=1, seq_embedding_size=seq_embedding_size, basic_embedding_size=80,
memory_size=8, memory_vector_dim=32, mem_induction=1, util_reg=1,
rnn_hidden_size=rnn_hidden_size, batch_size=1)
with tf.Session() as sess:
tf.train.init_from_checkpoint(os.path.join(ROOT_PATH, "ckpt"), assignment_map)
sess.run(tf.global_variables_initializer())
其中ckpt是训练代码输出的ckpt文件,训练部分代码为源码给出,稍作修改即可在本业务数据上跑通,该内容本文省略。
重构增量模型网络之后,重新保存一个新的cpkt和冻结图,该网络仅包含一个MIMN单元,且记忆信息依赖外部输入进来,这是我们部署在线上真正需要的。
另外第二个需要注意点记忆信息的初始化,使用tf.placeholder_with_default
实现,可以允许模型预测的时候传和不传,当用户冷启动时需要初始化状态
tf.placeholder_with_default(input, shape, name=None)
如果传入input,代表当该占位符没有传值时,就以这个input作为占位符的值,如果占位符传值时,就和普通的tf.placeholder功能一致
(3)实时消费Kafka预测
现在部署对接Kafka,MIMN的论文是将记忆更新和在线ctr预测分为连个独立的系统,而本案例直接将两个流程合并,一方式是因为召回集是固定的,另一方面本案例的序列更新并不频繁,来一个更新一个模型撑的住。采用Java代码读取冻结图进行部署,代码demo如下
public static Map<String, Double> oneMIMNCell(String userName, int seqNo) throws SQLException {
int [][] inputNo = {{seqNo}};
JSONObject prevState = MySQLUtils.getState(entName);
JSONObject state = null;
if (prevState.isEmpty()) {
state = MIMN.getInstance().getMIMNRes(inputNo);
} else {
state = MIMN.getInstance().getMIMNRes(inputNo,
JSON.parseObject(prevState.getString("key_M"), float[][][].class),
JSON.parseObject(prevState.getString("M"), float[][][].class),
JSON.parseObject(prevState.getString("controller_state"), float[][].class),
JSON.parseObject(prevState.getString("read_vector_list"), float[][].class),
JSON.parseObject(prevState.getString("w_aggre"), float[][].class),
JSON.parseObject(prevState.getString("channel_rnn_state"), float[][][].class));
}
JSONObject stateInfo = new JSONObject();
stateInfo.put("ent_name", entName);
stateInfo.put("controller_state", state.getString("controller_state"));
stateInfo.put("read_vector_list", state.getString("read_vector_list"));
stateInfo.put("w_aggre", state.getString("w_aggre"));
stateInfo.put("key_M", state.getString("key_M"));
stateInfo.put("channel_rnn_state", state.getString("channel_rnn_state"));
stateInfo.put("M", state.getString("M"));
stateInfo.put("prob", state.getDouble("prob"));
stateInfo.put("update_cnt", (prevState.containsKey("update_cnt") ? prevState.getIntValue("update_cnt") : 0) + 1);
stateInfo.put("update_time", LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
// 写入mysql
MySQLUtils.saveState(stateInfo);
// 日志信息
Map<String, Double> map = new HashMap<>();
map.put("score", state.getDouble("prob"));
map.put("prev_score", prevState.getDouble("score"));
map.put("score_inc", state.getDouble("prob") - (prevState.getDouble("score") == null ? 0.0 : prevState.getDouble("score")));
return map;
}
现在MySQL里面获取该用户最新的状态,如果为空则全部使用初始化状态,否则手动传入状态,计算完成后将结果在写到MySQL。
在增量之前需要先刷存量,每个用户的所有序列都要刷一遍增量MIMN网络到MySQL数据库,增量部分指定Kafka的offset进行消费,该offset对应存量的截止时间点,增量消费Kafka的代码demo如下
public static void main(String[] args) {
Properties props = initConfig();
KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props);
String offsetString = Config.getString("seekOffset");
consumer.subscribe(Collections.singletonList("topicName"));
if (offsetString != null && !offsetString.equals("null")) {
long offset = Long.parseLong(offsetString);
consumer.poll(5000L);
consumer.seek(new TopicPartition("topicName", 0), offset);
LOGGER.info("offset从{}开始消费", offset);
}
while (true) {
try {
ConsumerRecords<String, String> records = consumer.poll(5000L);
if (records != null && !records.isEmpty()) {
for (TopicPartition partition : records.partitions()) {
List<ConsumerRecord<String, String>> partitionRecords = records.records(partition);
long currentOffset = 0;
for (ConsumerRecord<String, String> record : partitionRecords) {
JSONObject json = JSONObject.parseObject((String) record.value());
JSONArray userRule = parseEntNameRule(json);
if (!userRule.isEmpty()) {
for (int i = 0; i < userRule.size(); i++) {
String ruleNo = userRule.getJSONObject(i).getString("rule_no");
String userName = userRule.getJSONObject(i).getString("ent_name");
int ruleId = ruleIndex.containsKey(ruleNo) ? ruleIndex.get(ruleNo) : ruleIndex.get("UKN");
Map<String, Double> map = MIMN.oneMIMNCell(userName, ruleId);
}
}
currentOffset = record.offset();
}
long lastoffset = currentOffset + 1;
consumer.seek(partition, lastoffset);
consumer.commitSync(Collections.singletonMap(partition, new OffsetAndMetadata(lastoffset)));
}
}
} catch (Exception e) {
LOGGER.info("kafka消费异常!", e);
System.exit(1);
}
}
}
消费日志如下
2023-06-24 20:48:37 INFO FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-1.3(4.7->3.4)
2023-06-24 20:48:37 INFO FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数+0.2(22.0->22.2)
2023-06-24 20:50:02 INFO FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数+1.3(19.7->21.0)
2023-06-24 20:50:02 INFO FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-3.9(21.0->17.1)
2023-06-24 20:50:02 INFO FlowJob:86 - xxx 更新成功,行为标签:xxx 导致分数-5.2(17.1->11.9)
MySQL存储的记忆信息如下,矩阵存储为JSONArray
记忆信息存储