import copy import itertools import math import random import sys import numpy as np from numba.core.compiler import compile_isolated, Flags from numba import jit, njit from numba.core import types, utils, errors import unittest from numba import testing from numba.tests.support import TestCase, MemoryLeakMixin, tag from numba.misc.quicksort import make_py_quicksort, make_jit_quicksort from numba.misc.mergesort import make_jit_mergesort from numba.misc.timsort import make_py_timsort, make_jit_timsort, MergeRun def make_temp_list(keys, n): return [keys[0]] * n def make_temp_array(keys, n): return np.empty(n, keys.dtype) py_list_timsort = make_py_timsort(make_temp_list) py_array_timsort = make_py_timsort(make_temp_array) jit_list_timsort = make_jit_timsort(make_temp_list) jit_array_timsort = make_jit_timsort(make_temp_array) py_quicksort = make_py_quicksort() jit_quicksort = make_jit_quicksort() def sort_usecase(val): val.sort() def argsort_usecase(val): return val.argsort() def argsort_kind_usecase(val, is_stable=False): if is_stable: return val.argsort(kind='mergesort') else: return val.argsort(kind='quicksort') def sorted_usecase(val): return sorted(val) def sorted_reverse_usecase(val, b): return sorted(val, reverse=b) def np_sort_usecase(val): return np.sort(val) def np_argsort_usecase(val): return np.argsort(val) def np_argsort_kind_usecase(val, is_stable=False): if is_stable: return np.argsort(val, kind='mergesort') else: return np.argsort(val, kind='quicksort') def list_sort_usecase(n): np.random.seed(42) l = [] for i in range(n): l.append(np.random.random()) ll = l[:] ll.sort() return l, ll def list_sort_reverse_usecase(n, b): np.random.seed(42) l = [] for i in range(n): l.append(np.random.random()) ll = l[:] ll.sort(reverse=b) return l, ll class BaseSortingTest(object): def random_list(self, n, offset=10): random.seed(42) l = list(range(offset, offset + n)) random.shuffle(l) return l def sorted_list(self, n, offset=10): return list(range(offset, offset + n)) def revsorted_list(self, n, offset=10): return list(range(offset, offset + n))[::-1] def initially_sorted_list(self, n, m=None, offset=10): if m is None: m = n // 2 l = self.sorted_list(m, offset) l += self.random_list(n - m, offset=l[-1] + offset) return l def duprandom_list(self, n, factor=None, offset=10): random.seed(42) if factor is None: factor = int(math.sqrt(n)) l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n] assert len(l) == n random.shuffle(l) return l def dupsorted_list(self, n, factor=None, offset=10): if factor is None: factor = int(math.sqrt(n)) l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n] assert len(l) == n, (len(l), n) l.sort() return l def assertSorted(self, orig, result): self.assertEqual(len(result), len(orig)) # sorted() returns a list, so make sure we compare to another list self.assertEqual(list(result), sorted(orig)) def assertSortedValues(self, orig, orig_values, result, result_values): self.assertEqual(len(result), len(orig)) self.assertEqual(list(result), sorted(orig)) zip_sorted = sorted(zip(orig, orig_values), key=lambda x: x[0]) zip_result = list(zip(result, result_values)) self.assertEqual(zip_sorted, zip_result) # Check stability for i in range(len(zip_result) - 1): (k1, v1), (k2, v2) = zip_result[i], zip_result[i + 1] if k1 == k2: # Assuming values are unique, which is enforced by the tests self.assertLess(orig_values.index(v1), orig_values.index(v2)) def fibo(self): a = 1 b = 1 while True: yield a a, b = b, a + b def make_sample_sorted_lists(self, n): lists = [] for offset in (20, 120): lists.append(self.sorted_list(n, offset)) lists.append(self.dupsorted_list(n, offset)) return lists def make_sample_lists(self, n): lists = [] for offset in (20, 120): lists.append(self.sorted_list(n, offset)) lists.append(self.dupsorted_list(n, offset)) lists.append(self.revsorted_list(n, offset)) lists.append(self.duprandom_list(n, offset)) return lists class BaseTimsortTest(BaseSortingTest): def merge_init(self, keys): f = self.timsort.merge_init return f(keys) def test_binarysort(self): n = 20 def check(l, n, start=0): res = self.array_factory(l) f(res, res, 0, n, start) self.assertSorted(l, res) f = self.timsort.binarysort l = self.sorted_list(n) check(l, n) check(l, n, n//2) l = self.revsorted_list(n) check(l, n) l = self.initially_sorted_list(n, n//2) check(l, n) check(l, n, n//2) l = self.revsorted_list(n) check(l, n) l = self.random_list(n) check(l, n) l = self.duprandom_list(n) check(l, n) def test_binarysort_with_values(self): n = 20 v = list(range(100, 100+n)) def check(l, n, start=0): res = self.array_factory(l) res_v = self.array_factory(v) f(res, res_v, 0, n, start) self.assertSortedValues(l, v, res, res_v) f = self.timsort.binarysort l = self.sorted_list(n) check(l, n) check(l, n, n//2) l = self.revsorted_list(n) check(l, n) l = self.initially_sorted_list(n, n//2) check(l, n) check(l, n, n//2) l = self.revsorted_list(n) check(l, n) l = self.random_list(n) check(l, n) l = self.duprandom_list(n) check(l, n) def test_count_run(self): n = 16 f = self.timsort.count_run def check(l, lo, hi): n, desc = f(self.array_factory(l), lo, hi) # Fully check invariants if desc: for k in range(lo, lo + n - 1): a, b = l[k], l[k + 1] self.assertGreater(a, b) if lo + n < hi: self.assertLessEqual(l[lo + n - 1], l[lo + n]) else: for k in range(lo, lo + n - 1): a, b = l[k], l[k + 1] self.assertLessEqual(a, b) if lo + n < hi: self.assertGreater(l[lo + n - 1], l[lo + n], l) l = self.sorted_list(n, offset=100) check(l, 0, n) check(l, 1, n - 1) check(l, 1, 2) l = self.revsorted_list(n, offset=100) check(l, 0, n) check(l, 1, n - 1) check(l, 1, 2) l = self.random_list(n, offset=100) for i in range(len(l) - 1): check(l, i, n) l = self.duprandom_list(n, offset=100) for i in range(len(l) - 1): check(l, i, n) def test_gallop_left(self): n = 20 f = self.timsort.gallop_left def check(l, key, start, stop, hint): k = f(key, l, start, stop, hint) # Fully check invariants self.assertGreaterEqual(k, start) self.assertLessEqual(k, stop) if k > start: self.assertLess(l[k - 1], key) if k < stop: self.assertGreaterEqual(l[k], key) def check_all_hints(l, key, start, stop): for hint in range(start, stop): check(l, key, start, stop, hint) def check_sorted_list(l): l = self.array_factory(l) for key in (l[5], l[15], l[0], -1000, l[-1], 1000): check_all_hints(l, key, 0, n) check_all_hints(l, key, 1, n - 1) check_all_hints(l, key, 8, n - 8) l = self.sorted_list(n, offset=100) check_sorted_list(l) l = self.dupsorted_list(n, offset=100) check_sorted_list(l) def test_gallop_right(self): n = 20 f = self.timsort.gallop_right def check(l, key, start, stop, hint): k = f(key, l, start, stop, hint) # Fully check invariants self.assertGreaterEqual(k, start) self.assertLessEqual(k, stop) if k > start: self.assertLessEqual(l[k - 1], key) if k < stop: self.assertGreater(l[k], key) def check_all_hints(l, key, start, stop): for hint in range(start, stop): check(l, key, start, stop, hint) def check_sorted_list(l): l = self.array_factory(l) for key in (l[5], l[15], l[0], -1000, l[-1], 1000): check_all_hints(l, key, 0, n) check_all_hints(l, key, 1, n - 1) check_all_hints(l, key, 8, n - 8) l = self.sorted_list(n, offset=100) check_sorted_list(l) l = self.dupsorted_list(n, offset=100) check_sorted_list(l) def test_merge_compute_minrun(self): f = self.timsort.merge_compute_minrun for i in range(0, 64): self.assertEqual(f(i), i) for i in range(6, 63): if 2**i > sys.maxsize: break self.assertEqual(f(2**i), 32) for i in self.fibo(): if i < 64: continue if i >= sys.maxsize: break k = f(i) self.assertGreaterEqual(k, 32) self.assertLessEqual(k, 64) if i > 500: # i/k is close to, but strictly less than, an exact power of 2 quot = i // k p = 2 ** utils.bit_length(quot) self.assertLess(quot, p) self.assertGreaterEqual(quot, 0.9 * p) def check_merge_lo_hi(self, func, a, b): na = len(a) nb = len(b) # Add sentinels at start and end, to check they weren't moved orig_keys = [42] + a + b + [-42] keys = self.array_factory(orig_keys) ms = self.merge_init(keys) ssa = 1 ssb = ssa + na #new_ms = func(ms, keys, [], ssa, na, ssb, nb) new_ms = func(ms, keys, keys, ssa, na, ssb, nb) self.assertEqual(keys[0], orig_keys[0]) self.assertEqual(keys[-1], orig_keys[-1]) self.assertSorted(orig_keys[1:-1], keys[1:-1]) # Check the MergeState result self.assertGreaterEqual(len(new_ms.keys), len(ms.keys)) self.assertGreaterEqual(len(new_ms.values), len(ms.values)) self.assertIs(new_ms.pending, ms.pending) self.assertGreaterEqual(new_ms.min_gallop, 1) def test_merge_lo_hi(self): f_lo = self.timsort.merge_lo f_hi = self.timsort.merge_hi # The larger sizes exercise galloping for (na, nb) in [(12, 16), (40, 40), (100, 110), (1000, 1100)]: for a, b in itertools.product(self.make_sample_sorted_lists(na), self.make_sample_sorted_lists(nb)): self.check_merge_lo_hi(f_lo, a, b) self.check_merge_lo_hi(f_hi, b, a) def check_merge_at(self, a, b): f = self.timsort.merge_at # Prepare the array to be sorted na = len(a) nb = len(b) # Add sentinels at start and end, to check they weren't moved orig_keys = [42] + a + b + [-42] ssa = 1 ssb = ssa + na stack_sentinel = MergeRun(-42, -42) def run_merge_at(ms, keys, i): new_ms = f(ms, keys, keys, i) self.assertEqual(keys[0], orig_keys[0]) self.assertEqual(keys[-1], orig_keys[-1]) self.assertSorted(orig_keys[1:-1], keys[1:-1]) # Check stack state self.assertIs(new_ms.pending, ms.pending) self.assertEqual(ms.pending[i], (ssa, na + nb)) self.assertEqual(ms.pending[0], stack_sentinel) return new_ms # First check with i == len(stack) - 2 keys = self.array_factory(orig_keys) ms = self.merge_init(keys) # Push sentinel on stack, to check it was't touched ms = self.timsort.merge_append(ms, stack_sentinel) i = ms.n ms = self.timsort.merge_append(ms, MergeRun(ssa, na)) ms = self.timsort.merge_append(ms, MergeRun(ssb, nb)) ms = run_merge_at(ms, keys, i) self.assertEqual(ms.n, i + 1) # Now check with i == len(stack) - 3 keys = self.array_factory(orig_keys) ms = self.merge_init(keys) # Push sentinel on stack, to check it was't touched ms = self.timsort.merge_append(ms, stack_sentinel) i = ms.n ms = self.timsort.merge_append(ms, MergeRun(ssa, na)) ms = self.timsort.merge_append(ms, MergeRun(ssb, nb)) # A last run (trivial here) last_run = MergeRun(ssb + nb, 1) ms = self.timsort.merge_append(ms, last_run) ms = run_merge_at(ms, keys, i) self.assertEqual(ms.n, i + 2) self.assertEqual(ms.pending[ms.n - 1], last_run) def test_merge_at(self): # The larger sizes exercise galloping for (na, nb) in [(12, 16), (40, 40), (100, 110), (500, 510)]: for a, b in itertools.product(self.make_sample_sorted_lists(na), self.make_sample_sorted_lists(nb)): self.check_merge_at(a, b) self.check_merge_at(b, a) def test_merge_force_collapse(self): f = self.timsort.merge_force_collapse # Test with runs of ascending sizes, then descending sizes sizes_list = [(8, 10, 15, 20)] sizes_list.append(sizes_list[0][::-1]) for sizes in sizes_list: for chunks in itertools.product(*(self.make_sample_sorted_lists(n) for n in sizes)): # Create runs of the given sizes orig_keys = sum(chunks, []) keys = self.array_factory(orig_keys) ms = self.merge_init(keys) pos = 0 for c in chunks: ms = self.timsort.merge_append(ms, MergeRun(pos, len(c))) pos += len(c) # Sanity check self.assertEqual(sum(ms.pending[ms.n - 1]), len(keys)) # Now merge the runs ms = f(ms, keys, keys) # Remaining run is the whole list self.assertEqual(ms.n, 1) self.assertEqual(ms.pending[0], MergeRun(0, len(keys))) # The list is now sorted self.assertSorted(orig_keys, keys) def test_run_timsort(self): f = self.timsort.run_timsort for size_factor in (1, 10): # Make lists to be sorted from three chunks of different kinds. sizes = (15, 30, 20) all_lists = [self.make_sample_lists(n * size_factor) for n in sizes] for chunks in itertools.product(*all_lists): orig_keys = sum(chunks, []) keys = self.array_factory(orig_keys) f(keys) # The list is now sorted self.assertSorted(orig_keys, keys) def test_run_timsort_with_values(self): # Run timsort, but also with a values array f = self.timsort.run_timsort_with_values for size_factor in (1, 5): chunk_size = 80 * size_factor a = self.dupsorted_list(chunk_size) b = self.duprandom_list(chunk_size) c = self.revsorted_list(chunk_size) orig_keys = a + b + c orig_values = list(range(1000, 1000 + len(orig_keys))) keys = self.array_factory(orig_keys) values = self.array_factory(orig_values) f(keys, values) # This checks sort stability self.assertSortedValues(orig_keys, orig_values, keys, values) class TestTimsortPurePython(BaseTimsortTest, TestCase): timsort = py_list_timsort # Much faster than a Numpy array in pure Python array_factory = list class TestTimsortArraysPurePython(BaseTimsortTest, TestCase): timsort = py_array_timsort def array_factory(self, lst): return np.array(lst, dtype=np.int32) class JITTimsortMixin(object): timsort = jit_array_timsort test_merge_at = None test_merge_force_collapse = None def wrap_with_mergestate(self, timsort, func, _cache={}): """ Wrap *func* into another compiled function inserting a runtime-created mergestate as the first function argument. """ key = timsort, func if key in _cache: return _cache[key] merge_init = timsort.merge_init @timsort.compile def wrapper(keys, values, *args): ms = merge_init(keys) res = func(ms, keys, values, *args) return res _cache[key] = wrapper return wrapper class TestTimsortArrays(JITTimsortMixin, BaseTimsortTest, TestCase): def array_factory(self, lst): return np.array(lst, dtype=np.int32) def check_merge_lo_hi(self, func, a, b): na = len(a) nb = len(b) func = self.wrap_with_mergestate(self.timsort, func) # Add sentinels at start and end, to check they weren't moved orig_keys = [42] + a + b + [-42] keys = self.array_factory(orig_keys) ssa = 1 ssb = ssa + na new_ms = func(keys, keys, ssa, na, ssb, nb) self.assertEqual(keys[0], orig_keys[0]) self.assertEqual(keys[-1], orig_keys[-1]) self.assertSorted(orig_keys[1:-1], keys[1:-1]) class BaseQuicksortTest(BaseSortingTest): # NOTE these tests assume a non-argsort quicksort. def test_insertion_sort(self): n = 20 def check(l, n): res = self.array_factory([9999] + l + [-9999]) f(res, res, 1, n) self.assertEqual(res[0], 9999) self.assertEqual(res[-1], -9999) self.assertSorted(l, res[1:-1]) f = self.quicksort.insertion_sort l = self.sorted_list(n) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.initially_sorted_list(n, n//2) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.random_list(n) check(l, n) l = self.duprandom_list(n) check(l, n) def test_partition(self): n = 20 def check(l, n): res = self.array_factory([9999] + l + [-9999]) index = f(res, res, 1, n) self.assertEqual(res[0], 9999) self.assertEqual(res[-1], -9999) pivot = res[index] for i in range(1, index): self.assertLessEqual(res[i], pivot) for i in range(index + 1, n): self.assertGreaterEqual(res[i], pivot) f = self.quicksort.partition l = self.sorted_list(n) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.initially_sorted_list(n, n//2) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.random_list(n) check(l, n) l = self.duprandom_list(n) check(l, n) def test_partition3(self): # Test the unused partition3() function n = 20 def check(l, n): res = self.array_factory([9999] + l + [-9999]) lt, gt = f(res, 1, n) self.assertEqual(res[0], 9999) self.assertEqual(res[-1], -9999) pivot = res[lt] for i in range(1, lt): self.assertLessEqual(res[i], pivot) for i in range(lt, gt + 1): self.assertEqual(res[i], pivot) for i in range(gt + 1, n): self.assertGreater(res[i], pivot) f = self.quicksort.partition3 l = self.sorted_list(n) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.initially_sorted_list(n, n//2) check(l, n) l = self.revsorted_list(n) check(l, n) l = self.random_list(n) check(l, n) l = self.duprandom_list(n) check(l, n) def test_run_quicksort(self): f = self.quicksort.run_quicksort for size_factor in (1, 5): # Make lists to be sorted from two chunks of different kinds. sizes = (15, 20) all_lists = [self.make_sample_lists(n * size_factor) for n in sizes] for chunks in itertools.product(*all_lists): orig_keys = sum(chunks, []) keys = self.array_factory(orig_keys) f(keys) # The list is now sorted self.assertSorted(orig_keys, keys) def test_run_quicksort_lt(self): def lt(a, b): return a > b f = self.make_quicksort(lt=lt).run_quicksort for size_factor in (1, 5): # Make lists to be sorted from two chunks of different kinds. sizes = (15, 20) all_lists = [self.make_sample_lists(n * size_factor) for n in sizes] for chunks in itertools.product(*all_lists): orig_keys = sum(chunks, []) keys = self.array_factory(orig_keys) f(keys) # The list is now rev-sorted self.assertSorted(orig_keys, keys[::-1]) # An imperfect comparison function, as LT(a, b) does not imply not LT(b, a). # The sort should handle it gracefully. def lt_floats(a, b): return math.isnan(b) or a < b f = self.make_quicksort(lt=lt_floats).run_quicksort np.random.seed(42) for size in (5, 20, 50, 500): orig = np.random.random(size=size) * 100 orig[np.random.random(size=size) < 0.1] = float('nan') orig_keys = list(orig) keys = self.array_factory(orig_keys) f(keys) non_nans = orig[~np.isnan(orig)] # Non-NaNs are sorted at the front self.assertSorted(non_nans, keys[:len(non_nans)]) class TestQuicksortPurePython(BaseQuicksortTest, TestCase): quicksort = py_quicksort make_quicksort = staticmethod(make_py_quicksort) # Much faster than a Numpy array in pure Python array_factory = list class TestQuicksortArrays(BaseQuicksortTest, TestCase): quicksort = jit_quicksort make_quicksort = staticmethod(make_jit_quicksort) def array_factory(self, lst): return np.array(lst, dtype=np.float64) class TestNumpySort(TestCase): def setUp(self): np.random.seed(42) def int_arrays(self): for size in (5, 20, 50, 500): yield np.random.randint(99, size=size) def float_arrays(self): for size in (5, 20, 50, 500): yield np.random.random(size=size) * 100 # Now with NaNs. Numpy sorts them at the end. for size in (5, 20, 50, 500): orig = np.random.random(size=size) * 100 orig[np.random.random(size=size) < 0.1] = float('nan') yield orig def has_duplicates(self, arr): """ Whether the array has duplicates. Takes NaNs into account. """ if np.count_nonzero(np.isnan(arr)) > 1: return True if np.unique(arr).size < arr.size: return True return False def check_sort_inplace(self, pyfunc, cfunc, val): expected = copy.copy(val) got = copy.copy(val) pyfunc(expected) cfunc(got) self.assertPreciseEqual(got, expected) def check_sort_copy(self, pyfunc, cfunc, val): orig = copy.copy(val) expected = pyfunc(val) got = cfunc(val) self.assertPreciseEqual(got, expected) # The original wasn't mutated self.assertPreciseEqual(val, orig) def check_argsort(self, pyfunc, cfunc, val, kwargs={}): orig = copy.copy(val) expected = pyfunc(val, **kwargs) got = cfunc(val, **kwargs) self.assertPreciseEqual(orig[got], np.sort(orig), msg="the array wasn't argsorted") # Numba and Numpy results may differ if there are duplicates # in the array if not self.has_duplicates(orig): self.assertPreciseEqual(got, expected) # The original wasn't mutated self.assertPreciseEqual(val, orig) def test_array_sort_int(self): pyfunc = sort_usecase cfunc = jit(nopython=True)(pyfunc) for orig in self.int_arrays(): self.check_sort_inplace(pyfunc, cfunc, orig) def test_array_sort_float(self): pyfunc = sort_usecase cfunc = jit(nopython=True)(pyfunc) for orig in self.float_arrays(): self.check_sort_inplace(pyfunc, cfunc, orig) def test_np_sort_int(self): pyfunc = np_sort_usecase cfunc = jit(nopython=True)(pyfunc) for orig in self.int_arrays(): self.check_sort_copy(pyfunc, cfunc, orig) def test_np_sort_float(self): pyfunc = np_sort_usecase cfunc = jit(nopython=True)(pyfunc) for size in (5, 20, 50, 500): orig = np.random.random(size=size) * 100 orig[np.random.random(size=size) < 0.1] = float('nan') self.check_sort_copy(pyfunc, cfunc, orig) def test_argsort_int(self): def check(pyfunc): cfunc = jit(nopython=True)(pyfunc) for orig in self.int_arrays(): self.check_argsort(pyfunc, cfunc, orig) check(argsort_usecase) check(np_argsort_usecase) def test_argsort_kind_int(self): def check(pyfunc, is_stable): cfunc = jit(nopython=True)(pyfunc) for orig in self.int_arrays(): self.check_argsort(pyfunc, cfunc, orig, dict(is_stable=is_stable)) check(argsort_kind_usecase, is_stable=True) check(np_argsort_kind_usecase, is_stable=True) check(argsort_kind_usecase, is_stable=False) check(np_argsort_kind_usecase, is_stable=False) def test_argsort_float(self): def check(pyfunc): cfunc = jit(nopython=True)(pyfunc) for orig in self.float_arrays(): self.check_argsort(pyfunc, cfunc, orig) check(argsort_usecase) check(np_argsort_usecase) def test_argsort_float(self): def check(pyfunc, is_stable): cfunc = jit(nopython=True)(pyfunc) for orig in self.float_arrays(): self.check_argsort(pyfunc, cfunc, orig, dict(is_stable=is_stable)) check(argsort_kind_usecase, is_stable=True) check(np_argsort_kind_usecase, is_stable=True) check(argsort_kind_usecase, is_stable=False) check(np_argsort_kind_usecase, is_stable=False) class TestPythonSort(TestCase): def test_list_sort(self): pyfunc = list_sort_usecase cfunc = jit(nopython=True)(pyfunc) for size in (20, 50, 500): orig, ret = cfunc(size) self.assertEqual(sorted(orig), ret) self.assertNotEqual(orig, ret) # sanity check def test_list_sort_reverse(self): pyfunc = list_sort_reverse_usecase cfunc = jit(nopython=True)(pyfunc) for size in (20, 50, 500): for b in (False, True): orig, ret = cfunc(size, b) self.assertEqual(sorted(orig, reverse=b), ret) self.assertNotEqual(orig, ret) # sanity check def test_sorted(self): pyfunc = sorted_usecase cfunc = jit(nopython=True)(pyfunc) for size in (20, 50, 500): orig = np.random.random(size=size) * 100 expected = sorted(orig) got = cfunc(orig) self.assertPreciseEqual(got, expected) self.assertNotEqual(list(orig), got) # sanity check def test_sorted_reverse(self): pyfunc = sorted_reverse_usecase cfunc = jit(nopython=True)(pyfunc) size = 20 orig = np.random.random(size=size) * 100 for b in (False, True): expected = sorted(orig, reverse=b) got = cfunc(orig, b) self.assertPreciseEqual(got, expected) self.assertNotEqual(list(orig), got) # sanity check class TestMergeSort(TestCase): def setUp(self): np.random.seed(321) def check_argsort_stable(self, sorter, low, high, count): # make data with high possibility of duplicated key data = np.random.randint(low, high, count) expect = np.argsort(data, kind='mergesort') got = sorter(data) np.testing.assert_equal(expect, got) def test_argsort_stable(self): arglist = [ (-2, 2, 5), (-5, 5, 10), (0, 10, 101), (0, 100, 1003), ] imp = make_jit_mergesort(is_argsort=True) toplevel = imp.run_mergesort sorter = njit(lambda arr: toplevel(arr)) for args in arglist: self.check_argsort_stable(sorter, *args) nop_compiler = lambda x:x class TestSortSlashSortedWithKey(MemoryLeakMixin, TestCase): def test_01(self): a = [3, 1, 4, 1, 5, 9] @njit def external_key(z): return 1. / z @njit def foo(x, key=None): new_x = x[:] new_x.sort(key=key) return sorted(x[:], key=key), new_x self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:])) self.assertPreciseEqual(foo(a[:], external_key), foo.py_func(a[:], external_key)) def test_02(self): a = [3, 1, 4, 1, 5, 9] @njit def foo(x): def closure_key(z): return 1. / z new_x = x[:] new_x.sort(key=closure_key) return sorted(x[:], key=closure_key), new_x self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:])) def test_03(self): a = [3, 1, 4, 1, 5, 9] def gen(compiler): @compiler def bar(x, func): new_x = x[:] new_x.sort(key=func) return sorted(x[:], key=func), new_x @compiler def foo(x): def closure_escapee_key(z): return 1. / z return bar(x, closure_escapee_key) return foo self.assertPreciseEqual(gen(njit)(a[:]), gen(nop_compiler)(a[:])) def test_04(self): a = ['a','b','B','b','C','A'] @njit def external_key(z): return z.upper() @njit def foo(x, key=None): new_x = x[:] new_x.sort(key=key) return sorted(x[:], key=key), new_x self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:])) self.assertPreciseEqual(foo(a[:], external_key), foo.py_func(a[:], external_key)) def test_05(self): a = ['a','b','B','b','C','A'] @njit def external_key(z): return z.upper() @njit def foo(x, key=None, reverse=False): new_x = x[:] new_x.sort(key=key, reverse=reverse) return (sorted(x[:], key=key, reverse=reverse), new_x) for key, rev in itertools.product((None, external_key), (True, False, 1, -12, 0)): self.assertPreciseEqual(foo(a[:], key, rev), foo.py_func(a[:], key, rev)) def test_optional_on_key(self): a = [3, 1, 4, 1, 5, 9] @njit def foo(x, predicate): if predicate: def closure_key(z): return 1. / z else: closure_key = None new_x = x[:] new_x.sort(key=closure_key) return (sorted(x[:], key=closure_key), new_x) with self.assertRaises(errors.TypingError) as raises: TF = True foo(a[:], TF) msg = "Key must concretely be None or a Numba JIT compiled function" self.assertIn(msg, str(raises.exception)) def test_exceptions_sorted(self): @njit def foo_sorted(x, key=None, reverse=False): return sorted(x[:], key=key, reverse=reverse) @njit def foo_sort(x, key=None, reverse=False): new_x = x[:] new_x.sort(key=key, reverse=reverse) return new_x @njit def external_key(z): return 1. / z a = [3, 1, 4, 1, 5, 9] for impl in (foo_sort, foo_sorted): # check illegal key with self.assertRaises(errors.TypingError) as raises: impl(a, key="illegal") expect = "Key must be None or a Numba JIT compiled function" self.assertIn(expect, str(raises.exception)) # check illegal reverse with self.assertRaises(errors.TypingError) as raises: impl(a, key=external_key, reverse="go backwards") expect = "an integer is required for 'reverse'" self.assertIn(expect, str(raises.exception)) class TestArrayArgsort(MemoryLeakMixin, TestCase): """Tests specific to array.argsort""" def test_exceptions(self): @njit def nonliteral_kind(kind): np.arange(5).argsort(kind=kind) # check non-literal kind with self.assertRaises(errors.TypingError) as raises: # valid spelling but not literal nonliteral_kind('quicksort') expect = '"kind" must be a string literal' self.assertIn(expect, str(raises.exception)) @njit def unsupported_kwarg(): np.arange(5).argsort(foo='') with self.assertRaises(errors.TypingError) as raises: unsupported_kwarg() expect = "Unsupported keywords: ['foo']" self.assertIn(expect, str(raises.exception)) if __name__ == '__main__': unittest.main()