""" Test problems in nested calls. Usually due to invalid type conversion between function boundaries. """ from numba import int32, int64 from numba import jit, generated_jit from numba.core import types from numba.tests.support import TestCase, tag import unittest @jit(nopython=True) def f_inner(a, b, c): return a, b, c def f(x, y, z): return f_inner(x, c=y, b=z) @jit(nopython=True) def g_inner(a, b=2, c=3): return a, b, c def g(x, y, z): return g_inner(x, b=y), g_inner(a=z, c=x) @jit(nopython=True) def star_inner(a=5, *b): return a, b def star(x, y, z): return star_inner(a=x), star_inner(x, y, z) def star_call(x, y, z): return star_inner(x, *y), star_inner(*z) @jit(nopython=True) def argcast_inner(a, b): if b: # Here `a` is unified to int64 (from int32 originally) a = int64(0) return a def argcast(a, b): return argcast_inner(int32(a), b) @generated_jit(nopython=True) def generated_inner(x, y=5, z=6): if isinstance(x, types.Complex): def impl(x, y, z): return x + y, z else: def impl(x, y, z): return x - y, z return impl def call_generated(a, b): return generated_inner(a, z=b) class TestNestedCall(TestCase): def compile_func(self, pyfunc, objmode=False): def check(*args, **kwargs): expected = pyfunc(*args, **kwargs) result = f(*args, **kwargs) self.assertPreciseEqual(result, expected) flags = dict(forceobj=True) if objmode else dict(nopython=True) f = jit(**flags)(pyfunc) return f, check def test_boolean_return(self): @jit(nopython=True) def inner(x): return not x @jit(nopython=True) def outer(x): if inner(x): return True else: return False self.assertFalse(outer(True)) self.assertTrue(outer(False)) def test_named_args(self, objmode=False): """ Test a nested function call with named (keyword) arguments. """ cfunc, check = self.compile_func(f, objmode) check(1, 2, 3) check(1, y=2, z=3) def test_named_args_objmode(self): self.test_named_args(objmode=True) def test_default_args(self, objmode=False): """ Test a nested function call using default argument values. """ cfunc, check = self.compile_func(g, objmode) check(1, 2, 3) check(1, y=2, z=3) def test_default_args_objmode(self): self.test_default_args(objmode=True) def test_star_args(self): """ Test a nested function call to a function with *args in its signature. """ cfunc, check = self.compile_func(star) check(1, 2, 3) def test_star_call(self, objmode=False): """ Test a function call with a *args. """ cfunc, check = self.compile_func(star_call, objmode) check(1, (2,), (3,)) def test_star_call_objmode(self): self.test_star_call(objmode=True) def test_argcast(self): """ Issue #1488: implicitly casting an argument variable should not break nested calls. """ cfunc, check = self.compile_func(argcast) check(1, 0) check(1, 1) def test_call_generated(self): """ Test a nested function call to a generated jit function. """ cfunc = jit(nopython=True)(call_generated) self.assertPreciseEqual(cfunc(1, 2), (-4, 2)) self.assertPreciseEqual(cfunc(1j, 2), (1j + 5, 2)) if __name__ == '__main__': unittest.main()