keras 模型代码片段

2022-03-29  本文已影响0人  FreeTheWorld

记录常用的一些模型code,备忘。

1 双塔模型内部实现 batch in sample

class SampleInBatch(tf.keras.layers.Layer):
    def __init__(self, name):
        super(SampleInBatch, self).__init__(name=name)
        self.batch_indices = [i for i in range(GLOBAL_BATCH_SIZE)]

    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()

        if training:
            # 注意原始向量对齐
            usr_vec, item_vec = inputs
            indices1, indices2 = tf.random.shuffle(self.batch_indices), tf.random.shuffle(self.batch_indices)
            item_vec = tf.concat([item_vec,
                                 tf.gather(item_vec, axis=0, indices=indices1),
                                 tf.gather(item_vec, axis=0, indices=indices2)], axis=0)

            usr_vec = tf.tile(usr_vec, [3, 1]) 
            return usr_vec, item_vec
        return inputs

要点

2 Deep Cross NetWork 实现

class Cross(tf.keras.layers.Layer):
    def __init__(self, projection_dim=None):
        super(Cross, self).__init__()
        # projection_dim用来降低模型参数量,同时又不影响效果,将权重矩阵分解为两个小矩阵
        self._projection_dim = projection_dim

    def build(self, input_shape):
        last_dim = input_shape[-1]

        if self._projection_dim is None:
            self._dense = tf.keras.layers.Dense(last_dim, use_bias=True)
        else:
            self._dense_u = tf.keras.layers.Dense(self._projection_dim, use_bias=False)
            self._dense_v = tf.keras.layers.Dense(last_dim, use_bias=True)
        self.built = True

    def call(self, x0, x=None):
        if x is None:
            x = x0
        if self._projection_dim is None:
            prod_output = self._dense(x)
        else:
            prod_output = self._dense_v(self._dense_u(x))

        return x0 * prod_output + x


def cross_module(inputs):
    x0 = layers.concatenate(inputs)
    x1 = Cross()(x0, x0)
    x2 = Cross()(x0, x1)
    return x2

参考:
https://github.com/tensorflow/recommenders/blob/v0.6.0/tensorflow_recommenders/layers/feature_interaction/dcn.py#L23-L194
https://arxiv.org/pdf/2008.13535.pdf

(待更...)

上一篇下一篇

猜你喜欢

热点阅读