Elasticsearch分布式搜索引擎elasticsearch程序员

基于word2vec和Elasticsearch实现个性化搜索

2017-03-28  本文已影响1620人  ginobefun

word2vec学习小记一文中我们曾经学习了word2vec这个工具,它基于神经网络语言模型并在其基础上进行优化,最终能获取词向量和语言模型。在我们的商品搜索系统里,采用了word2vec的方式来计算用户向量和商品向量,并通过Elasticsearch的function_score评分机制和自定义的脚本插件来实现个性化搜索。

背景介绍

先来看下维基百科上对于个性化搜索的定义和介绍:

Personalized search refers to web search experiences that are tailored specifically to an individual's interests by incorporating information about the individual beyond specific query provided. Pitkow et al. describe two general approaches to personalizing search results, one involving modifying the user's query and the other re-ranking search results.

由此我们可以得到两个重要的信息:

  1. 个性化搜索需要充分考虑到用户的偏好,将用户感兴趣的内容优先展示给用户;
  2. 另外是对于实现个性化的方式上主要有查询修改和对搜索结果的重排序两种。

而对我们电商网站来说,个性化搜索的重点是当用户搜索某个关键字,如【卫衣】时,能将用户最感兴趣最可能购买的商品(如用户偏好的品牌或款式)优先展示给用户,以提升用户体验和点击转化。

设计思路

  1. 在此之前我们曾经有一般的个性化搜索实现,其主要是通过计算用户和商品的一些重要属性(比如品牌、品类、性别等)的权重,然后得到一个用户和商品之间的关联系数,然后根据该系数进行重排序。
  2. 但是这一版从效果来看并不是很好,我个人觉得主要的原因有以下几点:用户对商品的各个属性的重视程度并不是一样的,另外考虑的商品的属性并不全,且没有去考虑商品和商品直接的关系;
  3. 在新的版本的设计中,我们考虑通过用户的浏览记录这种时序数据来获取用户和商品以及商品和商品直接的关联关系,其核心就是通过类似于语言模型的词出现的顺序来训练向量表示结果;
  4. 在获取用户向量和商品向量表示后,我们就可以根据向量直接的距离来计算相关性,从而将用户感兴趣的商品优先展示;

实现细节

商品向量的计算

select concat_ws(' ', collect_set(product_skn)) as skns 
from 
 (select uid, cast(product_skn as string) as product_skn, click_time_stamp 
  from product_click_record 
  where date_id <= $date_id and date_id >= $date_id_30_day_ago
  order by uid, click_time_stamp) as a 
group by uid;
time ./word2vec -train $prepare_file -output $result_file -cbow 1 -size 20 
-window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -iter 15 

用户向量的计算

vec_list = []
for i in range(feature_length):
    vec_list.append("avg(coalesce(b.vec[%s], 0))" % (str(i)))
vec = ', '.join(vec_list)


select a.uid as uid, array(%s) as vec 
from 
 (select * from product_click_record where date_id <= $date_id and date_id >= $date_id_30_day_ago) as a
left outer join
 (select * from product_w2v where date_id = $date_id) as b
on a.product_skn = b.product_skn
group by a.uid;

搜索服务时增加个性化评分

Map<String, Object> scriptParams = new HashMap<>();
scriptParams.put("field", "productFeatureVector");
scriptParams.put("inputFeatureVector", userVector);
scriptParams.put("version", version);
Script script = new Script("feature_vector_scoring_script", ScriptService.ScriptType.INLINE, "native", scriptParams);
functionScoreQueryBuilder.add(ScoreFunctionBuilders.scriptFunction(script));

elasticsearch-feature-vector-scoring插件

这是我自己写的一个插件,具体的使用可以看下项目主页,其核心也就一个类,我将其主要的代码和注释贴一下:

public class FeatureVectorScoringSearchScript extends AbstractSearchScript {
    public static final ESLogger LOGGER = Loggers.getLogger("feature-vector-scoring");
    public static final String SCRIPT_NAME = "feature_vector_scoring_script";
    private static final double DEFAULT_BASE_CONSTANT = 1.0D;
    private static final double DEFAULT_FACTOR_CONSTANT = 1.0D;

    // field in index to store feature vector
    private String field;

    // version of feature vector, if it isn't null, it should match version of index
    private String version;

    // final_score = baseConstant + factorConstant * cos(X, Y)
    private double baseConstant;

    // final_score = baseConstant + factorConstant * cos(X, Y)
    private double factorConstant;

    // input feature vector
    private double[] inputFeatureVector;

    // cos(X, Y) = Σ(Xi * Yi) / ( sqrt(Σ(Xi * Xi)) * sqrt(Σ(Yi * Yi)) )
    // the inputFeatureVectorNorm is sqrt(Σ(Xi * Xi))
    private double inputFeatureVectorNorm;

    public static class ScriptFactory implements NativeScriptFactory {
        @Override
        public ExecutableScript newScript(@Nullable Map<String, Object> params) throws ScriptException {
            return new FeatureVectorScoringSearchScript(params);
        }

        @Override
        public boolean needsScores() {
            return false;
        }
    }

    private FeatureVectorScoringSearchScript(Map<String, Object> params) throws ScriptException {
        this.field = (String) params.get("field");
        String inputFeatureVectorStr = (String) params.get("inputFeatureVector");
        if (this.field == null || inputFeatureVectorStr == null || inputFeatureVectorStr.trim().length() == 0) {
            throw new ScriptException("Initialize script " + SCRIPT_NAME + " failed!");
        }

        this.version = (String) params.get("version");
        this.baseConstant = params.get("baseConstant") != null ? Double.parseDouble(params.get("baseConstant").toString()) : DEFAULT_BASE_CONSTANT;
        this.factorConstant = params.get("factorConstant") != null ? Double.parseDouble(params.get("factorConstant").toString()) : DEFAULT_FACTOR_CONSTANT;

        String[] inputFeatureVectorArr = inputFeatureVectorStr.split(",");
        int dimension = inputFeatureVectorArr.length;
        double sumOfSquare = 0.0D;
        this.inputFeatureVector = new double[dimension];
        double temp;
        for (int index = 0; index < dimension; index++) {
            temp = Double.parseDouble(inputFeatureVectorArr[index].trim());
            this.inputFeatureVector[index] = temp;
            sumOfSquare += temp * temp;
        }

        this.inputFeatureVectorNorm = Math.sqrt(sumOfSquare);
        LOGGER.debug("FeatureVectorScoringSearchScript.init, version:{}, norm:{}, baseConstant:{}, factorConstant:{}."
                , this.version, this.inputFeatureVectorNorm, this.baseConstant, this.factorConstant);
    }

    @Override
    public Object run() {
        if (this.inputFeatureVectorNorm == 0) {
            return this.baseConstant;
        }

        if (!doc().containsKey(this.field) || doc().get(this.field) == null) {
            LOGGER.error("cannot find field {}.", field);
            return this.baseConstant;
        }

        String docFeatureVectorStr = ((ScriptDocValues.Strings) doc().get(this.field)).getValue();
        return calculateScore(docFeatureVectorStr);
    }

    public double calculateScore(String docFeatureVectorStr) {
        // 1. check docFeatureVector
        if (docFeatureVectorStr == null) {
            return this.baseConstant;
        }

        docFeatureVectorStr = docFeatureVectorStr.trim();
        if (docFeatureVectorStr.isEmpty()) {
            return this.baseConstant;
        }

        // 2. check version and get feature vector array of document
        String[] docFeatureVectorArr;
        if (this.version != null) {
            String versionPrefix = version + "|";
            if (!docFeatureVectorStr.startsWith(versionPrefix)) {
                return this.baseConstant;
            }

            docFeatureVectorArr = docFeatureVectorStr.substring(versionPrefix.length()).split(",");
        } else {
            docFeatureVectorArr = docFeatureVectorStr.split(",");
        }

        // 3. check the dimension of input and document
        int dimension = this.inputFeatureVector.length;
        if (docFeatureVectorArr == null || docFeatureVectorArr.length != dimension) {
            return this.baseConstant;
        }

        // 4. calculate the relevance score of the two feature vector
        double sumOfSquare = 0.0D;
        double sumOfProduct = 0.0D;
        double tempValueInDouble;
        for (int i = 0; i < dimension; i++) {
            tempValueInDouble = Double.parseDouble(docFeatureVectorArr[i].trim());
            sumOfProduct += tempValueInDouble * this.inputFeatureVector[i];
            sumOfSquare += tempValueInDouble * tempValueInDouble;
        }

        if (sumOfSquare == 0) {
            return this.baseConstant;
        }

        double cosScore = sumOfProduct / (Math.sqrt(sumOfSquare) * inputFeatureVectorNorm);
        return this.baseConstant + this.factorConstant * cosScore;
    }
}

总结与后续改进

参考资料

扫一扫 关注我的微信公众号
上一篇下一篇

猜你喜欢

热点阅读