机器学习与数据挖掘

在Weka中实现流形学习Isomap中的距离计算

2017-03-19  本文已影响161人  Daisy丶

最近因为项目需求,需要时在weka上实现流形距离计算,因为weka没有提供流形学习的包,而smile提供了,于是我根据smile的等距离度量(Isomap)来重写了一个可在weka上使用的流形距离计算类。

欧式距离是最常用的距离度量,但是在数据集不具有全局线性结构是,欧氏距离就不是一种合理的数据距离度量,一般使用拓扑流形结构来度量高维度的非线线性数据。这种方法通常用了对数据进行降维,也被称为流形学习。

定义1:
流形两点间x1, x2的线段长度定义为 L(x1, x2) = exp(d(x1, x2) / σ) -1
定义2:
将数据点看作是无向有权图G=(V, E),V是顶点集合,E是边集P的集合,Pij表示图上数据点Xi, Xj的所有路径集合,则Xi,Xj的流形距离为 MD(xi, xj)=min∑L(pk, pk+1), 1≤k≤|p| - 1

算法流程:

for i = 1,2,3...m do
    确定xi的k个最近邻
    将xi与k个最近邻的距离设为定义的距离公式,与自己的距离设为0,与其他点距离设为-1
    将这些数值添加进入邻接矩阵
end

根据邻接矩阵构建一个有权无向图的对象
使用dijkstra最短距离求出图上任意两点的最短距离

ManifoldDistance.java

import weka.core.EuclideanDistance;
import weka.core.Instances;

import java.util.*;

/**
 * Created by Administrator on 2017/3/15.
 */
public class ManifoldDistance {
    private final Instances data;
    private final int k;
    private final double sigma;
    private double[][] matrix;
    private Graph graph = new Graph();

    /**
     * 流形学习的距离计算类的构造方法
     *
     * @param data  要计算的instances类型的数据集
     * @param k     KNN需要指定的参数k
     * @param sigma     距离公式需要的参数σ
     */
    public ManifoldDistance(Instances data, int k, double sigma) {
        this.data = data;
        this.k = k;
        this.sigma = sigma;
    }

    public Instances getData() {
        return data;
    }

    public int getK() {
        return k;
    }

    public double getSigma() {
        return sigma;
    }

    public double[][] getMatrix() {
        return matrix;
    }

    /**
     * 构造数据data的邻接矩阵
     *
     * @return      double[][]类型的邻接矩阵
     */
    private double[][] constructWeightMatrix() {
        int num = this.data.numInstances();
        double[][] weight_matrix = new double[num][num];
        EuclideanDistance calculateDistance = new EuclideanDistance(this.data);

        for(int i = 0; i < num; i++){
            HashMap<Integer, Double> temp = new HashMap<>();
            for(int j = 0; j < num; j++){
                if(i != j) {
                    double dist = calculateDistance.distance(this.data.instance(i), this.data.instance(j));
                    temp.put(j, Math.exp(dist / this.sigma) - 1);
                }else{
                    temp.put(j, 0.0);
                }
            }

            ArrayList<Integer> index = nearestNeighbor(temp);
            for(int n = 0; n < num; n++){
                if(index.contains(n)){
                    weight_matrix[i][n] = temp.get(n);
                    weight_matrix[n][i] = temp.get(n);
                }else if(i == n){
                    weight_matrix[i][i] = 0.0;
                }else{
                    if(weight_matrix[i][n] == 0.0) {
                        weight_matrix[i][n] = -1.0;
                    }
                }
            }
        }
        return weight_matrix;
    }

    /**
     * 计算K个最近邻
     *
     * @param temp  当前向量i与其他所有向量的距离
     * @return      k个最近邻所在的位置索引
     */
    private ArrayList<Integer> nearestNeighbor(HashMap<Integer, Double> temp){
        ArrayList<Integer> index = new ArrayList<>();
        ArrayList<Map.Entry<Integer, Double>> list = new ArrayList<>(temp.entrySet());
        list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));

        int count = 0;
        for (Map.Entry<Integer, Double> aList : list) {
            if(count >= this.k){
                break;
            }else {
                index.add(aList.getKey());
                count++;
            }
        }
        return index;
    }

    /**
     * 生成邻接矩阵与对应的无向有权图
     */
    public void build(){
        this.matrix = constructWeightMatrix();

        int num = this.matrix.length;

        HashMap<String, List<Vertex>>edge = new HashMap<>();
        for (int i = 0; i < num; i++){
            edge.put(Integer.toString(i), new ArrayList<>());
        }

        for (int i = 0; i < num; i++){
            for (int j = 0; j < num; j++){
                if (this.matrix[i][j] > 0){
                    List<Vertex> iedge = edge.get(Integer.toString(i));
                    iedge.add(new Vertex(Integer.toString(j), this.matrix[i][j]));
                    edge.put(Integer.toString(i), iedge);

                    List<Vertex> jedge = edge.get(Integer.toString(j));
                    jedge.add(new Vertex(Integer.toString(i), this.matrix[i][j]));
                    edge.put(Integer.toString(j), jedge);
                }
            }
        }

        for(String i : edge.keySet()){
            List<Vertex> toVertex = edge.get(i);
            this.graph.addVertex(i, toVertex);
        }
    }

    /**
     * 获取图上两个向量的dijkstra最短距离
     *
     * @param start     起始点
     * @param end   结束点
     * @return      最短距离的数值
     */
    public double getDistance(String start, String end){
        List<String> path = this.graph.getShortestPath(start, end);
        path.add(start);
        Collections.reverse(path);

        double mDist = 0.0;
        for (int i = 0; i < path.size() - 1; i++){
            int m = Integer.parseInt(path.get(i));
            int n = Integer.parseInt(path.get(i + 1));
            mDist += this.matrix[m][n];
        }

        System.out.println("shortest path:" + path);
        return mDist;
    }
}

Graph.java

import java.util.*;

/**
 * Created by Administrator on 2017/3/14.
 */

class Graph {

    private final Map<String, List<Vertex>> vertices;

    public Graph() {
        this.vertices = new HashMap<>();
    }

    public void addVertex(String character, List<Vertex> vertex) {
        this.vertices.put(character, vertex);
    }

    public List<String> getShortestPath(String start, String finish) {
        final Map<String, Double> distances = new HashMap<>();
        final Map<String, Vertex> previous = new HashMap<>();
        PriorityQueue<Vertex> nodes = new PriorityQueue<>();

        for(String vertex : vertices.keySet()) {
            if (Objects.equals(vertex, start)) {
                distances.put(vertex, 0.0);
                nodes.add(new Vertex(vertex, 0.0));
            } else {
                distances.put(vertex, Double.MAX_VALUE);
                nodes.add(new Vertex(vertex, Double.MAX_VALUE));
            }
            previous.put(vertex, null);
        }

        while (!nodes.isEmpty()) {
            Vertex smallest = nodes.poll();
            if (Objects.equals(smallest.getId(), finish)) {
                final List<String> path = new ArrayList<>();
                while (previous.get(smallest.getId()) != null) {
                    path.add(smallest.getId());
                    smallest = previous.get(smallest.getId());
                }
                return path;
            }

            if (distances.get(smallest.getId()) == Integer.MAX_VALUE) {
                break;
            }

            for (Vertex neighbor : vertices.get(smallest.getId())) {
                Double alt = distances.get(smallest.getId()) + neighbor.getDistance();
                if (alt < distances.get(neighbor.getId())) {
                    distances.put(neighbor.getId(), alt);
                    previous.put(neighbor.getId(), smallest);

                    for(Vertex n : nodes) {
                        if (Objects.equals(n.getId(), neighbor.getId())) {
                            nodes.remove(n);
                            n.setDistance(alt);
                            nodes.add(n);
                            break;
                        }
                    }
                }
            }
        }
        return new ArrayList<>(distances.keySet());
    }
}

Vertex.java

/**
 * Created by Administrator on 2017/3/14.
 */

class Vertex implements Comparable<Vertex> {

    private String id;
    private Double distance;

    public Vertex(String id, Double distance) {
        super();
        this.id = id;
        this.distance = distance;
    }

    public String getId() {
        return id;
    }

    public Double getDistance() {
        return distance;
    }

    public void setId(String id) {
        this.id = id;
    }

    public void setDistance(Double distance) {
        this.distance = distance;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result
                + ((distance == null) ? 0 : distance.hashCode());
        result = prime * result + ((id == null) ? 0 : id.hashCode());
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        Vertex other = (Vertex) obj;
        if (distance == null) {
            if (other.distance != null)
                return false;
        } else if (!distance.equals(other.distance))
            return false;
        if (id == null) {
            if (other.id != null)
                return false;
        } else if (!id.equals(other.id))
            return false;
        return true;
    }

    @Override
    public String toString() {
        return "Vertex [id=" + id + ", distance=" + distance + "]";
    }

    @Override
    public int compareTo(Vertex o) {
        if (this.distance < o.distance)
            return -1;
        else if (this.distance > o.distance)
            return 1;
        else
            return this.getId().compareTo(o.getId());
    }

}

Demo.java

import weka.core.Instances;

import java.io.FileReader;
import java.io.IOException;

/**
 * Created by Administrator on 2017/3/15.
 */
public class Demo {
    public static void main(String[] args) throws IOException {
        Instances data = new Instances(new FileReader("Test/Manifold/cpu.arff"));
        ManifoldDistance manifold = new ManifoldDistance(data, 20, 2);
        manifold.build();
        for (double[] aMtx : manifold.getMatrix()) {
            for(double v : aMtx){
                System.out.print(v + "   ");
            }
            System.out.println();
        }

        System.out.println(manifold.getDistance("10", "71"));
        System.out.println(manifold.getDistance("71", "10"));
    }
}
上一篇下一篇

猜你喜欢

热点阅读