左神算法笔记——BFPRT算法

2020-04-09  本文已影响0人  yaco

简介

通常我们需要在一大堆数中求前k大的数。比如在搜索引擎中求当天用户点击次数排名前10000的热词,在文本特征选择中求值按从大到小排名前k个文本
等问题,都涉及到一个核心问题,即TOP-K问题

那么这种问题就会有一个比较好的算法,叫做BFPTR算法,又称为中位数的中位数算法,它的最坏时间复杂度为O(N),它是由Blum、Floyd、Pratt、Rivest、Tarjan提出。该算法的思想是修改快速选择算法的主元选取方法,提高算法在最坏情况下的时间复杂度。
参考——BFPRT算法原理

问题描述

给定一个数组arr和k,返回数组中第k小的数

常规思路——堆排序

代码

    // ############################----------手撕BFPRT----------###################################
    // --------------------方法一: 堆排(对比BFPRT)---------------------------
    // 获取前k小的数构成数组返回
    public static int[] getMinKNumsByHeap(int[] arr, int k) {
        if(k < 1 || k > arr.length) return null;  // k越界返回null
        int[] minArr = new int[k];
        // 抽取arr的前k个元素构成大顶堆
        for (int i = 0; i < k; i++) {
            heapInsert(minArr,arr[i],i);
        }
        // 然后从arr数组的第k个元素开始
        for (int i = k; i < arr.length; i++) {
            // 如果当前元素小于堆顶,则插入
            if(arr[i] < minArr[0]) {
                minArr[0] = arr[i];
                // 调整数组,重新构造大顶堆
                heapfiy(minArr,0,k);
            }
        }
        return minArr;
    }

    // 调整数组为大顶堆
    private static void heapfiy(int[] minArr, int index, int heapSize) {
        int left = 2 * index + 1;
        int right = 2 * index + 2;
        int largest = index;
        while(left < heapSize){
            if(minArr[index] < minArr[left]){
                largest = left;
            }
            if(right < heapSize && minArr[right] > minArr[largest]){
                largest = right;
            }
            if(largest == index){
                // 没有调整
                break;
            }
            swap(minArr,index,largest);
            // 更新各个指针的位置
            index = largest;
            left = index * 2 + 1;
            right = index * 2 + 2;
        }
    }

    private static void heapInsert(int[] minArr, int value, int index) {
        minArr[index] = value;
        while(index != 0){
            int parent = (index - 1) / 2;
            if(minArr[index] > minArr[parent]){
                swap(minArr,index,parent);
                index = parent;
            }else{
                break;
            }
        }
    }

    private static void swap(int[] arr, int x, int y){
        int temp = arr[x];
        arr[x] = arr[y];
        arr[y] = temp;
    }

    
    //******************      测试         ************************
    public static void main(String[] args) {
        // 生成随机数组,用作测试
        int[] arr = ArrayTestUntil.generateRandomArray(10000,10000);

        long s1 = System.currentTimeMillis();
        int[] res = getMinKNumsByHeap(arr, 20);
        System.out.println(Arrays.toString(res));
        System.out.println("普通归并排序用时:  " + (System.currentTimeMillis() - s1) + "ms");
    }

BFPRT算法

基本思路:

    /**
     * 解法二: 利用BFPRT算法实现O(N)级别的时间复杂度
     */
    /**
     * 函数入口: 获取前k小的元素数组
     * @param arr
     * @param k
     * @return
     */
    public static int[] getMinKNumsByBFPRT(int[] arr, int k) {
        // 首先同样进行判断,如果K越界,则直接返回数组
        if (k < 1 || k > arr.length) {
            return arr;
        }
        // 获取第k小的元素值
        int minKth = getMinKthByBFPRT(arr, k);
        // 创建可以容纳k这么大的数组
        int[] res = new int[k];
        int index = 0;
        // 将数组元素向res中添
        for (int i = 0; i != arr.length; i++) {
            if (arr[i] < minKth) {
                res[index++] = arr[i];
            }
        }
        // 有可能走到arr[i] 刚好等于minKth的位置,则后面一路相等即可
        for (; index != res.length; index++) {
            res[index] = minKth;
        }
        // 返回结果
        return res;
    }

    /**
     * 获取数组中第k小的元素值(同时将数组元素已经按照中间的值拍好序了)
     * @param arr
     * @param K
     * @return
     */
    public static int getMinKthByBFPRT(int[] arr, int K) {
        // 利用复制好的数组执行
        int[] copyArr = copyArray(arr);
        // 挑选第k小的值
        return select(copyArr, 0, copyArr.length - 1, K - 1);
    }

    /**
     * 复制数组(不破坏原来数组的结构)
     * @param arr
     * @return
     */
    public static int[] copyArray(int[] arr) {
        int[] res = new int[arr.length];
        for (int i = 0; i != res.length; i++) {
            res[i] = arr[i];
        }
        return res;
    }

    /**
     * 从给定数组中挑选出中位数
     * @param arr      指定数组
     * @param begin    起始位置
     * @param end      终止位置
     * @param i        选取第i小的元素
     * @return
     */
    public static int select(int[] arr, int begin, int end, int i) {
        if (begin == end) {
            return arr[begin];
        }
        // 实现了一个递归调用,获取中位数(全局最好的pivot)
        int pivot = medianOfMedians(arr, begin, end);
        // 用这个中位数实现快排
        int[] pivotRange = partition(arr, begin, end, pivot);
        // 如果刚刚好查找的数就在中间位置,直接返回arr[i]
        if (i >= pivotRange[0] && i <= pivotRange[1]) {
            return arr[i];
        } else if (i < pivotRange[0]) {
            // 如果i位置小于less,则向左进行递归调用
            return select(arr, begin, pivotRange[0] - 1, i);
        } else {
            // 如果位置大于more,则向右进行递归调用
            return select(arr, pivotRange[1] + 1, end, i);
        }
    }

    /**
     * 快速获取中位数(全局最优的快排输入值)
     * @param arr
     * @param begin
     * @param end
     * @return
     */
    public static int medianOfMedians(int[] arr, int begin, int end) {
        // 数组总长度
        int num = end - begin + 1;
        // 每五个一组,查看是否有多余的数,有的话则单独成一位
        int offset = num % 5 == 0 ? 0 : 1;
        // 创建存储每五个数据排序后中位数的数组
        int[] mArr = new int[num / 5 + offset];
        // 遍历此数组
        for (int i = 0; i < mArr.length; i++) {
            // 当前mArr来源自原来数组中的起始位置
            int beginI = begin + i * 5;
            // 当前mArr来源自原来数组中的终止位置
            int endI = beginI + 4;
            // 计算出当前i位置5个数排序后的中位数
            mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
        }
        // 在这些中位数的点中,挑选出排好序之后的中位数返回
        return select(mArr, 0, mArr.length - 1, mArr.length / 2);
    }

    /**
     * 快排主体: 小于pivotValue放在左边,等于pivotValue放在中间,大于pivotValue放在右边
     * @param arr
     * @param begin
     * @param end
     * @param pivotValue    快排选取的元素
     * @return
     */
    public static int[] partition(int[] arr, int begin, int end, int pivotValue) {
        int small = begin - 1;
        int cur = begin;
        int big = end + 1;
        while (cur != big) {
            if (arr[cur] < pivotValue) {
                swap(arr, ++small, cur++);
            } else if (arr[cur] > pivotValue) {
                swap(arr, cur, --big);
            } else {
                cur++;
            }
        }
        int[] range = new int[2];
        range[0] = small + 1;
        range[1] = big - 1;
        return range;
    }

    /**
     * 获取数组排序后的中位数
     * @param arr
     * @param begin
     * @param end
     * @return
     */
    public static int getMedian(int[] arr, int begin, int end) {
        insertionSort(arr, begin, end);
        int sum = end + begin;
        int mid = (sum / 2) + (sum % 2);
        return arr[mid];
    }

    /**
     * 实现简单的插入排序
     * @param arr
     * @param begin
     * @param end
     */
    public static void insertionSort(int[] arr, int begin, int end) {
        for (int i = begin + 1; i != end + 1; i++) {
            for (int j = i; j != begin; j--) {
                if (arr[j - 1] > arr[j]) {
                    swap(arr, j - 1, j);
                } else {
                    break;
                }
            }
        }
    }

    /**
     * 交换
     * @param arr
     * @param index1
     * @param index2
     */
    public static void swap(int[] arr, int index1, int index2) {
        int tmp = arr[index1];
        arr[index1] = arr[index2];
        arr[index2] = tmp;
    }

    public static void printArray(int[] arr) {
        for (int i = 0; i != arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    public static void main(String[] args) {
        int[] arr = ArrayTestUntil.generateRandomArray(10000,10000);
        // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }


        long s1 = System.currentTimeMillis();
        printArray(getMinKNumsByHeap(arr, 2000));
        System.out.println("普通归并排序用时:  " + (System.currentTimeMillis() - s1) + "ms");

        long s2 = System.currentTimeMillis();
        printArray(getMinKNumsByBFPRT(arr, 2000));
        System.out.println("BFPRT用时:  " + (System.currentTimeMillis() - s2) + "ms");
    }

上一篇 下一篇

猜你喜欢

热点阅读