import itertools import numpy as np import unittest from numba.core.compiler import compile_isolated, Flags from numba import typeof, njit from numba.core import types, lowering from numba.tests.support import TestCase def return_double_or_none(x): if x: ret = None else: ret = 1.2 return ret def return_different_statement(x): if x: return None else: return 1.2 def return_bool_optional_or_none(x, y): if y: z = False else: z = None if x == 2: # A boolean return True elif x == 1: # A runtime optional return z else: # None return None def is_this_a_none(x): if x: val_or_none = None else: val_or_none = x if val_or_none is None: return x - 1 if val_or_none is not None: return x + 1 def a_is_b(a, b): """ Note in nopython mode, this operation does not make much sense. Because we don't have objects anymore. `a is b` is always False if not operating on None and Optional type """ return a is b def a_is_not_b(a, b): """ This is `not (a is b)` """ return a is not b class TestOptional(TestCase): _numba_parallel_test_ = False def test_return_double_or_none(self): pyfunc = return_double_or_none cres = compile_isolated(pyfunc, [types.boolean]) cfunc = cres.entry_point for v in [True, False]: self.assertPreciseEqual(pyfunc(v), cfunc(v)) def test_return_different_statement(self): pyfunc = return_different_statement cres = compile_isolated(pyfunc, [types.boolean]) cfunc = cres.entry_point for v in [True, False]: self.assertPreciseEqual(pyfunc(v), cfunc(v)) def test_return_bool_optional_or_none(self): pyfunc = return_bool_optional_or_none cres = compile_isolated(pyfunc, [types.int32, types.int32]) cfunc = cres.entry_point for x, y in itertools.product((0, 1, 2), (0, 1)): self.assertPreciseEqual(pyfunc(x, y), cfunc(x, y)) def test_is_this_a_none(self): pyfunc = is_this_a_none cres = compile_isolated(pyfunc, [types.intp]) cfunc = cres.entry_point for v in [-1, 0, 1, 2]: self.assertPreciseEqual(pyfunc(v), cfunc(v)) def test_is_this_a_none_objmode(self): pyfunc = is_this_a_none flags = Flags() flags.force_pyobject = True cres = compile_isolated(pyfunc, [types.intp], flags=flags) cfunc = cres.entry_point self.assertTrue(cres.objectmode) for v in [-1, 0, 1, 2]: self.assertPreciseEqual(pyfunc(v), cfunc(v)) def test_a_is_b_intp(self): pyfunc = a_is_b cres = compile_isolated(pyfunc, [types.intp, types.intp]) cfunc = cres.entry_point # integer identity relies on `==` self.assertTrue(cfunc(1, 1)) self.assertFalse(cfunc(1, 2)) def test_a_is_not_b_intp(self): pyfunc = a_is_not_b cres = compile_isolated(pyfunc, [types.intp, types.intp]) cfunc = cres.entry_point # integer identity relies on `==` self.assertFalse(cfunc(1, 1)) self.assertTrue(cfunc(1, 2)) def test_optional_float(self): def pyfunc(x, y): if y is None: return x else: return x + y cfunc = njit("(float64, optional(float64))")(pyfunc) self.assertAlmostEqual(pyfunc(1., 12.3), cfunc(1., 12.3)) self.assertAlmostEqual(pyfunc(1., None), cfunc(1., None)) def test_optional_array(self): def pyfunc(x, y): if y is None: return x else: y[0] += x return y[0] cfunc = njit("(float32, optional(float32[:]))")(pyfunc) cy = np.array([12.3], dtype=np.float32) py = cy.copy() self.assertAlmostEqual(pyfunc(1., py), cfunc(1., cy)) np.testing.assert_almost_equal(py, cy) self.assertAlmostEqual(pyfunc(1., None), cfunc(1., None)) def test_optional_array_error(self): def pyfunc(y): return y[0] cfunc = njit("(optional(int32[:]),)")(pyfunc) with self.assertRaises(TypeError) as raised: cfunc(None) self.assertIn('expected array(int32, 1d, A), got None', str(raised.exception)) y = np.array([0xabcd], dtype=np.int32) self.assertEqual(cfunc(y), pyfunc(y)) def test_optional_array_attribute(self): """ Check that we can access attribute of an optional """ def pyfunc(arr, do_it): opt = None if do_it: # forces `opt` to be an optional of arr opt = arr return opt.shape[0] cfunc = njit(pyfunc) arr = np.arange(5) self.assertEqual(pyfunc(arr, True), cfunc(arr, True)) def test_assign_to_optional(self): """ Check that we can assign to a variable of optional type """ @njit def make_optional(val, get_none): if get_none: ret = None else: ret = val return ret @njit def foo(val, run_second): a = make_optional(val, True) if run_second: a = make_optional(val, False) return a self.assertIsNone(foo(123, False)) self.assertEqual(foo(231, True), 231) def test_optional_thru_omitted_arg(self): """ Issue 1868 """ def pyfunc(x=None): if x is None: x = 1 return x cfunc = njit(pyfunc) self.assertEqual(pyfunc(), cfunc()) self.assertEqual(pyfunc(3), cfunc(3)) def test_optional_unpack(self): """ Issue 2171 """ def pyfunc(x): if x is None: return else: a, b = x return a, b tup = types.Tuple([types.intp] * 2) opt_tup = types.Optional(tup) sig = (opt_tup,) cfunc = njit(sig)(pyfunc) self.assertEqual(pyfunc(None), cfunc(None)) self.assertEqual(pyfunc((1, 2)), cfunc((1, 2))) def test_many_optional_none_returns(self): """ Issue #4058 """ @njit def foo(maybe): lx = None if maybe: lx = 10 return 1, lx def work(): tmp = [] for _ in range(20000): maybe = False _ = foo(maybe) # this caused "Fatal Python error: deallocating None" as there was no # incref being made on the returned None. work() if __name__ == '__main__': unittest.main()