数据结构和算法分析

选择第k小元素(分治法)

2019-11-27  本文已影响0人  张的笔记本

1、伪码

image.png
原理视频:https://www.bilibili.com/video/av7134874?p=29
2、实现
typedef pair<int, int> pairs;
/*
 * 选择第k小元素
 */ 
int Select(vector<int> a, int k)//需要保证k的范围合理
{   
    int temp1 = a.size();
    vector<int>::iterator iter = a.begin();
    
    if(temp1 < 5){
        sort(iter, iter + temp1);
        return *(iter + k - 1);
    }
    
    int temp2, temp3, i, m, n, last, sdn, sxn/*大于或小于m的个数*/;
    vector<pairs> b;
    vector<pairs>::iterator iter1;
    vector<int> c, sd, sx;
    
    n = temp1 / 5;
    last = temp1 % 5;//不是5的倍数
    
    for(i = 0; i < n; ++i){
        temp2 = i * 5;
        sort(iter + temp2, iter + temp2 + 5);
        b.push_back(make_pair(*(iter + temp2 + 2), i + 1));
    }
    
    /*
    for_each(a.begin(), a.end(), Show);
    cout<<endl;
    for(vector<pairs>::iterator m = b.begin(); m != b.end(); ++m){
        cout<<"("<<m -> first<<","<<m -> second<<")";
    }
    cout<<endl;
    */
    
    for(iter1 = b.begin();iter1 != b.end(); ++iter1)
        c.push_back(iter1 -> first);
    m = Select(c, b.size() / 2 + 1);
    
    //cout<<m<<endl;
    
    (n % 2) ? (sdn = sxn = (n / 2 + 1) * 3 - 1) : (sxn = (n / 2 + 1) * 3 - 1, sdn = n / 2 * 3 - 1);

    //处理不能整除5而剩余的元素
    for(i = 0, iter = a.end() - 1; i < last; ++i){
        if(*(iter - i) >= m){
            ++sdn;
            sd.push_back(*(iter - i));
        }
        else{
            ++sxn;
            sx.push_back(*(iter - i));
        }
    }
    
    
    for(iter1 = b.begin(); iter1 != b.end(); ++iter1){
        temp3 = ((iter1 -> second) - 1) * 5 + 2;
        iter = a.begin() + temp3;
        if(iter1 -> first < m){
            sx.push_back(*iter);sx.push_back(*(iter - 1));sx.push_back(*(iter - 2));
            for(i = 1; i <= 2; ++i){
                if(*(iter + i) >= m){
                    ++sdn;
                    sd.push_back(*(iter + i));
                }
                else{
                    ++sxn;
                    sx.push_back(*(iter + i));
                }
            }
        }
        else if(iter1 -> first == m){
            sx.push_back(*(iter - 1));sx.push_back(*(iter - 2));
            sd.push_back(*(iter + 1));sd.push_back(*(iter + 2));
        }
        else{
            sd.push_back(*iter);sd.push_back(*(iter + 1));sd.push_back(*(iter + 2));
            for(i = -2; i <= -1; ++i){
                if(*(iter + i) >= m){
                    ++sdn;
                    sd.push_back(*(iter + i));
                }
                else{
                    ++sxn;
                    sx.push_back(*(iter + i));
                }
            }
        }
    }
    
    /*
    cout<<sdn<<endl;
    for_each(sd.begin(), sd.end(), Show);
    cout<<endl;
    cout<<sxn<<endl;
    for_each(sx.begin(), sx.end(), Show);
    cout<<endl;
    */
    if(k <= sxn)
        return Select(sx, k);
    else if(k == sxn + 1)
        return m;
    else
        return Select(sd, k - sxn - 1);
}

原理参见 屈婉玲老师 算法设计与分析 ORZ

上一篇 下一篇

猜你喜欢

热点阅读