程序员

SMO算法实现

2020-07-17  本文已影响0人  lightt

这里根据SMO算法原论文中的伪代码实现了SMO算法。算法和数据已经上传到了git

伪代码

target = desired output vector
point = training point matrix
procedure takeStep(i1,i2)
    if (i1 == i2) return 0
    alph1 = Lagrange multiplier for i1
    y1 = target[i1]
    E1 = SVM output on point[i1] – y1 (check in error cache)
    s = y1*y2
    Compute L, H via equations (13) and (14)
    if (L == H)
        return 0
    k11 = kernel(point[i1],point[i1])
    k12 = kernel(point[i1],point[i2])
    k22 = kernel(point[i2],point[i2])
    eta = k11+k22-2*k12
    if (eta > 0)
    {
        a2 = alph2 + y2*(E1-E2)/eta
        if (a2 < L) a2 = L
        else if (a2 > H) a2 = H
    }
    else
    {
        Lobj = objective function at a2=L
        Hobj = objective function at a2=H
        if (Lobj < Hobj-eps)
            a2 = L
        else if (Lobj > Hobj+eps)
            a2 = H
        else
            a2 = alph2
    }
    if (|a2-alph2| < eps*(a2+alph2+eps))
        return 0
    a1 = alph1+s*(alph2-a2)
    Update threshold to reflect change in Lagrange multipliers
    Update weight vector to reflect change in a1 & a2, if SVM is linear
    Update error cache using new Lagrange multipliers
    Store a1 in the alpha array
    Store a2 in the alpha array
    return 1
endprocedure

procedure examineExample(i2)
    y2 = target[i2]
    alph2 = Lagrange multiplier for i2
    E2 = SVM output on point[i2] – y2 (check in error cache)
    r2 = E2*y2
    if ((r2 < -tol && alph2 < C) || (r2 > tol && alph2 > 0))
    {
        if (number of non-zero & non-C alpha > 1)
        {
            i1 = result of second choice heuristic (section 2.2)
            if takeStep(i1,i2)
                return 1
        }
        loop over all non-zero and non-C alpha, starting at a random point
        {
            i1 = identity of current alpha
            if takeStep(i1,i2)
                return 1
        }
        loop over all possible i1, starting at a random point
        {
            i1 = loop variable
            if (takeStep(i1,i2)
                return 1
        }
    }
    return 0
endprocedure

main routine:
    numChanged = 0
    examineAll = 1
    while (numChanged > 0 | examineAll)
    {
        numChanged = 0;
        if (examineAll)
            loop I over all training examples
            numChanged += examineExample(I)
        else
            loop I over examples where alpha is not 0 & not C
            numChanged += examineExample(I)
        if (examineAll == 1)
            examineAll = 0
        else if (numChanged == 0)
            examineAll = 1
    }

python实现

# @Author  : lightXu
# @File    : smo_paper.py

import numpy as np
import matplotlib.pyplot as plt
import random
import copy


class OptStruct:
    """
    数据结构,维护所有需要操作的值
    Parameters:
        dataMatIn - 数据矩阵
        classLabels - 数据标签
        C - 松弛变量
        toler - 容错率
    """

    def __init__(self, data_x, label, C, toler):
        self.X = data_x
        self.label = label
        self.C = C
        self.toler = toler
        self.row = data_x.shape[0]
        self.alpha = np.zeros(self.row)
        self.b = 0
        self.e_cache = np.zeros(self.row)
        # self.e_cache = label * (-1)


def cal_Ek(ost, k):
    """
    计算误差
    Parameters:
        ost - 数据结构
        k - 标号为k的数据
    Returns:
        Ek - 标号为k的数据误差
    """
    fxk = np.dot((ost.alpha * ost.label).T, np.dot(ost.X, ost.X[k, :])) + ost.b
    Ek = fxk - ost.label[k]
    return round_float(Ek), round_float(fxk)


def round_float(value):
    return round(value, 8)


def load_data(file_name):
    data_x = []
    data_y = []

    with open(file_name, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split('\t')
            xi = line[:-1]
            data_x.append(xi)
            data_y.append(line[-1])

    data_x = np.array(data_x, dtype=np.float)
    label = np.array(data_y, dtype=np.float)

    return data_x, label


def select_j_random(i, m):
    j = i
    while j == i:
        j = int(random.uniform(0, m))

    return j


def select_j(eligible_list, i, ost, Ei):
    eligible_list = list(eligible_list)
    if i in eligible_list:
        eligible_list.remove(i)

    E_list = [cal_Ek(ost, k)[0] for k in eligible_list]
    if Ei < 0:
        value = max(E_list)
    elif Ei > 0:
        value = min(E_list)
    else:
        value = max([abs(cal_Ek(ost, k)[0]) for k in eligible_list])
    max_k = eligible_list[E_list.index(value)]

    # E_list1 = [(Ei - cal_Ek(ost, k)[0]) for k in eligible_list]
    # value1 = max(E_list1)
    # max_k1 = eligible_list[E_list1.index(value1)]
    # if max_k != max_k1:
    #     print('!=', i)

    Ej, _ = cal_Ek(ost, max_k)
    return max_k, Ej


def updateEk(ost, k):
    """
    计算Ek,并更新误差缓存
    Parameters:
        oS - 数据结构
        k - 标号为k的数据的索引值
    Returns:
    """
    Ek, _ = cal_Ek(ost, k)
    ost.e_cache[k] = Ek


def clip_alpha(alpha, L, H):
    if alpha > H:
        alpha = H
    if alpha < L:
        alpha = L

    return alpha


def cal_w(data_x, label, alpha):
    w = np.dot((alpha * label).T, data_x)
    return w


def objective_func(ost, i1, i2, alpha1, alpha2, L, H):
    s = ost.label[i1] * ost.label[i2]
    k11 = np.dot(ost.X[i1, :], ost.X[i1, :])
    k12 = np.dot(ost.X[i1, :], ost.X[i2, :])
    k22 = np.dot(ost.X[i2, :], ost.X[i2, :])
    f1 = (ost.label[i1] * (cal_Ek(ost, i1) + ost.b) - alpha1 * k11 - s * alpha2 * k12)
    f2 = (ost.label[i2] * (cal_Ek(ost, i2) + ost.b) - s * alpha1 * k12 - alpha2 * k12)

    L1 = alpha1 + s * (alpha2 - L)
    H1 = alpha1 + s * (alpha2 - H)

    obj_L = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12
    obj_H = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12

    return obj_L, obj_H


def take_step(ost, i1, i2, E2):
    if i1 == i2:
        return 0
    alpha1 = ost.alpha[i1].copy()
    y1 = ost.label[i1]
    alpha2 = ost.alpha[i2].copy()
    y2 = ost.label[i2]

    E1, _ = cal_Ek(ost, i1)
    s = y1 * y2

    if ost.label[i1] != ost.label[i2]:
        L = max(0, ost.alpha[i2] - ost.alpha[i1])
        H = min(ost.C, ost.C + ost.alpha[i2] - ost.alpha[i1])
    else:
        L = max(0, ost.alpha[i2] + ost.alpha[i1] - ost.C)
        H = min(ost.C, ost.alpha[i2] + ost.alpha[i1])
    if L == H:
        # print("L==H")
        return 0

    eta = (np.dot(ost.X[i1, :], ost.X[i1, :])
           + np.dot(ost.X[i2, :], ost.X[i2, :])
           - 2 * np.dot(ost.X[i1, :], ost.X[i2, :]))
    eta = round_float(eta)
    if eta > 0:
        a2 = alpha2 + y2 * (E1 - E2) / eta
        a2 = a2
        if a2 < L:
            a2 = L
        if a2 > H:
            a2 = H

    else:
        Lobj, _ = objective_func(ost, i1, i2, alpha1, L, L, H)
        _, Hobj = objective_func(ost, i1, i2, alpha1, H, L, H)
        if Lobj < Hobj - 0.0001:
            a2 = L
        elif Lobj > Hobj + 0.001:
            a2 = H
        else:
            a2 = alpha2

    if abs(a2 - alpha2) < 0.0001:
        return 0

    a1 = alpha1 + s * (alpha2 - a2)

    b1 = (ost.b - E1
          - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i1, :])
          - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i1, :], ost.X[i2, :]))
    b2 = (ost.b - E2
          - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i2, :])
          - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i2, :], ost.X[i2, :]))

    if 0 < ost.alpha[i1] < ost.C:
        ost.b = b1
    elif 0 < ost.alpha[i2] < ost.C:
        ost.b = b2
    else:
        ost.b = (b1 + b2) / 2.0

    updateEk(ost, i1)
    updateEk(ost, i2)

    ost.alpha[i1] = round_float(a1)
    ost.alpha[i2] = round_float(a2)

    return 1


def violate_kkt(ost, alpha2, E2, fx2, y2):
    r2 = E2 * y2
    violate_cond1 = r2 < -ost.toler and alpha2 < ost.C
    violate_cond2 = r2 > ost.toler and alpha2 > 0

    violate12 = violate_cond1 or violate_cond2

    # 原始kkt
    # y2 * fx2 - 1 = y2*(fx2-y2) = y2*E2
    violate_cond3 = (not y2 * fx2 - 1 >= 0) and alpha2 == 0
    violate_cond4 = (not y2 * fx2 - 1 != 0) and 0 < alpha2 < ost.C
    violate_cond5 = (not y2 * fx2 - 1 <= 0) and alpha2 == ost.C

    # Notice that the KKT conditions are checked to be within ε of fulfillment.
    # 论文中引入了一个误差eps, 此时
    violate_cond3_ = (not y2 * fx2 - 1 >= -ost.toler) and alpha2 == 0
    violate_cond4_ = (not abs(y2 * fx2 - 1) <= ost.toler) and 0 < alpha2 < ost.C
    violate_cond5_ = (not y2 * fx2 - 1 <= ost.toler) and alpha2 == ost.C

    violate345 = violate_cond3_ or violate_cond4_ or violate_cond5_

    return violate345


def examine_example(ost, i2):
    y2 = ost.label[i2]
    alpha2 = ost.alpha[i2]
    E2, fx2 = cal_Ek(ost, i2)

    # 是非违反kkt条件
    cond = violate_kkt(ost, alpha2, E2, fx2, y2)
    if cond:
        non_0_non_C_alpha_list = np.where((ost.alpha != 0) & (ost.alpha != ost.C))[0]
        if (len(non_0_non_C_alpha_list)) > 1:
            i1, _ = select_j(non_0_non_C_alpha_list, i2, ost, E2)
            if take_step(ost, i1, i2, E2):
                return 1

        non_tmp = non_0_non_C_alpha_list.copy().tolist()
        while len(non_tmp) > 0:
            i1 = random.choice(non_tmp)
            if take_step(ost, i1, i2, E2):
                return 1
            else:
                non_tmp.remove(i1)

        tmp_list = list(range(0, ost.row))
        while len(tmp_list) > 0:
            i1 = random.choice(tmp_list)
            if take_step(ost, i1, i2, E2):
                return 1
            else:
                tmp_list.remove(i1)

    return 0


def main(dataMatIn, classLabels, C, toler, maxIter):
    ost = OptStruct(dataMatIn, classLabels, C, toler)
    iter_num = 0
    num_changed = 0
    examine_all = 1

    while (iter_num < maxIter) and num_changed > 0 or examine_all:
        num_changed = 0
        if examine_all:
            for i in range(ost.row):
                """
                The outer loop first iterates over the entire training set, 
                determining whether each example violates the KKT conditions (12). 
                """
                num_changed = num_changed + examine_example(ost, i)
                print("全样本遍历:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
            iter_num += 1

        else:
            """
            After one pass through the entire training set, the outer loop iterates over all examples whose
            Lagrange multipliers are neither 0 nor C (the non-bound examples). Again, each example is
            checked against the KKT conditions and violating examples are eligible for optimization. 
            """
            non_bound_index = np.where((0 < ost.alpha) & (ost.alpha < C))[0]
            for i in non_bound_index:
                num_changed = num_changed + examine_example(ost, i)
                print("非边界:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
            iter_num += 1
        if examine_all:
            examine_all = 0
        elif num_changed == 0:
            examine_all = 1

    return ost.b, ost.alpha


def show_classifier(data_x, label, w, b, alpha, seed):
    positive_index = np.where(label == 1)[0]
    negative_index = np.where(label == -1)[0]
    data_x_positive = data_x[positive_index]
    data_x_negative = data_x[negative_index]

    plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                s=30, alpha=0.7, c='green')  # 正样本散点图
    plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                s=30, alpha=0.7, c='pink')  # 负样本散点图

    x_max = np.max(data_x, axis=0)[0]
    x_min = np.min(data_x, axis=0)[0]
    a1, a2 = w
    b = float(b)
    y1, y2 = (-b - a1 * x_max) / a2, (-b - a1 * x_min) / a2
    plt.plot([x_max, x_min], [y1, y2])

    # 找出支持向量点
    for i, alp in enumerate(alpha):
        if abs(alp) > 0:
            print(i)
            x_max, x_min = data_x[i]
            plt.scatter([x_max], [x_min], s=150, c='none', alpha=0.7, linewidth=1.5, edgecolor='red')

    # plt.savefig("./fig/seed_{}.png".format(seed))
    plt.show()
    plt.close()


def show_classifier1(data_x, label):
    positive_index = np.where(label == 1)[0]
    negative_index = np.where(label == -1)[0]
    data_x_positive = data_x[positive_index]
    data_x_negative = data_x[negative_index]

    plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                s=30, alpha=0.7, c='green')  # 正样本散点图
    plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                s=30, alpha=0.7, c='pink')  # 负样本散点图

    plt.show()


if __name__ == '__main__':
    seed = 10
    random.seed(seed)
    dataMat, labelMat = load_data('testSet.txt')
    b, alphas = main(dataMat, labelMat, 0.6, 0.0001, 100)
    w = cal_w(dataMat, labelMat, alphas)
    print(w, b)
    show_classifier(dataMat, labelMat, w, b, alphas, seed)

分类结果如下:


smo.png

补充

第一个参数选择需要判断是否违反原始KKT条件, 这部分的原理可以参考博客。论文作者在这部分引入了eps加快训练, 推导过程大家直接看代码就行了。

上一篇下一篇

猜你喜欢

热点阅读