import unittest from collections import namedtuple import contextlib import itertools import math import random import sys import numpy as np from numba.core.compiler import compile_isolated, Flags, errors from numba import jit from numba.core import types import unittest from numba.tests.support import (TestCase, enable_pyobj_flags, MemoryLeakMixin, tag, compile_function) Point = namedtuple('Point', ('a', 'b')) def _build_set_literal_usecase(code, args): code = code % {'initializer': ', '.join(repr(arg) for arg in args)} return compile_function('build_set', code, globals()) def set_literal_return_usecase(args): code = """if 1: def build_set(): return {%(initializer)s} """ return _build_set_literal_usecase(code, args) def set_literal_convert_usecase(args): code = """if 1: def build_set(): my_set = {%(initializer)s} return list(my_set) """ return _build_set_literal_usecase(code, args) def empty_constructor_usecase(): s = set() s.add(1) return len(s) def constructor_usecase(arg): s = set(arg) return len(s) def iterator_usecase(arg): s = set(arg) l = [] for v in s: l.append(v) return l def update_usecase(a, b, c): s = set() s.update(a) s.update(b) s.update(c) return list(s) def bool_usecase(arg): # Remove one element to allow for empty sets. s = set(arg[1:]) return bool(s) def remove_usecase(a, b): s = set(a) for v in b: s.remove(v) return list(s) def discard_usecase(a, b): s = set(a) for v in b: s.discard(v) return list(s) def add_discard_usecase(a, u, v): s = set(a) for i in range(1000): s.add(u) s.discard(v) return list(s) def pop_usecase(a): s = set(a) l = [] while len(s) > 0: l.append(s.pop()) return l def contains_usecase(a, b): s = set(a) l = [] for v in b: l.append(v in s) return l def difference_update_usecase(a, b): s = set(a) s.difference_update(set(b)) return list(s) def intersection_update_usecase(a, b): s = set(a) s.intersection_update(set(b)) return list(s) def symmetric_difference_update_usecase(a, b): s = set(a) s.symmetric_difference_update(set(b)) return list(s) def isdisjoint_usecase(a, b): return set(a).isdisjoint(set(b)) def issubset_usecase(a, b): return set(a).issubset(set(b)) def issuperset_usecase(a, b): return set(a).issuperset(set(b)) def clear_usecase(a): s = set(a) s.clear() return len(s), list(s) def copy_usecase(a): s = set(a) ss = s.copy() s.pop() return len(ss), list(ss) def copy_usecase_empty(a): s = set(a) s.clear() ss = s.copy() s.add(42) return len(ss), list(ss) def copy_usecase_deleted(a, b): s = set(a) s.remove(b) ss = s.copy() s.pop() return len(ss), list(ss) def difference_usecase(a, b): sa = set(a) s = sa.difference(set(b)) return list(s) def intersection_usecase(a, b): sa = set(a) s = sa.intersection(set(b)) return list(s) def symmetric_difference_usecase(a, b): sa = set(a) s = sa.symmetric_difference(set(b)) return list(s) def union_usecase(a, b): sa = set(a) s = sa.union(set(b)) return list(s) def set_return_usecase(a): s = set(a) return s def make_operator_usecase(op): code = """if 1: def operator_usecase(a, b): s = set(a) %(op)s set(b) return list(s) """ % dict(op=op) return compile_function('operator_usecase', code, globals()) def make_inplace_operator_usecase(op): code = """if 1: def inplace_operator_usecase(a, b): sa = set(a) sb = set(b) sc = sa sc %(op)s sb return list(sc), list(sa) """ % dict(op=op) return compile_function('inplace_operator_usecase', code, globals()) def make_comparison_usecase(op): code = """if 1: def comparison_usecase(a, b): return set(a) %(op)s set(b) """ % dict(op=op) return compile_function('comparison_usecase', code, globals()) def noop(x): pass def unbox_usecase(x): """ Expect a set of numbers """ res = 0 for v in x: res += v return res def unbox_usecase2(x): """ Expect a set of tuples """ res = 0 for v in x: res += len(v) return res def unbox_usecase3(x): """ Expect a (number, set of numbers) tuple. """ a, b = x res = a for v in b: res += v return res def unbox_usecase4(x): """ Expect a (number, set of tuples) tuple. """ a, b = x res = a for v in b: res += len(v) return res def reflect_simple(sa, sb): sa.add(42) sa.update(sb) return sa, len(sa), len(sb) def reflect_conditional(sa, sb): # `sa` may or may not actually reflect a Python set if len(sb) > 1: sa = set((11., 22., 33., 44.)) sa.add(42.) sa.update(sb) # Combine with a non-reflected set (to check method typing) sc = set((55., 66.)) sa.symmetric_difference_update(sc) return sa, len(sa), len(sb) def reflect_exception(s): s.add(42) raise ZeroDivisionError def reflect_dual(sa, sb): sa.add(sb.pop()) return sa is sb def unique_usecase(src): seen = set() res = [] for v in src: if v not in seen: seen.add(v) res.append(v) return res class BaseTest(MemoryLeakMixin, TestCase): def setUp(self): super(BaseTest, self).setUp() self.rnd = random.Random(42) def _range(self, stop): return np.arange(int(stop)) def _random_choice(self, seq, n): """ Choose *n* possibly duplicate items from sequence. """ l = [self.rnd.choice(list(seq)) for i in range(n)] if isinstance(seq, np.ndarray): return np.array(l, dtype=seq.dtype) else: return l def duplicates_array(self, n): """ Get a 1d array with many duplicate values. """ a = self._range(np.sqrt(n)) return self._random_choice(a, n) def sparse_array(self, n): """ Get a 1d array with values spread around. """ # Note two calls to sparse_array() should generate reasonable overlap a = self._range(n ** 1.3) return self._random_choice(a, n) def _assert_equal_unordered(self, a, b): if isinstance(a, tuple): self.assertIsInstance(b, tuple) for u, v in zip(a, b): self._assert_equal_unordered(u, v) elif isinstance(a, list): self.assertIsInstance(b, list) self.assertPreciseEqual(sorted(a), sorted(b)) else: self.assertPreciseEqual(a, b) def unordered_checker(self, pyfunc): cfunc = jit(nopython=True)(pyfunc) def check(*args): expected = pyfunc(*args) got = cfunc(*args) self._assert_equal_unordered(expected, got) return check class TestSetLiterals(BaseTest): def test_build_set(self, flags=enable_pyobj_flags): pyfunc = set_literal_return_usecase((1, 2, 3, 2)) self.run_nullary_func(pyfunc, flags=flags) def test_build_heterogeneous_set(self, flags=enable_pyobj_flags): pyfunc = set_literal_return_usecase((1, 2.0, 3j, 2)) self.run_nullary_func(pyfunc, flags=flags) pyfunc = set_literal_return_usecase((2.0, 2)) got, expected = self.run_nullary_func(pyfunc, flags=flags) self.assertIs(type(got.pop()), type(expected.pop())) def test_build_set_nopython(self): arg = list(self.sparse_array(50)) pyfunc = set_literal_convert_usecase(arg) cfunc = jit(nopython=True)(pyfunc) expected = pyfunc() got = cfunc() self.assertPreciseEqual(sorted(expected), sorted(got)) class TestSets(BaseTest): def test_constructor(self): pyfunc = empty_constructor_usecase cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(), pyfunc()) pyfunc = constructor_usecase cfunc = jit(nopython=True)(pyfunc) def check(arg): self.assertPreciseEqual(pyfunc(arg), cfunc(arg)) check((1, 2, 3, 2, 7)) check(self.duplicates_array(200)) check(self.sparse_array(200)) def test_set_return(self): pyfunc = set_return_usecase cfunc = jit(nopython=True)(pyfunc) arg = (1, 2, 3, 2, 7) self.assertEqual(cfunc(arg), set(arg)) def test_iterator(self): pyfunc = iterator_usecase check = self.unordered_checker(pyfunc) check((1, 2, 3, 2, 7)) check(self.duplicates_array(200)) check(self.sparse_array(200)) def test_update(self): pyfunc = update_usecase check = self.unordered_checker(pyfunc) a, b, c = (1, 2, 4, 9), (2, 3, 5, 11, 42), (4, 5, 6, 42) check(a, b, c) a = self.sparse_array(50) b = self.duplicates_array(50) c = self.sparse_array(50) check(a, b, c) def test_bool(self): pyfunc = bool_usecase check = self.unordered_checker(pyfunc) check([1]) check([1, 2]) check([False, False]) check([True, False]) def test_remove(self): pyfunc = remove_usecase check = self.unordered_checker(pyfunc) a = (1, 2, 3, 5, 8, 42) b = (5, 2, 8) check(a, b) def test_remove_error(self): # References are leaked on exception self.disable_leak_check() pyfunc = remove_usecase cfunc = jit(nopython=True)(pyfunc) with self.assertRaises(KeyError) as raises: cfunc((1, 2, 3), (5, )) def test_refcounted_types_forbidden(self): # References are leaked on exception self.disable_leak_check() pyfunc = constructor_usecase cfunc = jit(nopython=True)(pyfunc) with self.assertRaises(errors.NumbaValueError) as raises: cfunc("abc") excstr = str(raises.exception) self.assertIn("Use of reference counted items in 'set()'", excstr) self.assertIn("offending type is: 'unicode_type'", excstr) def test_discard(self): pyfunc = discard_usecase check = self.unordered_checker(pyfunc) a = (1, 2, 3, 5, 8, 42) b = (5, 2, 8) check(a, b) a = self.sparse_array(50) b = self.sparse_array(50) check(a, b) def test_add_discard(self): """ Check that the insertion logic does not create an infinite lookup chain with deleted entries (insertion should happen at the first deleted entry, not at the free entry at the end of the chain). See issue #1913. """ pyfunc = add_discard_usecase check = self.unordered_checker(pyfunc) check((1,), 5, 5) def test_pop(self): pyfunc = pop_usecase check = self.unordered_checker(pyfunc) check((2, 3, 55, 11, 8, 42)) check(self.sparse_array(50)) def test_contains(self): pyfunc = contains_usecase cfunc = jit(nopython=True)(pyfunc) def check(a, b): self.assertPreciseEqual(pyfunc(a, b), cfunc(a, b)) a = (1, 2, 3, 5, 42) b = (5, 2, 8, 3) check(a, b) def _test_xxx_update(self, pyfunc): check = self.unordered_checker(pyfunc) a, b = (1, 2, 4, 11), (2, 3, 5, 11, 42) check(a, b) sizes = (0, 50, 500) for na, nb in itertools.product(sizes, sizes): a = self.sparse_array(na) b = self.sparse_array(nb) check(a, b) def test_difference_update(self): self._test_xxx_update(difference_update_usecase) def test_intersection_update(self): self._test_xxx_update(intersection_update_usecase) def test_symmetric_difference_update(self): self._test_xxx_update(symmetric_difference_update_usecase) def _test_comparator(self, pyfunc): cfunc = jit(nopython=True)(pyfunc) def check(a, b): self.assertPreciseEqual(pyfunc(a, b), cfunc(a, b)) a, b = map(set, [(1, 2, 4, 11), (2, 3, 5, 11, 42)]) args = [a & b, a - b, a | b, a ^ b] args = [tuple(x) for x in args] for a, b in itertools.product(args, args): check(a, b) def test_isdisjoint(self): self._test_comparator(isdisjoint_usecase) def test_issubset(self): self._test_comparator(issubset_usecase) def test_issuperset(self): self._test_comparator(issuperset_usecase) def test_clear(self): pyfunc = clear_usecase check = self.unordered_checker(pyfunc) check((1, 2, 4, 11)) check(self.sparse_array(50)) def test_copy(self): # Source set doesn't have any deleted entries pyfunc = copy_usecase check = self.unordered_checker(pyfunc) check((1, 2, 4, 11)) check(self.sparse_array(50)) pyfunc = copy_usecase_empty check = self.unordered_checker(pyfunc) check((1,)) # Source set has deleted entries pyfunc = copy_usecase_deleted check = self.unordered_checker(pyfunc) check((1, 2, 4, 11), 2) a = self.sparse_array(50) check(a, a[len(a) // 2]) def _test_set_operator(self, pyfunc): check = self.unordered_checker(pyfunc) a, b = (1, 2, 4, 11), (2, 3, 5, 11, 42) check(a, b) sizes = (0, 50, 500) for na, nb in itertools.product(sizes, sizes): a = self.sparse_array(na) b = self.sparse_array(nb) check(a, b) def test_difference(self): self._test_set_operator(difference_usecase) def test_intersection(self): self._test_set_operator(intersection_usecase) def test_symmetric_difference(self): self._test_set_operator(symmetric_difference_usecase) def test_union(self): self._test_set_operator(union_usecase) def test_and(self): self._test_set_operator(make_operator_usecase('&')) def test_or(self): self._test_set_operator(make_operator_usecase('|')) def test_sub(self): self._test_set_operator(make_operator_usecase('-')) def test_xor(self): self._test_set_operator(make_operator_usecase('^')) def test_eq(self): self._test_set_operator(make_comparison_usecase('==')) def test_ne(self): self._test_set_operator(make_comparison_usecase('!=')) def test_le(self): self._test_set_operator(make_comparison_usecase('<=')) def test_lt(self): self._test_set_operator(make_comparison_usecase('<')) def test_ge(self): self._test_set_operator(make_comparison_usecase('>=')) def test_gt(self): self._test_set_operator(make_comparison_usecase('>')) def test_iand(self): self._test_set_operator(make_inplace_operator_usecase('&=')) def test_ior(self): self._test_set_operator(make_inplace_operator_usecase('|=')) def test_isub(self): self._test_set_operator(make_inplace_operator_usecase('-=')) def test_ixor(self): self._test_set_operator(make_inplace_operator_usecase('^=')) class OtherTypesTest(object): def test_constructor(self): pyfunc = empty_constructor_usecase cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(), pyfunc()) pyfunc = constructor_usecase cfunc = jit(nopython=True)(pyfunc) def check(arg): self.assertPreciseEqual(pyfunc(arg), cfunc(arg)) check(self.duplicates_array(200)) check(self.sparse_array(200)) def test_iterator(self): pyfunc = iterator_usecase check = self.unordered_checker(pyfunc) check(self.duplicates_array(200)) check(self.sparse_array(200)) def test_update(self): pyfunc = update_usecase check = self.unordered_checker(pyfunc) a = self.sparse_array(50) b = self.duplicates_array(50) c = self.sparse_array(50) check(a, b, c) class TestFloatSets(OtherTypesTest, BaseTest): """ Test sets with floating-point keys. """ # Only a few basic tests here, as the sanity of most operations doesn't # depend on the key type. def _range(self, stop): return np.arange(stop, dtype=np.float32) * np.float32(0.1) class TestTupleSets(OtherTypesTest, BaseTest): """ Test sets with tuple keys. """ def _range(self, stop): a = np.arange(stop, dtype=np.int64) b = a & 0x5555555555555555 c = (a & 0xaaaaaaaa).astype(np.int32) d = ((a >> 32) & 1).astype(np.bool_) return list(zip(b, c, d)) class TestUnboxing(BaseTest): """ Test unboxing of Python sets into native Numba sets. """ @contextlib.contextmanager def assert_type_error(self, msg): with self.assertRaises(TypeError) as raises: yield if msg is not None: self.assertRegexpMatches(str(raises.exception), msg) def check_unary(self, pyfunc): cfunc = jit(nopython=True)(pyfunc) def check(arg): expected = pyfunc(arg) got = cfunc(arg) self.assertPreciseEqual(got, expected) return check def test_numbers(self): check = self.check_unary(unbox_usecase) check(set([1, 2])) check(set([1j, 2.5j])) # Check allocation and sizing check(set(range(100))) def test_tuples(self): check = self.check_unary(unbox_usecase2) check(set([(1, 2), (3, 4)])) check(set([(1, 2j), (3, 4j)])) def test_set_inside_tuple(self): check = self.check_unary(unbox_usecase3) check((1, set([2, 3, 4]))) def test_set_of_tuples_inside_tuple(self): check = self.check_unary(unbox_usecase4) check((1, set([(2,), (3,)]))) def test_errors(self): # Error checking should ensure the set is homogeneous msg = "can't unbox heterogeneous set" pyfunc = noop cfunc = jit(nopython=True)(pyfunc) val = set([1, 2.5]) with self.assert_type_error(msg): cfunc(val) # The set hasn't been changed (bogus reflecting) self.assertEqual(val, set([1, 2.5])) with self.assert_type_error(msg): cfunc(set([1, 2j])) # Same when the set is nested in a tuple or namedtuple with self.assert_type_error(msg): cfunc((1, set([1, 2j]))) with self.assert_type_error(msg): cfunc(Point(1, set([1, 2j]))) # Tuples of different size. # Note the check is really on the tuple side. lst = set([(1,), (2, 3)]) # Depending on which tuple is examined first, we could get # a IndexError or a ValueError. with self.assertRaises((IndexError, ValueError)) as raises: cfunc(lst) class TestSetReflection(BaseTest): """ Test reflection of native Numba sets on Python set objects. """ def check_reflection(self, pyfunc): cfunc = jit(nopython=True)(pyfunc) samples = [(set([1., 2., 3., 4.]), set([0.])), (set([1., 2., 3., 4.]), set([5., 6., 7., 8., 9.])), ] for dest, src in samples: expected = set(dest) got = set(dest) pyres = pyfunc(expected, src) with self.assertRefCount(got, src): cres = cfunc(got, src) self.assertPreciseEqual(cres, pyres) self.assertPreciseEqual(expected, got) self.assertEqual(pyres[0] is expected, cres[0] is got) del pyres, cres def test_reflect_simple(self): self.check_reflection(reflect_simple) def test_reflect_conditional(self): self.check_reflection(reflect_conditional) def test_reflect_exception(self): """ When the function exits with an exception, sets should still be reflected. """ pyfunc = reflect_exception cfunc = jit(nopython=True)(pyfunc) s = set([1, 2, 3]) with self.assertRefCount(s): with self.assertRaises(ZeroDivisionError): cfunc(s) self.assertPreciseEqual(s, set([1, 2, 3, 42])) def test_reflect_same_set(self): """ When the same set object is reflected twice, behaviour should be consistent. """ pyfunc = reflect_dual cfunc = jit(nopython=True)(pyfunc) pyset = set([1, 2, 3]) cset = pyset.copy() expected = pyfunc(pyset, pyset) got = cfunc(cset, cset) self.assertPreciseEqual(expected, got) self.assertPreciseEqual(pyset, cset) self.assertPreciseEqual(sys.getrefcount(pyset), sys.getrefcount(cset)) def test_reflect_clean(self): """ When the set wasn't mutated, no reflection should take place. """ cfunc = jit(nopython=True)(noop) # Use a complex, as Python integers can be cached s = set([12.5j]) ids = [id(x) for x in s] cfunc(s) self.assertEqual([id(x) for x in s], ids) class TestExamples(BaseTest): """ Examples of using sets. """ def test_unique(self): pyfunc = unique_usecase check = self.unordered_checker(pyfunc) check(self.duplicates_array(200)) check(self.sparse_array(200)) def test_type_coercion_from_update(self): # see issue #6621 def impl(): i = np.uint64(1) R = set() R.update({1, 2, 3}) R.add(i) return R check = self.unordered_checker(impl) check() if __name__ == '__main__': unittest.main()