【Python 秘籍】找到最大或最小的 N 个元素

问题

我们想在某个集合中找出最大或最小的 N 个元素。

解决方案

heapq 模块中有两个函数——nlargest()和 nsmallest()——它们正是我们需要的。例如:

import heapq

nums = [1, 8, 2, 23, 7, -4, 18, 23, 42, 37, 2]
print(heapq.nlargest(3, nums))  # 打印结果[42, 37, 23]
print(heapq.nsmallest(3, nums))  # 打印结果[-4, 1, 2]

这两个函数都可以接受一个参数 key,从而允许它们工作在更加复杂的数据结构之上。例如:

portfolio = [
    {'name': 'IBM', 'shares': 100, 'price': 135.79},
    {'name': 'AAPL', 'shares': 50, 'price': 208.97},
    {'name': 'FB', 'shares': 200, 'price': 188.45},
    {'name': 'HPQ', 'shares': 35, 'price': 19.54},
    {'name': 'INTC', 'shares': 45, 'price': 46.84},
    {'name': 'AMZN', 'shares': 75, 'price': 1824.34}
]

cheap = heapq.nsmallest(3, portfolio, key=lambda s: s['price'])
expensive = heapq.nlargest(3, portfolio, key=lambda s: s['price'])

高级用法

如果正在寻找最大或最小的 N 个元素,且这个 N 相对集合来说很小,那么下面这些函数可以提供更好的性能。这些函数首先会在底层将数据转化成列表,且元素会以堆的顺序排列。例如:

>>> nums = {1, 8, 2, 23, 7, -4, 18, 23, 42, 37, 2}
>>> import heapq
>>> heap = list(nums)
>>> heapq.heapify(heap)  # 以O(N)复杂度将一个列表原地修改为堆结构
>>> heap
[-4, 1, 18, 2, 8, 42, 37, 23, 7]
>>>

堆最重要的特性就是 heap[0] 总是最小或者最大的那个元素(heapq 模块实现的是最小堆,所以 heap[0] 为最小元素)。此外,接下来的元素可依次通过 heapq.heappop()方法轻松找到。该方法会将第一个元素(最小的)弹出,然后以第二小的元素取而代之(这个操作的复杂度是 O(logN),N 代表堆的大小)。例如,要找到第 3 小的元素,可以这样做:

>>> heapq.heappop(heap)
-4
>>> heapq.heappop(heap)
1
>>> heapq.heappop(heap)
2

当所要找的元素数量相对较小时,函数 nlargest()和 nsmallest() 才是最适用的。如果只是简单地想找到最小或最大的元素(N=1 时),那么用 min()和 max() 会更加快。同样,如果 N 和集合本身的大小差不多大,通常更快的方法是先对容器对象排序,然后做切片操作(例如,使用 sorted(items)[:N] 或者 sorted(items)[-N:])。应该要注意的是,nlargest()和 nsmallest() 的实际实现会根据使用它们的方式而有所不同,可能会相应作出一些优化措施(比如,当 N 的大小同输入大小很接近时,就会采用排序的方法)。

通常在优秀的算法和数据结构相关的书籍里都能找到堆数据结构的实现方法。另外,可以参考 heapq 模块,深入了解其底层实现的细节。

补充说明

当这个集合非常大,超过可用内存空间时,可以使用一个固定大小的堆来解决问题。例如:我们需要求大量数据中的最大的 k 个元素,我们可以用最小堆先迭代前 k 个元素来建立一个最小堆,之后的每一个元素如果小于堆顶则跳过,否则替换堆顶元素并重新调整堆。

import heapq

class TopK:
    def __init__(self, iter_obj, k):
        self.minheap = []
        self.maxsize = k
        self.iter_obj = iter_obj
		
    def push(self, item):
        if len(self.minheap) >= self.maxsize:
            min_item = self.minheap[0]
            if item > min_item:
                heapq.heapreplace(self.minheap, item)
        else:
            heapq.heappush(self.minheap, item)
			
    def get_topk(self):
        for item in self.iter_obj:
            self.push(item)
        return self.minheap

对于一个直接 list 化会报 MemoryError 错误的迭代器,我们可以如下调用:

print(sorted(TopK(range(100000001), 100).get_topk()))