import heapq as hq import itertools import numpy as np from numba import jit, typed from numba.core.compiler import Flags from numba.tests.support import TestCase, CompilationCache, MemoryLeakMixin no_pyobj_flags = Flags() no_pyobj_flags.nrt = True def heapify(x): return hq.heapify(x) def heappop(heap): return hq.heappop(heap) def heappush(heap, item): return hq.heappush(heap, item) def heappushpop(heap, item): return hq.heappushpop(heap, item) def heapreplace(heap, item): return hq.heapreplace(heap, item) def nsmallest(n, iterable): return hq.nsmallest(n, iterable) def nlargest(n, iterable): return hq.nlargest(n, iterable) class _TestHeapq(MemoryLeakMixin): def setUp(self): super(_TestHeapq, self).setUp() self.ccache = CompilationCache() self.rnd = np.random.RandomState(42) def test_heapify_basic_sanity(self): pyfunc = heapify cfunc = jit(nopython=True)(pyfunc) a = [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] b = self.listimpl(a) pyfunc(a) cfunc(b) self.assertPreciseEqual(a, list(b)) # includes non-finite elements element_pool = [3.142, -10.0, 5.5, np.nan, -np.inf, np.inf] # list which may contain duplicate elements for x in itertools.combinations_with_replacement(element_pool, 6): a = list(x) b = self.listimpl(a) pyfunc(a) cfunc(b) self.assertPreciseEqual(a, list(b)) # single element list for i in range(len(element_pool)): a = [element_pool[i]] b = self.listimpl(a) pyfunc(a) cfunc(b) self.assertPreciseEqual(a, list(b)) # elements are tuples a = [(3, 33), (1, 11), (2, 22)] b = self.listimpl(a) pyfunc(a) cfunc(b) self.assertPreciseEqual(a, list(b)) def check_invariant(self, heap): for pos, item in enumerate(heap): if pos: parentpos = (pos - 1) >> 1 self.assertTrue(heap[parentpos] <= item) def test_push_pop(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc_heappush = heappush cfunc_heappush = jit(nopython=True)(pyfunc_heappush) pyfunc_heappop = heappop cfunc_heappop = jit(nopython=True)(pyfunc_heappop) heap = self.listimpl([-1.0]) data = self.listimpl([-1.0]) self.check_invariant(heap) for i in range(256): item = self.rnd.randn(1).item(0) data.append(item) cfunc_heappush(heap, item) self.check_invariant(heap) results = [] while heap: item = cfunc_heappop(heap) self.check_invariant(heap) results.append(item) data_sorted = data[:] data_sorted.sort() self.assertPreciseEqual(list(data_sorted), results) self.check_invariant(results) def test_heapify(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc = heapify cfunc = jit(nopython=True)(pyfunc) for size in list(range(1, 30)) + [20000]: heap = self.listimpl(self.rnd.random_sample(size)) cfunc(heap) self.check_invariant(heap) def test_heapify_exceptions(self): pyfunc = heapify cfunc = jit(nopython=True)(pyfunc) # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc((1, 5, 4)) msg = 'heap argument must be a list' self.assertIn(msg, str(e.exception)) with self.assertTypingError() as e: cfunc(self.listimpl([1 + 1j, 2 - 3j])) msg = ("'<' not supported between instances " "of 'complex' and 'complex'") self.assertIn(msg, str(e.exception)) def test_heappop_basic_sanity(self): pyfunc = heappop cfunc = jit(nopython=True)(pyfunc) def a_variations(): yield [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] yield [(3, 33), (1, 111), (2, 2222)] yield np.full(5, fill_value=np.nan).tolist() yield np.linspace(-10, -5, 100).tolist() for a in a_variations(): heapify(a) b = self.listimpl(a) for i in range(len(a)): val_py = pyfunc(a) val_c = cfunc(b) self.assertPreciseEqual(a, list(b)) self.assertPreciseEqual(val_py, val_c) def test_heappop_exceptions(self): pyfunc = heappop cfunc = jit(nopython=True)(pyfunc) # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc((1, 5, 4)) msg = 'heap argument must be a list' self.assertIn(msg, str(e.exception)) def iterables(self): yield self.listimpl([1, 3, 5, 7, 9, 2, 4, 6, 8, 0]) a = np.linspace(-10, 2, 23) yield self.listimpl(a) yield self.listimpl(a[::-1]) self.rnd.shuffle(a) yield self.listimpl(a) def test_heappush_basic(self): pyfunc_push = heappush cfunc_push = jit(nopython=True)(pyfunc_push) pyfunc_pop = heappop cfunc_pop = jit(nopython=True)(pyfunc_pop) for iterable in self.iterables(): expected = sorted(iterable) heap = self.listimpl([iterable.pop(0)]) # must initialise heap for value in iterable: cfunc_push(heap, value) got = [cfunc_pop(heap) for _ in range(len(heap))] self.assertPreciseEqual(expected, got) def test_heappush_exceptions(self): pyfunc = heappush cfunc = jit(nopython=True)(pyfunc) # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc((1, 5, 4), 6) msg = 'heap argument must be a list' self.assertIn(msg, str(e.exception)) with self.assertTypingError() as e: cfunc(self.listimpl([1, 5, 4]), 6.0) msg = 'heap type must be the same as item type' self.assertIn(msg, str(e.exception)) def test_nsmallest_basic(self): pyfunc = nsmallest cfunc = jit(nopython=True)(pyfunc) for iterable in self.iterables(): for n in range(-5, len(iterable) + 3): expected = pyfunc(1, iterable) got = cfunc(1, iterable) self.assertPreciseEqual(expected, got) # n is boolean out = cfunc(False, self.listimpl([3, 2, 1])) self.assertPreciseEqual(out, []) out = cfunc(True, self.listimpl([3, 2, 1])) self.assertPreciseEqual(out, [1]) # iterable is not a list out = cfunc(2, (6, 5, 4, 3, 2, 1)) self.assertPreciseEqual(out, [1, 2]) out = cfunc(3, np.arange(6)) self.assertPreciseEqual(out, [0, 1, 2]) def test_nlargest_basic(self): pyfunc = nlargest cfunc = jit(nopython=True)(pyfunc) for iterable in self.iterables(): for n in range(-5, len(iterable) + 3): expected = pyfunc(1, iterable) got = cfunc(1, iterable) self.assertPreciseEqual(expected, got) # n is boolean out = cfunc(False, self.listimpl([3, 2, 1])) self.assertPreciseEqual(out, []) out = cfunc(True, self.listimpl([3, 2, 1])) self.assertPreciseEqual(out, [3]) # iterable is not a list out = cfunc(2, (6, 5, 4, 3, 2, 1)) self.assertPreciseEqual(out, [6, 5]) out = cfunc(3, np.arange(6)) self.assertPreciseEqual(out, [5, 4, 3]) def _assert_typing_error(self, cfunc): # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc(2.2, self.listimpl([3, 2, 1])) msg = "First argument 'n' must be an integer" self.assertIn(msg, str(e.exception)) with self.assertTypingError() as e: cfunc(2, 100) msg = "Second argument 'iterable' must be iterable" self.assertIn(msg, str(e.exception)) def test_nsmallest_exceptions(self): pyfunc = nsmallest cfunc = jit(nopython=True)(pyfunc) self._assert_typing_error(cfunc) def test_nlargest_exceptions(self): pyfunc = nlargest cfunc = jit(nopython=True)(pyfunc) self._assert_typing_error(cfunc) def test_heapreplace_basic(self): pyfunc = heapreplace cfunc = jit(nopython=True)(pyfunc) a = [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] heapify(a) b = self.listimpl(a) for item in [-4, 4, 14]: pyfunc(a, item) cfunc(b, item) self.assertPreciseEqual(a, list(b)) a = np.linspace(-3, 13, 20) a[4] = np.nan a[-1] = np.inf a = a.tolist() heapify(a) b = self.listimpl(a) for item in [-4.0, 3.142, -np.inf, np.inf]: pyfunc(a, item) cfunc(b, item) self.assertPreciseEqual(a, list(b)) def test_heapreplace_exceptions(self): pyfunc = heapreplace cfunc = jit(nopython=True)(pyfunc) # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc((1, 5, 4), -1) msg = 'heap argument must be a list' self.assertIn(msg, str(e.exception)) with self.assertTypingError() as e: cfunc(self.listimpl([1, 5, 4]), -1.0) msg = 'heap type must be the same as item type' self.assertIn(msg, str(e.exception)) def heapiter(self, heap): try: while 1: yield heappop(heap) except IndexError: pass def test_nbest(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py cfunc_heapify = jit(nopython=True)(heapify) cfunc_heapreplace = jit(nopython=True)(heapreplace) data = self.rnd.choice(range(2000), 1000).tolist() heap = self.listimpl(data[:10]) cfunc_heapify(heap) for item in data[10:]: if item > heap[0]: cfunc_heapreplace(heap, item) self.assertPreciseEqual(list(self.heapiter(list(heap))), sorted(data)[-10:]) def test_heapsort(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py cfunc_heapify = jit(nopython=True)(heapify) cfunc_heappush = jit(nopython=True)(heappush) cfunc_heappop = jit(nopython=True)(heappop) for trial in range(100): # Ensure consistency of typing, use float64 as it's double # everywhere values = np.arange(5, dtype=np.float64) data = self.listimpl(self.rnd.choice(values, 10)) if trial & 1: heap = data[:] cfunc_heapify(heap) else: heap = self.listimpl([data[0]]) for item in data[1:]: cfunc_heappush(heap, item) heap_sorted = [cfunc_heappop(heap) for _ in range(10)] self.assertPreciseEqual(heap_sorted, sorted(data)) def test_nsmallest(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc = nsmallest cfunc = jit(nopython=True)(pyfunc) data = self.listimpl(self.rnd.choice(range(2000), 1000)) for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): self.assertPreciseEqual(list(cfunc(n, data)), sorted(data)[:n]) def test_nlargest(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc = nlargest cfunc = jit(nopython=True)(pyfunc) data = self.listimpl(self.rnd.choice(range(2000), 1000)) for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): self.assertPreciseEqual(list(cfunc(n, data)), sorted(data, reverse=True)[:n]) def test_nbest_with_pushpop(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc_heappushpop = heappushpop cfunc_heappushpop = jit(nopython=True)(pyfunc_heappushpop) pyfunc_heapify = heapify cfunc_heapify = jit(nopython=True)(pyfunc_heapify) # Ensure consistency of typing, use float64 as it's double everywhere values = np.arange(2000, dtype=np.float64) data = self.listimpl(self.rnd.choice(values, 1000)) heap = data[:10] cfunc_heapify(heap) for item in data[10:]: cfunc_heappushpop(heap, item) self.assertPreciseEqual(list(self.heapiter(list(heap))), sorted(data)[-10:]) def test_heappushpop(self): # inspired by # https://github.com/python/cpython/blob/e42b7051/Lib/test/test_heapq.py pyfunc = heappushpop cfunc = jit(nopython=True)(pyfunc) h = self.listimpl([1.0]) x = cfunc(h, 10.0) self.assertPreciseEqual((list(h), x), ([10.0], 1.0)) self.assertPreciseEqual(type(h[0]), float) self.assertPreciseEqual(type(x), float) h = self.listimpl([10]) x = cfunc(h, 9) self.assertPreciseEqual((list(h), x), ([10], 9)) h = self.listimpl([10]) x = cfunc(h, 11) self.assertPreciseEqual((list(h), x), ([11], 10)) def test_heappushpop_exceptions(self): pyfunc = heappushpop cfunc = jit(nopython=True)(pyfunc) # Exceptions leak references self.disable_leak_check() with self.assertTypingError() as e: cfunc((1, 5, 4), -1) msg = 'heap argument must be a list' self.assertIn(msg, str(e.exception)) with self.assertTypingError() as e: cfunc(self.listimpl([1, 5, 4]), False) msg = 'heap type must be the same as item type' self.assertIn(msg, str(e.exception)) class TestHeapqReflectedList(_TestHeapq, TestCase): """Test heapq with reflected lists""" listimpl = list class TestHeapqTypedList(_TestHeapq, TestCase): """Test heapq with typed lists""" listimpl = typed.List