快速排序 --- 基础实践篇

2019-03-05  本文已影响0人  张虾米试错

本篇主要介绍快速排序的基本代码。

大纲

  1. 普通版的快速排序
  2. 改进版的快速排序
  3. 快速排序应用----求前K个最大的数

1. 普通版的快速排序

# -*- coding:utf-8 -*-

def partition(nums, left, right):
    low = left
    high = right
    privot = nums[low]
    while low < high:
        while low < high and nums[high] >= privot:
            high -= 1
        if low < high:
            nums[low], nums[high] = nums[high], nums[low]
        while low < high and nums[low] <= privot:
            low += 1
        if low < high:
            nums[high], nums[low] = nums[low], nums[high]
    #nums[low] = privot
    return low


def partition1(nums, left, right):
    low = left
    high = right
    privot = nums[low]
    while low < high:
        while low < high and nums[high] >= privot:
            high -= 1
        while low < high and nums[low] <= privot:
            low += 1

        if low < high:
            nums[low], nums[high] = nums[high], nums[low]

    nums[left], nums[low] = nums[low], nums[left]
    return low


def quick_sort(nums, left, right):
    if nums is None or len(nums) == 0:
        return 
    # 这里的判断条件已经是判断是否调出递归
    if left < right: 
        ind = partition(nums, left, right)
        quick_sort(nums, left, ind-1)#ind-1可能会更好
        quick_sort(nums, ind+1, right)

2. 改进版的快速排序

普通版本存在的问题:

Q1: 数组长度比较短

研究表明长度在5~25的数组,快排不如插入排序。所以可以对数组长度进行判断,然后再选择用哪一种排序方法。

Q2: pivot选择

当pivot选择不当,左右两边不平衡,这样导致快排的时间复杂度为o(n^2)。有一种方法是选取随机数作为枢轴,但是随机数的生成本身是一种代价,根本减少不了算法其余部分的平均运行时间。因此,我们可以使用左端,右端和中心的中值做为枢轴元。经验得知,选取左端,右端,中心元素的中值会减少了快排大约 14%的比较。

Q3: 数组中重复元素太多

大致的思想就是将数组分为三段:小于pviot,等于pviot,大于pviot

方案1
def partition(nums, left, right):
    q = left
    pivot = nums[right]
    # 这一段代码是找出第一个大于pviot的index
    # 右界范围也可以是right+1,不过因为pviot是nums[right]所以该代码上没有意义;但是如果pviot不是nums[right]则应该换成right+1
    for i in xrange(left, right):
        if nums[i] < pivot:
            nums[q], nums[i] = nums[i], nums[q]
            q += 1
    # q表示第一个大于等于pviot的index,因此一定会小于或等于right
    # 这一段代码跳过等于pviot的数
    t = q
    while t < right and nums[t] == pivot:
        t += 1
    # 这一段代码将大于pviot的数移到后面
    i = right
    while i > t - 1:
        if nums[i] == pivot:
            nums[t], nums[i] = nums[i], nums[t]
            t += 1
        i -= 1
    return q, t

def quick_sort(nums, left, right):
    if left < right:
        q, t = partition(nums, left, right)
        # q是第一个等于pivot的index,而t是第一个大于pivot的index
        quick_sort(nums, left, q-1)
        # 当q-1变成q后,竟然会陷入死循环???
        quick_sort(nums, t, right)
        
if __name__ == "__main__":
    nums = [6, 4, 2, 6, 8, 9, 5, 6, 7, 6]
    quick_sort(nums, 0, len(nums)-1)
    print nums

由于可能存在重复的元素,因此无法像普通的快速排序一样进行值对换;而这种遍历的方法与普通值对换的方法相比,可能就是多了一些交换的代价。其本质是一样的,都是把小于pviot的值换到左边,大于pviot的值换到右边。

好像确实如此哦,因为pviot是当前right的值,如果一直是q的话,那么pviot一直是最开始的pviot,那么前半段就会陷入死循环。

方案2
import random
class Solution:
    def quickSortArray(self, nums):
        def partition(nums, start, end):
            rand = random.randint(start, end)
            nums[rand], nums[start] = nums[start], nums[rand]
            key = nums[start]
            left, right = start, end 
            while left < right:
                while left < right and nums[right] >= key:
                    right -= 1
                nums[right], nums[left] = nums[left], nums[right]
                while left < right and nums[left]<=key:
                    left += 1
                nums[right], nums[left] = nums[left], nums[right]
            return left
            
        def sort(nums, left, right):
            if left < right:
                mid = partition(nums, left, right)
                sort(nums, left, mid-1)
                sort(nums, mid+1, right)
            
        n = len(nums)
        if n <=1:
            return nums
        #random.shuffle(nums)
        sort(nums, 0, n-1)
        return nums

if __name__ == "__main__":
    s = Solution()
    nums = [5,2,1,1,2,3,3,2,1]
    ans = s.quickSortArray(nums)
    print (ans)           

3. 快速排序应用----求前K个最大的数

def partition(nums, left, right):
    pviot = nums[right]
    p = left
    for i in xrange(left, right):
        # 由于本题是找第K大的数,因此这里是>
        if nums[i] > pviot:
            nums[i], nums[p] = nums[p], nums[i]
            p += 1
    nums[p], nums[right] = nums[right], nums[p]
    return p

def findKthLargest2(nums, k):
    if len(nums) == 1:
        return nums[0]
    left = 0
    right = len(nums) - 1
    p = partition(nums, left, right)
    while left < right:
        if k == p + 1:
            return nums[p]
        elif k < p + 1:
            return findKthLargest2(nums[left:p], k)
        else:
            return findKthLargest2(nums[p+1:right+1], k-p-1)

def findKthLargest3(nums, k):
    if len(nums) == 1:
        return nums[0]
    size = len(nums)
    left = 0
    right = len(nums) - 1
    while left <= right:
        p = partition(nums, left, right)
        if k == p + 1
            return nums[p]
        elif k > p + 1:
            left = p + 1
        else:
            right = p - 1


if __name__ == "__main__":
    nums = [3, 2, 1, 5, 6, 4]
    k = 2
    print findKthLargest2(nums, k)

参考资料

上一篇 下一篇

猜你喜欢

热点阅读