通过TF Keras实现双塔模型

2021-11-23  本文已影响0人  郭彦超
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

tf.__version__
'2.4.1'

1、数据预处理

df = pd.read_csv("my_test.csv")
df.f_act.value_counts()
0    2112700
1      17616
Name: f_act, dtype: int64
df0 = df[df['f_act']==0].sample(frac=0.01)
df1 = df[df['f_act']==1]

df_sample = df1.append(df0)
df_sample.f_act.value_counts()
0    21127
1    17616
Name: f_act, dtype: int64
def add_index_column(param_df, column_name):
    values = list(param_df[column_name].unique())
    value_index_dict = {value:idx for idx,value in enumerate(values)}
    param_df[f"{column_name}_idx"] = param_df[column_name].map(value_index_dict)
    
add_index_column(df_sample, "product_total_pv_3d")
add_index_column(df_sample, "product_valid_work_users")
add_index_column(df_sample, "product_total_pv_7d")
add_index_column(df_sample, "f_user_grade_type")
add_index_column(df_sample, "product_price")
add_index_column(df_sample, "f_user_identity")

add_index_column(df_sample, "f_user_last_active_time")
add_index_column(df_sample, "user_product_common_valid_count")
add_index_column(df_sample, "user_score")
add_index_column(df_sample, "f_user_recharge_xd_90")
add_index_column(df_sample, "user_login_days_7")
add_index_column(df_sample, "user_id")
add_index_column(df_sample, "product_id")
num_product_total_pv_3d = df_sample["product_total_pv_3d_idx"].max() + 1
num_product_valid_work_users = df_sample["product_valid_work_users_idx"].max() + 1
num_product_total_pv_7d = df_sample["product_total_pv_7d_idx"].max() + 1
num_f_user_grade_type = df_sample["f_user_grade_type_idx"].max() + 1
num_product_price = df_sample["product_price_idx"].max() + 1
num_f_user_identity = df_sample["f_user_identity_idx"].max() + 1

num_f_user_last_active_time = df_sample["f_user_last_active_time_idx"].max() + 1
num_user_product_common_valid_count = df_sample["user_product_common_valid_count_idx"].max() + 1
num_user_score = df_sample["user_score_idx"].max() + 1
num_f_user_recharge_xd_90 = df_sample["f_user_recharge_xd_90_idx"].max() + 1
num_user_login_days_7 = df_sample["user_login_days_7_idx"].max() + 1
num_user_id = df_sample["user_id_idx"].max() + 1
num_product_id = df_sample["product_id_idx"].max() + 1
#df_sample = df.sample(frac=0.1)

y = df_sample.pop("f_act")
X = df_sample[["product_total_pv_3d_idx","product_valid_work_users_idx","product_total_pv_7d_idx","f_user_grade_type_idx","product_price_idx","f_user_identity_idx","f_user_last_active_time_idx","user_product_common_valid_count_idx","user_score_idx","f_user_recharge_xd_90_idx","user_login_days_7_idx","user_id_idx","product_id_idx" ]] 
X.columns
Index(['product_total_pv_3d_idx', 'product_valid_work_users_idx',
       'product_total_pv_7d_idx', 'f_user_grade_type_idx', 'product_price_idx',
       'f_user_identity_idx', 'f_user_last_active_time_idx',
       'user_product_common_valid_count_idx', 'user_score_idx',
       'f_user_recharge_xd_90_idx', 'user_login_days_7_idx', 'user_id_idx',
       'product_id_idx'],
      dtype='object')

2、构建双塔模型

def get_model():
    """函数式API搭建双塔DNN模型"""
    
    # 输入
    user_id = keras.layers.Input(shape=(1,), name="user_id")
    user_login_days_7 = keras.layers.Input(shape=(1,), name="user_login_days_7")
    f_user_recharge_xd_90 = keras.layers.Input(shape=(1,), name="f_user_recharge_xd_90")
    user_score = keras.layers.Input(shape=(1,), name="user_score")
    user_product_common_valid_count = keras.layers.Input(shape=(1,), name="user_product_common_valid_count")
    f_user_last_active_time = keras.layers.Input(shape=(1,), name="f_user_last_active_time")
    f_user_identity = keras.layers.Input(shape=(1,), name="f_user_identity")
    f_user_grade_type = keras.layers.Input(shape=(1,), name="f_user_grade_type")
    
    product_price = keras.layers.Input(shape=(1,), name="product_price")
    product_id = keras.layers.Input(shape=(1,), name="product_id")
    product_total_pv_3d = keras.layers.Input(shape=(1,), name="product_total_pv_3d")
    product_total_pv_7d = keras.layers.Input(shape=(1,), name="product_total_pv_7d")
    product_valid_work_users = keras.layers.Input(shape=(1,), name="product_valid_work_users")
    
    # user 塔, embedding部分可以考虑使用word2vec进行优化 https://stackoverflow.com/questions/58311682/how-to-concatenate-embeddings-with-variable-length-inputs-in-keras
    user_vector = tf.keras.layers.concatenate([
            layers.Embedding(num_user_id, 10000)(user_id), 
            layers.Embedding(num_user_login_days_7, 2)(user_login_days_7), 
            layers.Embedding(num_f_user_recharge_xd_90, 5)(f_user_recharge_xd_90),
            layers.Embedding(num_user_score, 10)(user_score), 
            layers.Embedding(num_user_product_common_valid_count, 20)(user_product_common_valid_count), 
            layers.Embedding(num_f_user_last_active_time, 20)(f_user_last_active_time),
            layers.Embedding(num_f_user_identity, 10)(f_user_identity), 
            layers.Embedding(num_f_user_grade_type, 5)(f_user_grade_type) 
    ])
    user_vector = layers.Dense(2048, activation='relu')(user_vector)
    user_vector = layers.Dense(512, activation='relu')(user_vector)
    user_vector = layers.Dense(128, activation='relu', 
                               name="user_embedding", kernel_regularizer='l2')(user_vector)

    # movie塔
    product_vector = tf.keras.layers.concatenate([
            layers.Embedding(num_product_id, 2000)(product_id), 
            layers.Embedding(num_product_price, 10)(product_price),
            layers.Embedding(num_product_total_pv_3d, 20)(product_total_pv_3d), 
            layers.Embedding(num_product_valid_work_users, 20)(product_valid_work_users), 
            layers.Embedding(num_product_total_pv_7d, 20)(product_total_pv_7d) 
    ])
    product_vector = layers.Dense(2048, activation='relu')(product_vector)
    product_vector = layers.Dense(512, activation='relu')(product_vector)
    product_vector = layers.Dense(128, activation='relu', 
                                name="product_embedding", kernel_regularizer='l2')(product_vector)

    # 每个用户的embedding和item的embedding作点积
    dot_user_product = tf.reduce_sum(user_vector*product_vector, axis = 1)
    dot_user_product = tf.expand_dims(dot_user_product, 1)

    output = layers.Dense(1, activation='sigmoid')(dot_user_product)
    
    return keras.models.Model(inputs=[user_id, user_login_days_7, f_user_recharge_xd_90, user_score, user_product_common_valid_count, f_user_last_active_time, f_user_identity, f_user_grade_type, product_id, product_price, product_total_pv_3d, product_valid_work_users, product_total_pv_7d], outputs=[output]) 

3、模型训练

model = get_model()
print(model.input)
# model.compile(loss=tf.keras.losses.MeanSquaredError(), 
#               optimizer=keras.optimizers.RMSprop())
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
fit_x_train = [
        X["user_id_idx"], 
        X["user_login_days_7_idx"],
        X["f_user_recharge_xd_90_idx"],
        X["user_score_idx"],
        X["user_product_common_valid_count_idx"],
        X["f_user_last_active_time_idx"],
        X["f_user_identity_idx"], 
        X["f_user_grade_type_idx"],
        X["product_id_idx"],
        X["product_price_idx"],
        X["product_total_pv_3d_idx"],
        X["product_valid_work_users_idx"],
        X["product_total_pv_7d_idx"]
    ]

from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs/logs_"+TIMESTAMP)

history = model.fit(
    x=fit_x_train,
    y=y,
    batch_size=32,
    epochs=5,
    verbose=1,
    callbacks=[tensorboard_callback]
)
history
[<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'user_id')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'user_login_days_7')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'f_user_recharge_xd_90')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'user_score')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'user_product_common_valid_count')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'f_user_last_active_time')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'f_user_identity')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'f_user_grade_type')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'product_id')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'product_price')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'product_total_pv_3d')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'product_valid_work_users')>, <KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'product_total_pv_7d')>]
Epoch 1/5
909/909 [==============================] - 604s 663ms/step - loss: 1.1246 - accuracy: 0.6392
Epoch 2/5
909/909 [==============================] - 613s 674ms/step - loss: 0.3892 - accuracy: 0.8272
Epoch 3/5
909/909 [==============================] - 610s 671ms/step - loss: 0.2689 - accuracy: 0.8650
Epoch 4/5
909/909 [==============================] - 608s 669ms/step - loss: 0.2350 - accuracy: 0.8742
Epoch 5/5
909/909 [==============================] - 612s 673ms/step - loss: 0.2082 - accuracy: 0.8870

# 打印模型结构
model.summary()
Model: "model_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
user_id (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
user_login_days_7 (InputLayer)  [(None, 1)]          0                                            
__________________________________________________________________________________________________
f_user_recharge_xd_90 (InputLay [(None, 1)]          0                                            
__________________________________________________________________________________________________
user_score (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
user_product_common_valid_count [(None, 1)]          0                                            
__________________________________________________________________________________________________
f_user_last_active_time (InputL [(None, 1)]          0                                            
__________________________________________________________________________________________________
f_user_identity (InputLayer)    [(None, 1)]          0                                            
__________________________________________________________________________________________________
f_user_grade_type (InputLayer)  [(None, 1)]          0                                            
__________________________________________________________________________________________________
product_id (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
product_price (InputLayer)      [(None, 1)]          0                                            
__________________________________________________________________________________________________
product_total_pv_3d (InputLayer [(None, 1)]          0                                            
__________________________________________________________________________________________________
product_valid_work_users (Input [(None, 1)]          0                                            
__________________________________________________________________________________________________
product_total_pv_7d (InputLayer [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding_52 (Embedding)        (None, 1, 10000)     38450000    user_id[0][0]                    
__________________________________________________________________________________________________
embedding_53 (Embedding)        (None, 1, 2)         18          user_login_days_7[0][0]          
__________________________________________________________________________________________________
embedding_54 (Embedding)        (None, 1, 5)         35          f_user_recharge_xd_90[0][0]      
__________________________________________________________________________________________________
embedding_55 (Embedding)        (None, 1, 10)        3030        user_score[0][0]                 
__________________________________________________________________________________________________
embedding_56 (Embedding)        (None, 1, 20)        8440        user_product_common_valid_count[0
__________________________________________________________________________________________________
embedding_57 (Embedding)        (None, 1, 20)        400         f_user_last_active_time[0][0]    
__________________________________________________________________________________________________
embedding_58 (Embedding)        (None, 1, 10)        40          f_user_identity[0][0]            
__________________________________________________________________________________________________
embedding_59 (Embedding)        (None, 1, 5)         25          f_user_grade_type[0][0]          
__________________________________________________________________________________________________
embedding_60 (Embedding)        (None, 1, 2000)      29008000    product_id[0][0]                 
__________________________________________________________________________________________________
embedding_61 (Embedding)        (None, 1, 10)        90          product_price[0][0]              
__________________________________________________________________________________________________
embedding_62 (Embedding)        (None, 1, 20)        3400        product_total_pv_3d[0][0]        
__________________________________________________________________________________________________
embedding_63 (Embedding)        (None, 1, 20)        17460       product_valid_work_users[0][0]   
__________________________________________________________________________________________________
embedding_64 (Embedding)        (None, 1, 20)        6180        product_total_pv_7d[0][0]        
__________________________________________________________________________________________________
concatenate_8 (Concatenate)     (None, 1, 10072)     0           embedding_52[0][0]               
                                                                 embedding_53[0][0]               
                                                                 embedding_54[0][0]               
                                                                 embedding_55[0][0]               
                                                                 embedding_56[0][0]               
                                                                 embedding_57[0][0]               
                                                                 embedding_58[0][0]               
                                                                 embedding_59[0][0]               
__________________________________________________________________________________________________
concatenate_9 (Concatenate)     (None, 1, 2070)      0           embedding_60[0][0]               
                                                                 embedding_61[0][0]               
                                                                 embedding_62[0][0]               
                                                                 embedding_63[0][0]               
                                                                 embedding_64[0][0]               
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 1, 2048)      20629504    concatenate_8[0][0]              
__________________________________________________________________________________________________
dense_20 (Dense)                (None, 1, 2048)      4241408     concatenate_9[0][0]              
__________________________________________________________________________________________________
dense_19 (Dense)                (None, 1, 512)       1049088     dense_18[0][0]                   
__________________________________________________________________________________________________
dense_21 (Dense)                (None, 1, 512)       1049088     dense_20[0][0]                   
__________________________________________________________________________________________________
user_embedding (Dense)          (None, 1, 128)       65664       dense_19[0][0]                   
__________________________________________________________________________________________________
product_embedding (Dense)       (None, 1, 128)       65664       dense_21[0][0]                   
__________________________________________________________________________________________________
tf.math.multiply_4 (TFOpLambda) (None, 1, 128)       0           user_embedding[0][0]             
                                                                 product_embedding[0][0]          
__________________________________________________________________________________________________
tf.math.reduce_sum_4 (TFOpLambd (None, 128)          0           tf.math.multiply_4[0][0]         
__________________________________________________________________________________________________
tf.expand_dims_4 (TFOpLambda)   (None, 1, 128)       0           tf.math.reduce_sum_4[0][0]       
__________________________________________________________________________________________________
dense_22 (Dense)                (None, 1, 1)         129         tf.expand_dims_4[0][0]           
==================================================================================================
Total params: 94,597,663
Trainable params: 94,597,663
Non-trainable params: 0
__________________________________________________________________________________________________

4、模型预估

inputs = df_sample.sample(frac=0.01)[
    ["user_id_idx","user_login_days_7_idx","f_user_recharge_xd_90_idx","user_score_idx","user_product_common_valid_count_idx", "f_user_last_active_time_idx", "f_user_identity_idx", "f_user_grade_type_idx", "product_id_idx", "product_price_idx", "product_total_pv_3d_idx", "product_valid_work_users_idx", "product_total_pv_7d_idx"]].head(10)

 
model.predict([
        inputs["user_id_idx"], 
        inputs["user_login_days_7_idx"],
        inputs["f_user_recharge_xd_90_idx"],
        inputs["user_score_idx"],
        inputs["user_product_common_valid_count_idx"],
        inputs["f_user_last_active_time_idx"],
        inputs["f_user_identity_idx"], 
        inputs["f_user_grade_type_idx"],
        inputs["product_id_idx"],
        inputs["product_price_idx"],
        inputs["product_total_pv_3d_idx"],
        inputs["product_valid_work_users_idx"],
        inputs["product_total_pv_7d_idx"]
    ])

array([[[0.09251767]],

       [[0.00407892]],

       [[0.07737359]],

       [[0.67664975]],

       [[0.6161837 ]],

       [[0.5397055 ]],

       [[0.43793055]],

       [[0.18323252]],

       [[0.06416044]],

       [[0.5696974 ]]], dtype=float32)
inputs = df_sample.sample(frac=0.01)[
    ["f_act","user_id_idx","user_login_days_7_idx","f_user_recharge_xd_90_idx","user_score_idx","user_product_common_valid_count_idx", "f_user_last_active_time_idx", "f_user_identity_idx", "f_user_grade_type_idx", "product_id_idx", "product_price_idx", "product_total_pv_3d_idx", "product_valid_work_users_idx", "product_total_pv_7d_idx"]].head(10)

 
score = model.evaluate([
        inputs["user_id_idx"], 
        inputs["user_login_days_7_idx"],
        inputs["f_user_recharge_xd_90_idx"],
        inputs["user_score_idx"],
        inputs["user_product_common_valid_count_idx"],
        inputs["f_user_last_active_time_idx"],
        inputs["f_user_identity_idx"], 
        inputs["f_user_grade_type_idx"],
        inputs["product_id_idx"],
        inputs["product_price_idx"],
        inputs["product_total_pv_3d_idx"],
        inputs["product_valid_work_users_idx"],
        inputs["product_total_pv_7d_idx"]
    ], inputs["f_act"] )
score 

5、导出User塔和Item塔

user_layer_model = keras.models.Model(
    inputs=[model.input[0], model.input[1], model.input[2], model.input[3], model.input[4], model.input[5], model.input[6], model.input[7]],
    outputs=model.get_layer("user_embedding").output
)

user_embeddings = []
#简单处理,这里没有对用户和商品进行去重
for index, row in df_sample.sample(frac=0.01).iterrows():
    user_id = row["user_id"]
    user_input = [
        np.reshape(row["user_id_idx"], [1,1]),
        np.reshape(row["user_login_days_7_idx"], [1,1]),
        np.reshape(row["f_user_recharge_xd_90_idx"], [1,1]),
        np.reshape(row["user_score_idx"], [1,1]),
        np.reshape(row["user_product_common_valid_count_idx"], [1,1]),
        np.reshape(row["f_user_last_active_time_idx"], [1,1]),
        np.reshape(row["f_user_identity_idx"], [1,1]),
        np.reshape(row["f_user_grade_type_idx"], [1,1])
    ]
    user_embedding = user_layer_model(user_input)
    
    embedding_str = ",".join([str(x) for x in user_embedding.numpy().flatten()])
    user_embeddings.append([user_id, embedding_str])
    
df_user_embedding = pd.DataFrame(user_embeddings, columns = ["user_id", "user_embedding"])
df_user_embedding.head()
 

上一篇 下一篇

猜你喜欢

热点阅读