https://docs.google.com/document/d/1qxA2wps0IhVRWULulQ55W4SGPMu2AE5MkBB37h8Dr58/edit#heading=h.t18sa5iilvre
第三轮:
value: 9, 8, 6, 8, 7.
label: A, A, B, A,B
输入一些node, node里有value和label, 找出K个value最大的, 但每个label不能超过M个. 比如K=3, M=2, 答案是 9, 8, 7. 8 > 7, 但是A的数量不能超过2.
A, A, B A B
先sort by value再从大到小scan,复杂度O(nlogn).
面试官又写了另外三个其它可行的复杂度, A:O(n), B: O(nlogk), C: 不记得了. 问我想再试试哪个, 我选了O(nlogk), 感觉像是heap. 一开始想错了, 面试官直接说这样不行, 很多人都这样想过但就是不对. 然后就是全程带着我做, 还让我把算法的步骤写在白板上, 说是思路更清楚. 我当时都有点儿懵,没见过这样的... 再加上还在想着上一轮bug的事儿, 精力有点不太集中. 后来想出来了, 也不难. 对每个label用heap选出M个最大的, 然后再从这些用heap选出K个最大的, 最后就是O(nlogk). 中午带着吃饭的大哥说这人好像是欧洲那边的, style确实不太一样, 有点儿严厉, 但还有点儿可爱, 挺有意思. 讨论完时间不多了, 就让写了O(nlogn)的解法, 简单问了问如何测试.
我感觉这个O(N)的算法,很可能是用quick select,对每个label都求出第m大,然后再merge成一个新的数组,求第k大
Top k的value,受label限制
第三轮:
value: 9, 8, 6, 8, 7.
label: A, A, B, A,B
输入一些node, node里有value和label, 找出K个value最大的, 但每个label不能超过M个. 比如K=3, M=2, 答案是 9, 8, 7. 8 > 7, 但是A的数量不能超过2.
A, A, B A B
先sort by value再从大到小scan,复杂度O(nlogn).
面试官又写了另外三个其它可行的复杂度, A:O(n), B: O(nlogk), C: 不记得了. 问我想再试试哪个, 我选了O(nlogk), 感觉像是heap. 一开始想错了, 面试官直接说这样不行, 很多人都这样想过但就是不对. 然后就是全程带着我做, 还让我把算法的步骤写在白板上, 说是思路更清楚. 我当时都有点儿懵,没见过这样的... 再加上还在想着上一轮bug的事儿, 精力有点不太集中. 后来想出来了, 也不难. 对每个label用heap选出M个最大的, 然后再从这些用heap选出K个最大的, 最后就是O(nlogk). 中午带着吃饭的大哥说这人好像是欧洲那边的, style确实不太一样, 有点儿严厉, 但还有点儿可爱, 挺有意思. 讨论完时间不多了, 就让写了O(nlogn)的解法, 简单问了问如何测试.
思路:
如题主描述
code
class Solution:
def topk(self, elements, k, m):
'''
:param elements: [val, label]
:param k:
:param m:
:return:
'''
# O(nlogn)
def _topk0(elements, k, m):
elements = sorted(elements, reverse=True)
ans = []
labels = {}
for val, lb in elements:
if labels.get(lb, 0) >= m:
continue
ans.append(val)
labels[lb] = labels.get(lb, 0) + 1
if len(ans) == k:
break
return ans
# O(nlogk)
# heap for each label(len == m)
# merge heap
def _topk1(elements, k, m):
from collections import defaultdict
import heapq
labels = defaultdict(list)
for val, label in elements:
pq = labels[label]
heapq.heappush(pq, val)
if len(pq) > m:
heapq.heappop(pq)
que = []
for pq in labels.values():
for e in pq:
heapq.heappush(que, e)
if len(que) > k:
heapq.heappop(que)
return que
# O(n)
# 用quick select,对每个label都求出第m大,然后再merge成一个新的数组,求第k大
def _topk2(elements, k, m):
def quick_select(arr, lo, hi, p):
i, j, k = lo, lo, hi - 1
pi = arr[p]
while j <= k:
if arr[j] == pi:
j += 1
elif arr[j] > pi:
arr[i], arr[j] = arr[j], arr[i]
i += 1
j += 1
else:
arr[j], arr[k] = arr[k], arr[j]
k -= 1
return i
def kth(nums, m):
lo, hi = 0, len(nums)
while lo < hi:
mid = (lo + hi) >> 1
mid = quick_select(nums, lo, hi, mid)
if mid == m - 1:
break
elif mid < m - 1:
lo = mid + 1
else:
hi = mid
from collections import defaultdict
labels = defaultdict(list)
for val, label in elements:
labels[label].append(val)
ans = []
for nums in labels.values():
kth(nums, m)
ans.extend(nums[:m])
kth(ans, k)
return ans[:k]
ans0 = _topk0(elements, k, m)
ans1 = _topk1(elements, k, m)
ans2 = _topk2(elements, k, m)
assert( sorted(ans0) == sorted(ans1) == sorted(ans2))
import random
for _ in range(100):
elements = []
for i in range(30):
elements.append([random.randint(1, 100), random.randint(0, 4)])
k = random.randint(5, 15)
m = random.randint(2, 5)
so = Solution()
so.topk(elements, k, m)
|
我感觉这个O(N)的算法,很可能是用quick select,对每个label都求出第m大,然后再merge成一个新的数组,求第k大