中位数的中位数 - Median of Medians,选取近似

2020-08-18  本文已影响0人  devilisdevil

问题描述

median of medians是一种中位数(近似)选取算法,常用于其他选择算法中(主要是QuickSelect算法中)进行pivot元素的选取。

在如QuickSelect这样的选择算法中,我们pivot元素的选取对于我们算法的效率有很大的影响,median of medians算法可以帮助我们在线性时间内选出中位数附近的元素,这使得QuickSelect的最坏时间复杂度由O(n^2)降为O(n)

算法分析

(主要来自自己对wiki的翻译和理解)

Median of medians算法也被称作BFPRT算法(由Blum、Floyd、Pratt、Rivest、Tarjan五人提出),可以对照我github上的代码或者wiki上的伪代码来看。我的代码基本是伪代码的C++直接翻译。

算法一共4个函数:

#ifndef BFPRT_H
#define BFPRT_H

/**
 * put the k-th element at L[k] (0 indexed)
 */
void select(int L[], int left, int right, int k);

/**
 * actual median of medians algo
 */
int pivot(int L[], int left, int right);

/**
 * three way partition: ---[=]=[=]+++
 */
int partition(int L[], int left, int right, int pivot_index, int k);

/*
 * <= 5 elements, insertion sort, and pick the middle as partition index
 */
int partition5(int L[], int left, int right);

#endif

pivot 函数

median of medians思想的核心即是在pivot函数,函数如下:

int pivot(int L[], int left, int right) {
  if (right - left < 5) return partition5(L, left, right);
  for (int i = left; i <= right; i += 5) {
    int sub_right = min(i + 4, right);
    int median5 = partition5(L, i, sub_right);
    swap(L[median5], L[left + (i - left) / 5]);
  }

  // approximate median index
  int mid = (right - left) / 10 + left + 1;
  select(L, left, left + (right - left) / 5, mid);
  return mid;
}

函数在[left, right]范围内寻找适合作为pivot的数(也就是median附近的数),返回这个数的下标。
如果少于等于5个元素,调用partition5获取,partition5使用插入排序,然后将中间元素作为pivot,返回其下标。
将范围内元素每5个一组进行划分,对每一组进行partition5操作,获取每一组的median,这一步也就是获取medians
最后用select从上一步获得的medians中进一步获取median,也就是median of medians

select函数

void select(int L[], int left, int right, int k) {
  while (true) {
    if (left == right) return;
    int pivot_index = pivot(L, left, right);
    pivot_index = partition(L, left, right, pivot_index, k);
    if (k == pivot_index)
      return;
    else if (k < pivot_index)
      right = pivot_index - 1;
    else
      left = pivot_index + 1;
  }
}

partition 函数

partition函数采用的是三路划分(three way partition),普通的两路划分也是可以的,三路书写跟两路都差不多简单,三路可以根据提供的k来判断返回最左边的pivot值的下标/最右边的pivot值的下标/中间的pivot值的下标 -- 直接就是k (看最后函数返回那几行懂了)

int partition(int L[], int left, int right, int pivot_index, int k) {
  int pivot_value = L[pivot_index];
  swap(L[pivot_index], L[right]);
  // ---[=]=[=]+++
  int store_index = left;
  for (int i = left; i < right; ++i) {
    if (L[i] < pivot_value) {
      swap(L[store_index++], L[i]);
    }
  }
  int store_index_eq = store_index;
  for (int i = store_index; i < right; ++i) {
    if (L[i] == pivot_value) {
      swap(L[store_index_eq++], L[i]);
    }
  }
  swap(L[right], L[store_index_eq]);
  if (k < store_index)
    return store_index;
  else if (k < store_index_eq)
    return k;
  else
    return store_index_eq;
}

partition5 函数

排序(这里采用的插入排序)然后返回中间元素的下标(left + right)/2即可

int partition5(int L[], int left, int right) {
  for (int i = left + 1; i <= right; ++i) {
    int j = i;
    while (j > left && L[j - 1] > L[j]) {
      swap(L[j - 1], L[j]);
      --j;
    }
  }
  return (left + right) / 2;  // return middle index (the median index)
}

最坏时间复杂度

分析最坏时间复杂度的关键在于知道选取的分割点pivot能够将[left, right]范围内的数最坏分成几比几。

(下面考虑最坏情况:选取的pivot使得下一次partition时元素尽可能多,即本次partition减少元素尽可能少)
pivot是median of medians of [left, right],设范围内有n个数,分成n/5份,n/10份的中位数小于pivotn/10份的中位数大于pivot,所以至少有3n/10个数小于pivot,所以最坏情况就是本次partition只能排除3n/10个元素,下一次partition需要在剩下的7n/10中寻找。

示例代码

详细源代码在: github,里面除了上面这样的几个函数分开的代码实现,还有一个简化的版本,将上面的几个函数合并到了一起,见文件BFPRT-simple.cpp

参考

上一篇下一篇

猜你喜欢

热点阅读