深度学习实战演练

第六章(1.6)机器学习实战——打造属于自己的贝叶斯分类器

2019-01-13  本文已影响3人  _两只橙_

github项目地址:https://github.com/liangzhicheng120/bayes

一、简介

[图片上传失败...(image-3fbd35-1547375244427)]

二、例子

三、实战代码

火影忍者 火影
火影忍者 秘传
火影忍者 大蛇丸
火影忍者 剧场版
火影忍者 动作
火影忍者 激斗
火影忍者 战斗
火影忍者 转生
火影忍者 佐助
火影忍者 村子
火影忍者 第六代火影
火影忍者 克拉
火影忍者 卡卡
火影忍者 带土
火影忍者 疾风
火影忍者 自来
火影忍者 火影忍者
火影忍者 仙人
火影忍者 六道
火影忍者 大战
火影忍者 九尾
火影忍者 忍者
火影忍者 究极
火影忍者 纲手
火影忍者 鸣人
火影忍者 木叶
火影忍者 忍术
火影忍者 秽土
火影忍者 宇智波
火影忍者 九尾妖狐
火影忍者 阿飞
海贼王 正文
海贼王 尾田
海贼王 海贼王
海贼王 弗兰奇
海贼王 草帽
海贼王 海贼
海贼王 武海
海贼王 事件
海贼王 悬赏
海贼王 第话
海贼王 梦想
海贼王 血型
海贼王 王下
海贼王 航路
海贼王 历史
海贼王 德雷斯
海贼王 船长
海贼王 恶魔
海贼王 路飞
海贼王 漫画
海贼王 超新星
海贼王 罗萨篇
海贼王 世界
海贼王 果实
海贼王 冥王
海贼王 荣一郎
海贼王 海贼团
海贼王 司法
海贼王 超人
海贼王 成为
海贼王 寻找
海贼王 传说
海贼王 海贼王
海贼王 中海
海贼王 罗杰
海贼王 秘宝
海贼王 留下
海贼王 伙伴
海贼王 ONE
海贼王 PIECE
海贼王 海贼
海贼王 志同道合
海贼王 扬起
海贼王 实现
龙珠 复活
龙珠 仙人
龙珠 武道
龙珠 得到
龙珠 军团
龙珠 找寻
龙珠 魔王
龙珠 饺子
龙珠 特典
龙珠 打败
龙珠 花梨
龙珠 缎带
龙珠 发售日期
龙珠 龙珠
龙珠 天津
龙珠 七龙珠
龙珠 比克
龙珠 天神
龙珠 修练
龙珠 悟空
龙珠 封入
龙珠 次郎
龙珠 拉夫
龙珠 封印
龙珠 许愿
龙珠 兵卫
龙珠 一武道
龙珠 动画
package com.xinrui.util;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.Charsets;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

import com.hankcs.hanlp.HanLP;

/**
 * 贝叶斯计算器主体类
 */
public class Bayes {

    private static Logger logger = Logger.getLogger(Bayes.class);

    /**
     * 将原训练元组按类别划分
     * 
     * @param datas
     *            训练元组
     * @return Map<类别,属于该类别的训练元组>
     */
    public static Map<String, ArrayList<ArrayList<String>>> classifyByCategory(ArrayList<ArrayList<String>> datas) {
        if (datas == null) {
            return null;
        }

        Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
        ArrayList<String> singleTrainning = null;
        String classificaion = "";
        for (int i = 0; i < datas.size(); i++) {
            singleTrainning = datas.get(i);
            classificaion = singleTrainning.get(0);
            singleTrainning.remove(0);
            if (map.containsKey(classificaion)) {
                map.get(classificaion).add(singleTrainning);
            } else {
                ArrayList<ArrayList<String>> list = new ArrayList<ArrayList<String>>();
                list.add(singleTrainning);
                map.put(classificaion, list);
            }
        }

        return map;
    }

    /**
     * 在训练数据的基础上预测测试元组的类别
     * 
     * @param datas
     *            训练元组
     * @param testData
     *            测试元组
     * @return 测试元组的类别
     */
    public static String predictClassify(ArrayList<ArrayList<String>> datas, ArrayList<String> testData) {

        if (datas == null || testData == null) {
            return null;
        }

        int maxPIndex = -1;
        Map<String, ArrayList<ArrayList<String>>> map = classifyByCategory(datas);
        Object[] classes = map.keySet().toArray();
        double maxProbability = 0.0;
        for (int i = 0; i < map.size(); i++) {
            double p = 0.0;
            for (int j = 0; j < testData.size(); j++) {
                p += calProbabilityClassificationInKey(map, classes[i].toString(), testData.get(j));
            }
            if (p > maxProbability) {
                maxProbability = p;
                maxPIndex = i;
            }
        }

        return maxPIndex == -1 ? "其他" : classes[maxPIndex].toString();
    }

    /**
     * 在训练数据的基础上预测测试元组的类别
     * 
     * @param testData
     *            测试元组
     * @return 测试元组的类别
     * @throws Exception
     */
    public String predictClassify(ArrayList<String> testData, String mId) throws Exception {
        return predictClassify(read(mId), testData);
    }

    /**
     * 某一特征值在某一分类上的概率分布[ P(key|Classify) ]
     * 
     * @param classify
     *            某一分类特征向量集
     * @param value
     *            某一特征值
     * @return 概率分布
     */
    private static double calProbabilityKeyInClassification(ArrayList<ArrayList<String>> classify, String value) {
        if (classify == null || StringUtils.isEmpty(value)) {
            return 0.0;
        }
        int totleKeyCount = 0;
        int foundKeyCount = 0;
        ArrayList<String> featureVector = null; // 分类中的某一特征向量
        for (int i = 0; i < classify.size(); i++) {
            featureVector = classify.get(i);
            for (int j = 0; j < featureVector.size(); j++) {
                totleKeyCount++;
                if (featureVector.get(j).equalsIgnoreCase(value)) {
                    foundKeyCount++;
                }
            }
        }
        return totleKeyCount == 0 ? 0.0 : 1.0 * foundKeyCount / totleKeyCount;
    }

    /**
     * 获得某一分类的概率 [ P(Classify) ]
     * 
     * @param classes
     *            分类集合
     * @param classify
     *            某一特定分类
     * @return 某一分类的概率
     */
    private static double calProbabilityClassification(Map<String, ArrayList<ArrayList<String>>> map, String classify) {
        if (map == null | StringUtils.isEmpty(classify)) {
            return 0;
        }
        Object[] classes = map.keySet().toArray();
        int totleClassifyCount = 0;
        for (int i = 0; i < classes.length; i++) {
            totleClassifyCount += map.get(classes[i].toString()).size();
        }
        return 1.0 * map.get(classify).size() / totleClassifyCount;
    }

    /**
     * 获得关键词的总概率
     * 
     * @param map
     *            所有分类的数据集
     * @param key
     *            某一特征值
     * @return 某一特征值在所有分类数据集中的比率
     */
    private static double calProbabilityKey(Map<String, ArrayList<ArrayList<String>>> map, String key) {
        if (map == null || StringUtils.isEmpty(key)) {
            return 0;
        }
        int foundKeyCount = 0;
        int totleKeyCount = 0;
        Object[] classes = map.keySet().toArray();
        for (int i = 0; i < map.size(); i++) {
            ArrayList<ArrayList<String>> classify = map.get(classes[i]);
            ArrayList<String> featureVector = null; // 分类中的某一特征向量
            for (int j = 0; j < classify.size(); j++) {
                featureVector = classify.get(j);
                for (int k = 0; k < featureVector.size(); k++) {
                    totleKeyCount++;
                    if (featureVector.get(k).equalsIgnoreCase(key)) {
                        foundKeyCount++;
                    }
                }
            }
        }
        return totleKeyCount == 0 ? 0.0 : 1.0 * foundKeyCount / totleKeyCount;
    }

    /**
     * 计算在出现key的情况下,是分类classify的概率 [ P(Classify | key) ]
     * 
     * @param map
     *            所有分类的数据集
     * @param classify
     *            某一特定分类
     * @param key
     *            某一特定特征
     * @return P(Classify | key)
     */
    private static double calProbabilityClassificationInKey(Map<String, ArrayList<ArrayList<String>>> map, String classify, String key) {
        ArrayList<ArrayList<String>> classifyList = map.get(classify);
        double pkc = calProbabilityKeyInClassification(classifyList, key); // p(key|classify)
        double pc = calProbabilityClassification(map, classify); // p(classify)
        double pk = calProbabilityKey(map, key); // p(key)
        return pk == 0 ? 0 : pkc * pc / pk; // p(classify | key)
    }

    /**
     * 读取训练文档中的训练数据 并进行封装
     * 
     * @param filePath
     *            训练文档的路径
     * @return 训练数据集
     * @throws Exception
     */
    public static ArrayList<ArrayList<String>> read(String clzss) throws Exception {
        ArrayList<String> singleTrainning = null;
        ArrayList<ArrayList<String>> trainningSet = new ArrayList<ArrayList<String>>();
        List<String> datas = new ArrayList<String>(FileUtils.readLines(new File(clzss), Charsets.UTF_8));
        if (datas.size() == 0) {
            logger.error("[" + "模型文件加载错误" + "]" + clzss);
            throw new Exception("模型文件加载错误!");
        }
        for (int i = 0; i < datas.size(); i++) {
            String[] characteristicValues = datas.get(i).split(" ");
            singleTrainning = new ArrayList<String>();
            for (int j = 0; j < characteristicValues.length; j++) {
                if (StringUtils.isNotEmpty(characteristicValues[j])) {
                    singleTrainning.add(characteristicValues[j]);
                }
            }
            trainningSet.add(singleTrainning);
        }
        return trainningSet;
    }

    /**
     * 
     * @param fileName
     *            训练文件
     * @param size
     *            关键词个数
     */
    public static void trainBayes(String fileName, String mId, int size) {
        try {
            Bayes bayes = new Bayes();
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            String line = null;
            int total = 0;
            int right = 0;
            long start = System.currentTimeMillis();
            while ((line = reader.readLine()) != null) {
                ArrayList<String> testData = (ArrayList<String>) HanLP.extractKeyword(line, size);
                String classification = bayes.predictClassify(testData, mId);
                if (classification.equals(fileName.split("\\.")[0])) {
                    right += 1;
                }
                System.out.print("\n分类:" + classification);
                total++;
            }
            reader.close();
            long end = System.currentTimeMillis();
            System.out.println("正确分类:" + right);
            System.out.println("总行数:" + total);
            System.out.println("正确率:" + MathUtil.div(right, total, 4) * 100 + "%");
            System.out.println("程序运行时间: " + (end - start) / 1000 + "s");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}

package com.xinrui.test;

import java.util.ArrayList;

import com.hankcs.hanlp.HanLP;
import com.xinrui.util.Bayes;

public class TestBayes {
    public static void main(String[] args) throws Exception {
        // 获取当前工程存放位置
        String path = TestBayes.class.getResource("").getPath();
        String classPath = path.substring(0, path.indexOf("/com/xinrui"));
        // 模型文件存放位置
        String modelName = classPath + "/model/classify_model.txt";
        ArrayList<ArrayList<String>> model = Bayes.read(modelName);
        // 抽取10个关键词组成一个元祖
        ArrayList<String> testData = (ArrayList<String>) HanLP
                .extractKeyword(
                        "时值“大海贼时代”,为了寻找传说中海贼王罗杰所留下的大秘宝“ONE PIECE”,无数海贼扬起旗帜,互相争斗。有一个梦想成为海盗的少年叫路飞,他因误食“恶魔果实”而成为了橡皮人,在获得超人能力的同时付出了一辈子无法游泳的代价。十年后,路飞为实现与因救他而断臂的香克斯的约定而出海,他在旅途中不断寻找志同道合的伙伴,开始了以成为海贼王为目标的伟大的冒险旅程[9]  ",
                        15);
        // 输出预测结果
        System.out.println(Bayes.predictClassify(model, testData));
    }
}

关注我的技术公众号,每天推送优质文章
关注我的音乐公众号,工作之余放松自己
微信扫一扫下方二维码即可关注:


音乐公众号
技术公众号
上一篇 下一篇

猜你喜欢

热点阅读