Range Tree C++实现

2023-08-24  本文已影响0人  叶迎宪

https://www.cs.umd.edu/class/fall2020/cmsc420-0201/Lects/lect17-range-tree.pdf
https://github.com/mocherson/range-tree 修改

#ifndef RANGE_TREE_2D_H
#define RANGE_TREE_2D_H 1

#include <vector>

struct MyFloatPoint
{
    float x;
    float y;
};

class RangeTree1D
{
public:
    friend class RangeTree2D;

    RangeTree1D();
    ~RangeTree1D();

public:
    bool Build(std::vector<size_t>& idx);

    void RangeQuery(MyFloatPoint& query_low, MyFloatPoint& query_high,
        std::vector<MyFloatPoint>& result, float ylow, float yhigh);
    void TraversalInorder(std::vector<MyFloatPoint>& result);
    bool IsLeaf();

protected:
    size_t keypos_;
    RangeTree1D* lchild_;
    RangeTree1D* rchild_;

    static std::vector<MyFloatPoint>* data_;
};


class RangeTree2D
{
public:
    RangeTree2D(const std::vector<MyFloatPoint>& in_data);
    ~RangeTree2D();

    bool BuildTree();
    void RangeQuery(MyFloatPoint &query_low, MyFloatPoint &query_high,
        std::vector<MyFloatPoint>& result);

protected:
    RangeTree2D();

    bool Build(std::vector<size_t>& idx);
    void RangeQuery(MyFloatPoint& query_low, MyFloatPoint& query_high,
        std::vector<MyFloatPoint>& result, float xlow, float xhigh);
    bool IsLeaf();

protected:
    size_t keypos_;
    RangeTree2D* lchild_;
    RangeTree2D* rchild_;
    RangeTree1D* ytree_;
    // 所有子孙树都共用同一份数据
    static std::vector<MyFloatPoint>  data_;
};

#endif
#include "RangeTree2D.h"
#include <numeric>
#include <algorithm>

// 根据实际数据的范围,定义x,y坐标的范围
const float kXMin = 0.0;
const float kXMax = 1000.0;
const float kYMin = 0.0;
const float kYMax = 1000.0;

std::vector<MyFloatPoint>  *RangeTree1D::data_ = NULL;
std::vector<MyFloatPoint>  RangeTree2D::data_;

typedef bool (*CompFun)(const MyFloatPoint&, const MyFloatPoint&);

static bool xless(const MyFloatPoint &l, const MyFloatPoint &r)
{
    return l.x < r.x;
}

static bool yless(const MyFloatPoint& l, const MyFloatPoint& r)
{
    return l.y < r.y;
}

template <typename T>
std::vector<size_t> SortIndex(const std::vector<T>& data, bool (*comp)(const T&, const T&)) 
{
    // initialize original index locations
    std::vector<size_t> idx(data.size());
    std::iota(idx.begin(), idx.end(), 0);
    //    for(int i=0;i<v.size();i++)  idx[i]=i;
        
    // sort indexes based on comparing values in v
    sort(idx.begin(), idx.end(), [&data, &comp](size_t i1, size_t i2) {return (*comp)(data[i1], data[i2]); });

    return idx;
}

template <typename T>
std::vector<size_t> SortIndex(const std::vector<T>& data, std::vector<size_t>& idx, 
    bool(*comp)(const T&, const T&)) 
{
    // sort indexes based on comparing values in v
    sort(idx.begin(), idx.end(), [&data, &comp](size_t i1, size_t i2) {return (*comp)(data[i1], data[i2]); });

    return idx;
}

/////////////////////////////////////////////
// class RangeTree1D
/////////////////////////////////////////////

RangeTree1D::RangeTree1D() :
    keypos_(-1), lchild_(NULL), rchild_(NULL)
{
}

RangeTree1D::~RangeTree1D()
{
    if (lchild_)
        delete lchild_;
    if (rchild_)
        delete rchild_;
}

bool RangeTree1D::Build(std::vector<size_t>& idx)
{
    size_t n = idx.size();
    if (n == 0)
        return false;

    size_t pos = !(n % 2) ? n / 2 - 1 : (n - 1) / 2;
    keypos_ = idx[pos];

    if (n == 1)
    {
        this->lchild_ = NULL;
        this->rchild_ = NULL;
    }
    else
    {
        std::vector<size_t> left, right;
        
        left.assign(idx.begin(), idx.begin() + pos + 1);
        right.assign(idx.begin() + pos + 1, idx.end());

        this->lchild_ = new RangeTree1D();
        lchild_->Build(left);
        this->rchild_ = new RangeTree1D();
        rchild_->Build(right);
    }

    return true;
}

void RangeTree1D::RangeQuery(MyFloatPoint& query_low, MyFloatPoint& query_high,
    std::vector<MyFloatPoint>& result, float ylow, float yhigh)
{
    if (this->IsLeaf()) // hit the leaf level?
    {
        float leaf_x = (*data_)[keypos_].x;
        float leaf_y = (*data_)[keypos_].y;

        // count if point in range
        if (leaf_y >= query_low.y && leaf_y <= query_high.y)
        {
            MyFloatPoint point;
            point.x = leaf_x;
            point.y = leaf_y;
            result.push_back(point);
        }
    }
    else if (query_low.y <= ylow && query_high.y >= yhigh) // Query contains entire cell?
    {
        TraversalInorder(result);
        return; // return entire subtree size
    }
    else if (query_high.y < ylow || query_low.y > yhigh) // no overlap
    {
        return;
    }
    else
    {
        float leaf_y = (*data_)[keypos_].y;

        // count left side      
        lchild_->RangeQuery(query_low, query_high, result, ylow, leaf_y);
        // count right side         
        rchild_->RangeQuery(query_low, query_high, result, leaf_y, yhigh);
    }
}

void RangeTree1D::TraversalInorder(std::vector<MyFloatPoint>& result)
{
    if (this->IsLeaf())
        result.push_back((*data_)[keypos_]);
    else
    {
        lchild_->TraversalInorder(result);
        rchild_->TraversalInorder(result);
    }
}

bool RangeTree1D::IsLeaf()
{
    return (lchild_ == NULL && rchild_ == NULL);
}

/////////////////////////////////////////////
// class RangeTree2D
/////////////////////////////////////////////

RangeTree2D::RangeTree2D(const std::vector<MyFloatPoint>& in_data)
    : lchild_(NULL), rchild_(NULL), ytree_(NULL)
{
    data_ = in_data;
    RangeTree1D::data_ = &data_;
}

RangeTree2D::RangeTree2D() 
    : lchild_(NULL), rchild_(NULL), ytree_(NULL)
{
}

RangeTree2D::~RangeTree2D()
{
    if (lchild_)
        delete lchild_;
    if (rchild_)
        delete rchild_;
    if (ytree_)
        delete ytree_;
}

bool RangeTree2D::BuildTree()
{
    std::vector<size_t> idx = SortIndex(data_, xless);
    
    return Build(idx);
}

bool RangeTree2D::Build(std::vector<size_t>& idx)
{
    size_t n = idx.size();
    if (n == 0)
        return false;

    size_t pos = !(n % 2) ? n / 2 - 1 : (n - 1) / 2;
    keypos_ = idx[pos];

    if (n == 1)
    {
        this->lchild_ = NULL;
        this->rchild_ = NULL;
    }
    else
    {
        std::vector<size_t> left, right;

        left.assign(idx.begin(), idx.begin() + pos + 1);
        right.assign(idx.begin() + pos + 1, idx.end());

        this->lchild_ = new RangeTree2D();
        lchild_->Build(left);
        this->rchild_ = new RangeTree2D();
        rchild_->Build(right);
    }

    // build y-tree.
    this->ytree_ = new RangeTree1D();
    SortIndex(data_, idx, yless);
    ytree_->Build(idx);

    return true;
}

void RangeTree2D::RangeQuery(MyFloatPoint &query_low, MyFloatPoint &query_high,
    std::vector<MyFloatPoint>& result)
{
    result.clear();

    RangeQuery(query_low, query_high, result, kXMin, kXMax);
}

void RangeTree2D::RangeQuery(MyFloatPoint &query_low, MyFloatPoint &query_high,
    std::vector<MyFloatPoint>& result, float xlow, float xhigh)
{
    if (this->IsLeaf()) // hit the leaf level?
    {
        float leaf_x = data_[keypos_].x;
        float leaf_y = data_[keypos_].y;

        // count if point in range
        if (leaf_x >= query_low.x && leaf_x <= query_high.x
            && leaf_y >= query_low.y && leaf_y <= query_high.y)
        {
            MyFloatPoint point;
            point.x = leaf_x;
            point.y = leaf_y;
            result.push_back(point);
        }
    }
    else if (query_low.x <= xlow && query_high.x >= xhigh)
    {   // Query’s x-range contains C
        // search auxiliary tree
        ytree_->RangeQuery(query_low, query_high, result, kYMin, kYMax);
    }
    else if (query_high.x < xlow || query_low.x > xhigh) // no overlap
    {
        return;
    }
    else
    {
        float leaf_x = data_[keypos_].x;

        // count left side
        lchild_->RangeQuery(query_low, query_high, result, xlow, leaf_x);
        // count right side
        rchild_->RangeQuery(query_low, query_high, result, leaf_x, xhigh);
    }
}

bool RangeTree2D::IsLeaf()
{
    return (lchild_ == NULL && rchild_ == NULL);
}
上一篇下一篇

猜你喜欢

热点阅读