KD树维护平面点集求近邻

2019-07-17  本文已影响0人  学无止境1980

一道KD树模板题:HDU4347 The Closest M Points

AC代码:

#include <bits/stdc++.h>
using namespace std;

const int MAXN = 50000;
const int MAXK = 5;

int n, m, K, cmpIdx;
struct Coordinate {
    int d[MAXK];
    friend bool operator < (const Coordinate &a, const Coordinate &b) {
        return a.d[cmpIdx] < b.d[cmpIdx];
    }
} o, point[MAXN];

struct Node {
    int ls, rs;
    Coordinate coordinate;
} kdTree[MAXN];
int nodeCnt;

int calc_distance(const Coordinate &a, const Coordinate &b) {
    int d = 0;
    for (int i=0; i<K; i++) d += (a.d[i] - b.d[i]) * (a.d[i] - b.d[i]);
    return d;
}

struct distance_less {
    bool operator ()(const Coordinate &a, const Coordinate &b) {
        return calc_distance(a, o) < calc_distance(b, o);
    }
}; // 自定义比较函数,按照与o点的距离排序
priority_queue <Coordinate, vector<Coordinate>, distance_less > Q; // 与o的距离的大根堆
stack <Coordinate> S;

void push(const Coordinate &p) {
    Q.push(p);
    if (Q.size() > m) Q.pop();
} // 仅保留与o点距离较小的m个点

int build_kdTree(int l, int r, int idx) {
    if (l >= r) return 0;
    int m = (l + r) >> 1;
    cmpIdx = idx;
    nth_element(point + l, point + m, point + r); // 可以忽略坐标值相等的情况

    int ret = nodeCnt++;
    kdTree[ret].coordinate = point[m];
    kdTree[ret].ls = build_kdTree(l, m, (idx + 1) % K);
    kdTree[ret].rs = build_kdTree(m + 1, r, (idx + 1) % K);
    return ret;
}

void query_nn(int u, int idx) {
    cmpIdx = idx;
    int next_u = kdTree[u].ls, next_v = kdTree[u].rs;
    if (!(o < kdTree[u].coordinate)) swap(next_u, next_v);
    if (next_u) query_nn(next_u, (idx + 1) % K);

    /* 回溯 */
    push(kdTree[u].coordinate);
    int d = o.d[idx] - kdTree[u].coordinate.d[idx];
    if (next_v && calc_distance(Q.top(), o) >= d * d)
        query_nn(next_v, (idx + 1) % K);
}

int main() {
    while (scanf("%d %d", &n, &K) != EOF) {
        for (int i=0; i<n; i++) 
            for (int j=0; j<K; j++) scanf("%d", &point[i].d[j]);
        nodeCnt = 0;
        build_kdTree(0, n, 0);

        int t;
        scanf("%d", &t);
        while (t--) {
            for (int i=0; i<K; i++) scanf("%d", &o.d[i]);
            scanf("%d", &m);
            query_nn(0, 0);
            printf("the closest %d points are:\n", m);
            while (!Q.empty()) {
                S.push(Q.top());
                Q.pop();
            }
            while (!S.empty()) {
                for (int i=0; i<K; i++) 
                    printf(i < K-1 ? "%d " : "%d\n", S.top().d[i]);
                S.pop();
            }
        }
    }
    return 0;
}
上一篇 下一篇

猜你喜欢

热点阅读