基于spark实现TFIDF

2018-03-21  本文已影响0人  yxwithu

上一段实习的时候用spark手写了一个tfidf,下面贴上代码并和spark中的源码进行比较。
输入文本(demo):

文档1:a b c d e f g
文档2:a b c d e f 
文档3:a b c d e
文档4:a b c d
文档5:a b c 
文档6:a b 
文档7:a

输出结果:

代码分析
主要有以下几个步骤:

  1. 读取文件到JavaRDD<String>中
  2. mapToPair将每行文本映射为doc <标题 : 单词[]>中,后者为分词后的单词数组
  3. mapValues获取每个文档的词频
  4. 将文档数进行广播,用于计算idf
  5. 类似于wordCount, 先将doc中的每个文本对应的去重单词出现次数置为1,然后aggregateByKey统计每个单词出现的文档数,用对应的求idf的公式,就可以求出idf了
  6. 将表示每个词idf的RDD<map> collect到driver,再进行广播,进行每个文档的tfIdf计算
  7. 最后写入输出文件

和spark Mllib中tf-idf实现方法的对比
源码中也是将tf计算和idf计算分隔开的,tf计算时也是用了HashMap但是使用了hash函数(hashcode取余numfeatures)将词映射到了一个int作为Key.在计算idf时每个文档使用了一个词语大小的向量来保存每个词是否出现过,累加这些向量就得到了整个数据集中每个词语出现的文档数,即IDF,再利用公式计算,不过源码中使用的是log即以e为底而不是以10为底

源码中也是用广播的形式将TF和IDF联系起来

public class GenerateTags {

    public static void main(String[] args) throws IOException{
        SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
//        SparkConf conf = new SparkConf().setAppName("video-tags");
        JavaSparkContext sc = new JavaSparkContext(conf);
        System.setProperty("hadoop.home.dir", "D:\\winutils");
        JavaRDD<String> lines = sc.textFile("C:\\Users\\YANGXIN\\Desktop\\test.txt");

        //得到每个文档标题和对应的词串
        JavaPairRDD<String, String[]> docs = lines.mapToPair(new PairFunction<String, String, String[]>() {
            @Override
            public Tuple2<String, String[]> call(String s) throws Exception {
                String[] doc = s.split(":");
                String title = doc[0];
                String[] words = doc[1].split(" ");
                return new Tuple2<String, String[]>(title, words);
            }
        });

        //得到每个文档的词频
        JavaPairRDD<String, Map<String, Double>> docTF = docs.mapValues(new Function<String[], Map<String, Double>>() {
            @Override
            public Map<String, Double> call(String[] strings) throws Exception {
                Map<String, Double> map = new HashMap<String, Double>();
                int sum = strings.length;
                for(String str : strings){
                    double cnt = map.containsKey(str) ? map.get(str) : 1;
                    map.put(str, cnt);
                }
                for(String str : map.keySet()){
                    map.replace(str, map.get(str) / sum);
                }
                return map;
            }
        });

        //文档数
        final Broadcast<Long> docCnt = sc.broadcast(docs.count());

        //得到每个词的idf值
        JavaPairRDD<String, Integer> ones = docs.flatMapToPair(new PairFlatMapFunction<Tuple2<String, String[]>, String, Integer>() {
            @Override
            public Iterable<Tuple2<String, Integer>> call(Tuple2<String, String[]> stringTuple2) throws Exception {
                List<Tuple2<String, Integer>> list = new ArrayList<Tuple2<String, Integer>>();
                Set<String> set = new HashSet<String>();
                for(String str : stringTuple2._2()){
                    set.add(str);
                }
                for(String str : set){
                    list.add(new Tuple2<>(str, 1));
                }
                return list;
            }
        });

        //每个单词在多少个文档中出现了
        JavaPairRDD<String, Integer> wordDocCnt= ones.aggregateByKey(0, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //同partition下的处理
                return integer + integer2;
            }
        }, new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception { //不同partition下的处理
                return integer + integer2;
            }
        });

        JavaPairRDD<String, Double> wordIdf = wordDocCnt.mapValues(new Function<Integer, Double>() {
            @Override
            public Double call(Integer integer) throws Exception {
                return Math.log10((docCnt.getValue() + 1) * 1.0 / (integer + 1));  //计算逆文档频率
            }
        });

        //广播idf值,进行tf-idf计算
        Map<String, Double> idfs = wordIdf.collectAsMap();
        final Broadcast<Map<String, Double>> idfMap = sc.broadcast(idfs);

        //计算每个文档的tf-idf向量
        JavaPairRDD<String, TreeMap<Double, String>> TfIdf = docTF.mapValues(new Function<Map<String, Double>, TreeMap<Double, String>>() {
            @Override
            public TreeMap<Double, String> call(Map<String, Double> stringDoubleMap) throws Exception {
                TreeMap<Double, String> map = new TreeMap<Double, String>();
                for(Map.Entry<String, Double> entry : stringDoubleMap.entrySet()){
                    String word = entry.getKey();
                    Double tf = entry.getValue();
                    Double idf = idfMap.getValue().get(word);
                    map.put(tf * idf, word);
                }
                return map;
            }
        });

        TfIdf.saveAsTextFile("C:\\Users\\YANGXIN\\Desktop\\result");
    }

参考文献:
https://github.com/endymecy/spark-ml-source-analysis/blob/master/%E7%89%B9%E5%BE%81%E6%8A%BD%E5%8F%96%E5%92%8C%E8%BD%AC%E6%8D%A2/TF-IDF.md

上一篇下一篇

猜你喜欢

热点阅读