简单感知机
2018-11-25 本文已影响0人
瞿大官人
简单感知机的实现
from matplotlib import pyplot as plt
from matplotlib import animation
import copy
# 训练集
training_set = [[(3, 3), 1], [(4, 3), 1], [(1, 1), -1],[(1, 2), -1],[(2, 2), -1]] # 训练数据集
history = [] # 用来记录每次更新过后的w,b
# 获取Y的值
def getY(number):
if number >=0:
return 1;
return -1;
#获取 w*x+b的值
def getFuncationResult(w,x,b):
return x[0][0] * w[0] + x[0][1] * w[1] + b;
# 计算最终有效的W值
def calculateFunction(w,b):
while True:
foreach = 0;
for first in training_set:
y = getFuncationResult(w, first, b)
num = y * first[1]
if num <= 0:
w = caculateW(w, first[1], first);
b = cacualteB(b, first[1])
history.append([copy.copy(w), b])
break;
foreach = foreach + 1;
if foreach == training_set.__len__():
break;
return w;
def caculateW(oldW,oldY,selectPoint):
list=[]
list.append(oldW[0] + oldY * selectPoint[0][0])
list.append(oldW[1] + oldY * selectPoint[0][1])
return list;
def cacualteB(oldB, oldY):
return oldB + oldY;
if __name__ == "__main__":
#for i in range(1000): # 迭代1000遍
#if not check(): break # 如果已正确分类,则结束迭代
# 以下代码是将迭代过程可视化
# 首先建立我们想要做成动画的图像figure, 坐标轴axis,和plot element
calculateFunction([0,0],0)
fig = plt.figure()
ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
line, = ax.plot([], [], 'g', lw=2) # 画一条线
label = ax.text([], [], '')
def init():
line.set_data([], [])
x, y, x_, y_ = [], [], [], []
for p in training_set:
if p[1] > 0:
x.append(p[0][0]) # 存放yi=1的点的x1坐标
y.append(p[0][1]) # 存放yi=1的点的x2坐标
else:
x_.append(p[0][0]) # 存放yi=-1的点的x1坐标
y_.append(p[0][1]) # 存放yi=-1的点的x2坐标
plt.plot(x, y, 'bo', x_, y_, 'rx') # 在图里yi=1的点用点表示,yi=-1的点用叉表示
plt.axis([-6, 6, -6, 6]) # 横纵坐标上下限
plt.grid(True) # 显示网格
plt.xlabel('x1') # 这里我修改了原文表示
plt.ylabel('x2') # 为了和原理中表达方式一致,横纵坐标应该是x1,x2
plt.title('Perceptron Algorithm (www.hankcs.com)') # 给图一个标题:感知机算法
return line, label
def animate(i):
global history, ax, line, label
w = history[i][0]
b = history[i][1]
if w[1] == 0: return line, label
# 因为图中坐标上下限为-6~6,所以我们在横坐标为-7和7的两个点之间画一条线就够了,这里代码中的xi,yi其实是原理中的x1,x2
x1 = -7
y1 = -(b + w[0] * x1) / w[1]
x2 = 7
y2 = -(b + w[0] * x2) / w[1]
line.set_data([x1, x2], [y1, y2]) # 设置线的两个点
x1 = 0
y1 = -(b + w[0] * x1) / w[1]
label.set_text(history[i])
label.set_position([x1, y1])
return line, label
print("参数w,b更新过程:", history)
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=500, repeat=False,
blit=True)
plt.show()
深度截图_选择区域_20181125193218.png