朴素贝叶斯分类器(Python实现+详细源码原理)

2021-04-09  本文已影响0人  Anthons

一、贝叶斯公式

1、贝叶斯公式的本质:<u>由因到果,由果推因</u>

2、贝叶斯公式:

[图片上传中...(wps6.png-5fd624-1618488341725-0)]


wps3.png

二、朴素贝叶斯

1、朴素贝叶斯公式

x1,x2,...xn为特征集合,y为分类结果

![ ](https://img.haomeiwen.com/i24930360/0005ce89d26c5c46.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)

朴素贝叶斯假设各个特征之间相互独立


wps5.png

分母相同情况下,我们只要保证分子最大


wps6.png

三、朴素贝叶斯分类器 Python实现

训练数据集
long,not_long,sweet,not_sweet,yellow,not_yellow,species
400,100,350,150,450,50,banana
0,300,150,150,300,0,orange
100,100,150,50,50,150,other_fruit
测试数据集
long,sweet,yellow
not_long,not_sweet,not_yellow
not_long,sweet,not_yellow
not_long,sweet,yellow
not_long,sweet,yellow
not_long,not_sweet,not_yellow
long,not_sweet,not_yellow
long,not_sweet,not_yellow
long,not_sweet,not_yellow
long,not_sweet,not_yellow
long,not_sweet,yellow
not_long,not_sweet,yellow
not_long,not_sweet,yellow
long,not_sweet,not_yellow
not_long,not_sweet,yellow

"""
实现朴素贝叶斯分类器
"""
import pandas as pd


def count_total(data):
    count = {}
    total = 0

    for index in data.index:
        specie = data.loc[index, 'species']
        count[specie] = data.loc[index, 'sweet'] + data.loc[index, 'not_sweet']
        total += count[specie]
    return count, total


def cal_base_rates(categories, total):
    cal_base_rates = {}
    for label in categories:
        priori_prob = categories[label]/total
        cal_base_rates[label] = priori_prob
    return cal_base_rates


def likelihold_prob(data, count):
    likelihold = {}
    for index in data.index:
        attr_prob = {}
        specie = data.loc[index, 'species']
        attr_prob['long'] = data.loc[index, 'long']/count[specie]
        attr_prob['not_long'] = data.loc[index, 'not_long']/count[specie]
        attr_prob['sweet'] = data.loc[index, 'sweet']/count[specie]
        attr_prob['not_sweet'] = data.loc[index, 'not_sweet']/count[specie]
        attr_prob['yellow'] = data.loc[index, 'yellow']/count[specie]
        attr_prob['not_yellow'] = data.loc[index, 'not_yellow']/count[specie]
        likelihold[specie] = attr_prob

    return likelihold


def navie_bayes_classifier(data, length=None, sweetness=None, color=None):
    count, total = count_total(data)
    # print("各个水果的总数:" + str(count))
    priori_prob = cal_base_rates(count, total)
    # print("各种水果的先验概率:" + str(priori_prob))
    likelihold = likelihold_prob(data, count)
    # print("各个特征在各种水果中的概率:" + str(likelihold))
    # ep = evidence_prob(data)
    # print("各个特征的先验概率:" + str(ep))
    res = {}
    for lable in data['species']:
        prob = priori_prob[lable]
        prob *= likelihold[lable][length] * likelihold[lable][sweetness] * likelihold[lable][color]
        res[lable] = prob
    print("预测结果:" + str(res))
    res = sorted(res.items(), key=lambda kv: kv[1], reverse=True)
    return res[0][0]


def main():
    # 定义数据集
    datasets_train = pd.read_csv('fruitclass_train.csv')
    datasets_test = pd.read_csv('fruitclass_test.csv')

    for index in datasets_test.index:
        long = datasets_test.loc[index, 'long']
        sweet = datasets_test.loc[index, 'sweet']
        color = datasets_test.loc[index, 'yellow']
        print("特征值:[{0}, {1}, {2}]".format(long, sweet, color))
        res = navie_bayes_classifier(datasets_train, long, sweet, color)
        print("水果类别:" + res)


if __name__ == '__main__':
    main()

结果
特征值:[not_long, not_sweet, not_yellow]
预测结果:{'banana': 0.003, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[not_long, sweet, not_yellow]
预测结果:{'banana': 0.006999999999999999, 'orange': 0.0, 'other_fruit': 0.05625000000000001}
水果类别:other_fruit
特征值:[not_long, sweet, yellow]
预测结果:{'banana': 0.063, 'orange': 0.15, 'other_fruit': 0.018750000000000003}
水果类别:orange
特征值:[not_long, sweet, yellow]
预测结果:{'banana': 0.063, 'orange': 0.15, 'other_fruit': 0.018750000000000003}
水果类别:orange
特征值:[not_long, not_sweet, not_yellow]
预测结果:{'banana': 0.003, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[long, not_sweet, not_yellow]
预测结果:{'banana': 0.012, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[long, not_sweet, not_yellow]
预测结果:{'banana': 0.012, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[long, not_sweet, not_yellow]
预测结果:{'banana': 0.012, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[long, not_sweet, not_yellow]
预测结果:{'banana': 0.012, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[long, not_sweet, yellow]
预测结果:{'banana': 0.108, 'orange': 0.0, 'other_fruit': 0.00625}
水果类别:banana
特征值:[not_long, not_sweet, yellow]
预测结果:{'banana': 0.027, 'orange': 0.15, 'other_fruit': 0.00625}
水果类别:orange
特征值:[not_long, not_sweet, yellow]
预测结果:{'banana': 0.027, 'orange': 0.15, 'other_fruit': 0.00625}
水果类别:orange
特征值:[long, not_sweet, not_yellow]
预测结果:{'banana': 0.012, 'orange': 0.0, 'other_fruit': 0.018750000000000003}
水果类别:other_fruit
特征值:[not_long, not_sweet, yellow]
预测结果:{'banana': 0.027, 'orange': 0.15, 'other_fruit': 0.00625}
水果类别:orange

上一篇下一篇

猜你喜欢

热点阅读