简单感知机

2018-11-25  本文已影响0人  瞿大官人

简单感知机的实现

from matplotlib import pyplot as plt
from matplotlib import animation
import copy
# 训练集
training_set = [[(3, 3), 1], [(4, 3), 1], [(1, 1), -1],[(1, 2), -1],[(2, 2), -1]]  # 训练数据集
history = []  # 用来记录每次更新过后的w,b
# 获取Y的值
def getY(number):

    if number >=0:
        return 1;
    return -1;
#获取 w*x+b的值
def getFuncationResult(w,x,b):
    return x[0][0] * w[0] + x[0][1] * w[1] + b;
# 计算最终有效的W值
def calculateFunction(w,b):

    while True:
        foreach = 0;
        for first in training_set:

            y = getFuncationResult(w, first, b)
            num = y * first[1]
            if num <= 0:
                w = caculateW(w, first[1], first);
                b = cacualteB(b, first[1])
                history.append([copy.copy(w), b])
                break;
            foreach = foreach + 1;
        if foreach == training_set.__len__():
            break;
    return w;


def caculateW(oldW,oldY,selectPoint):
    list=[]
    list.append(oldW[0] + oldY * selectPoint[0][0])
    list.append(oldW[1] + oldY * selectPoint[0][1])
    return list;
def cacualteB(oldB, oldY):

    return oldB + oldY;

if __name__ == "__main__":
    #for i in range(1000):  # 迭代1000遍
        #if not check(): break  # 如果已正确分类,则结束迭代
    # 以下代码是将迭代过程可视化
    # 首先建立我们想要做成动画的图像figure, 坐标轴axis,和plot element
    calculateFunction([0,0],0)
    fig = plt.figure()
    ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
    line, = ax.plot([], [], 'g', lw=2)  # 画一条线
    label = ax.text([], [], '')


    def init():
        line.set_data([], [])
        x, y, x_, y_ = [], [], [], []
        for p in training_set:
            if p[1] > 0:
                x.append(p[0][0])  # 存放yi=1的点的x1坐标
                y.append(p[0][1])  # 存放yi=1的点的x2坐标
            else:
                x_.append(p[0][0])  # 存放yi=-1的点的x1坐标
                y_.append(p[0][1])  # 存放yi=-1的点的x2坐标
        plt.plot(x, y, 'bo', x_, y_, 'rx')  # 在图里yi=1的点用点表示,yi=-1的点用叉表示
        plt.axis([-6, 6, -6, 6])  # 横纵坐标上下限
        plt.grid(True)  # 显示网格
        plt.xlabel('x1')  # 这里我修改了原文表示
        plt.ylabel('x2')  # 为了和原理中表达方式一致,横纵坐标应该是x1,x2
        plt.title('Perceptron Algorithm (www.hankcs.com)')  # 给图一个标题:感知机算法
        return line, label

    def animate(i):
        global history, ax, line, label
        w = history[i][0]
        b = history[i][1]
        if w[1] == 0: return line, label
        # 因为图中坐标上下限为-6~6,所以我们在横坐标为-7和7的两个点之间画一条线就够了,这里代码中的xi,yi其实是原理中的x1,x2
        x1 = -7
        y1 = -(b + w[0] * x1) / w[1]
        x2 = 7
        y2 = -(b + w[0] * x2) / w[1]
        line.set_data([x1, x2], [y1, y2])  # 设置线的两个点
        x1 = 0
        y1 = -(b + w[0] * x1) / w[1]
        label.set_text(history[i])
        label.set_position([x1, y1])
        return line, label


    print("参数w,b更新过程:", history)
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=500, repeat=False,
                                   blit=True)
    plt.show()

深度截图_选择区域_20181125193218.png
上一篇下一篇

猜你喜欢

热点阅读