N个数取Top K的平均复杂度O(N)的算法
n个数选top k,可以用快排剪枝来做。也就是说,进行一轮比较,给第n个数找位置(前面的都比它小,后面的都比它大),如果找到的位置i在k前面,就在i+1到n之间找top k-i;如果i在k后面,就在1到i-1之间找top k。
网上说算法复杂性是O(N),我感觉这也太快了。但是我不会证,既证明不了是O(N),也证明不了不是。于是写了个程序,画个图。视觉上看,应该是O(N),因为太直线了也。
public class TopK {
private int n, k;
private double[][] nSelK = null; // Only k*(n-k) elements are eventually computed.
public TopK(int n, int k){
this.n = n;
this.k = k;
}
/**
* @return Return this for convenience.
*/
public TopK compute() {
nSelK = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
nSelK[i][j] = -1.0;
}
}
counts(n, k);
return this;
}
/**
* Average counts of comparison when selecting top k from n.
* Only k >= 1 and n >= k is valid.
*/
private double counts(int n, int k) {
double r = nSelK[n-1][k-1];
if (r < 0) { // Get in first time. Compute.
r = 0.0;
if (n == 1 && k == 1) {
} else {
for (int i = 1; i < k; i++) {
r += counts(n - i, k - i);
}
for (int i = k + 1; i <= n; i++) {
r += counts(i - 1, k);
}
r = r/n + n-1;
}
nSelK[n-1][k-1] = r;
// System.out.printf("Got %d-%d: %f: \n", n, k, r);
}
return r;
}
public void printCol() {
for (int i = k; i <= n; i++) {
System.out.printf("%d\t%d\t%f \n", i, k, nSelK[i - 1][k - 1]);
}
}
public void printAll() {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= i; j++) {
System.out.printf("%.2f\t", nSelK[i - 1][j - 1]);
}
System.out.println();
}
}
public static void main(String[] args) {
int n = 100, k = 5;
new TopK(n, k).compute().printCol();
}
}
跑出来的结果是:
top5(n=100).png
n k comparison counts
5 5 5.433333
6 5 8.333333
7 5 11.185714
8 5 13.954762
9 5 16.646032
10 5 19.270635
11 5 21.838961
12 5 24.359668
13 5 26.839777
14 5 29.284965
15 5 31.699850
16 5 34.088220
17 5 36.453211
18 5 38.797446
19 5 41.123134
20 5 43.432155
21 5 45.726120
22 5 48.006420
23 5 50.274260
24 5 52.530697
25 5 54.776658
26 5 57.012962
27 5 59.240336
28 5 61.459428
29 5 63.670820
30 5 65.875031
31 5 68.072534
32 5 70.263754
33 5 72.449077
34 5 74.628856
35 5 76.803413
36 5 78.973041
37 5 81.138011
38 5 83.298571
39 5 85.454951
40 5 87.607364
41 5 89.756004
42 5 91.901054
43 5 94.042683
44 5 96.181048
45 5 98.316298
46 5 100.448567
47 5 102.577986
48 5 104.704674
49 5 106.828744
50 5 108.950302
51 5 111.069448
52 5 113.186275
53 5 115.300871
54 5 117.413321
55 5 119.523704
56 5 121.632093
57 5 123.738560
58 5 125.843171
59 5 127.945991
60 5 130.047079
61 5 132.146492
62 5 134.244286
63 5 136.340512
64 5 138.435220
65 5 140.528456
66 5 142.620266
67 5 144.710693
68 5 146.799778
69 5 148.887560
70 5 150.974076
71 5 153.059363
72 5 155.143455
73 5 157.226385
74 5 159.308185
75 5 161.388885
76 5 163.468513
77 5 165.547100
78 5 167.624670
79 5 169.701251
80 5 171.776867
81 5 173.851541
82 5 175.925298
83 5 177.998160
84 5 180.070147
85 5 182.141282
86 5 184.211583
87 5 186.281069
88 5 188.349761
89 5 190.417675
90 5 192.484829
91 5 194.551240
92 5 196.616924
93 5 198.681897
94 5 200.746174
95 5 202.809769
96 5 204.872698
97 5 206.934974
98 5 208.996610
99 5 211.057620
100 5 213.118015
看到结果后我反而会证了。因为显然斜率不超过4,并且过原点,也就是说只要证明比较次数f(n) < 4n,就证明了复杂度是O(n)的。归纳,已知f(n,k) < 4n对所有n=1到s-1及任何k<=n都成立,计算f(s,k) = s-1 + (1/s) * (f(s-1,k-1) + f(s-2,k-2) + ...... + f(s-k+1,1) + f(k+1,k) + f(k+2,k) + ...... + f(s-1,k)),按照归纳,其中
f(s-1,k-1) + f(s-2,k-2) + ...... + f(s-k+1,1) < 4(s-1) + 4(s-2) + ...... + 4(s-k+1) = 2(2s-k)(k-1)
f(k+1,k) + f(k+2,k) + ...... + f(s-1,k) < 4(k+1) + 4(k+2) + ...... + 4(s-1) = 2(s+k)(s-k-1)
这样,f(s,k) < s-1 + (1/s) * (2(2s-k)(k-1) + 2(s+k)(s-k-1)) = s-1 + (1/s)(2ss+4sk-4kk-6s) = s-7 + (1/s)(2ss+4sk-4kk)
其中2ss+4sk-4kk=-4(k-s/2)(k-s/2) + 3ss,在k=s/2时,取得最大值3ss。这样,
f(s,k) < s-7 + (1/s)*3ss = 4s-7 < 4s。
[证毕]
跑了多个参数,肉眼看f(n)没有超过3n的,所以有可能f(n)<3n也能证出来。