from collections import namedtuple import numpy as np from numba.tests.support import (TestCase, MemoryLeakMixin, skip_parfors_unsupported, captured_stdout) from numba import njit, typed, literal_unroll, prange from numba.core import types, errors, ir from numba.testing import unittest from numba.core.extending import overload from numba.core.compiler_machinery import (PassManager, register_pass, FunctionPass, AnalysisPass) from numba.core.compiler import CompilerBase from numba.core.untyped_passes import (FixupArgs, TranslateByteCode, IRProcessing, InlineClosureLikes, SimplifyCFG, IterLoopCanonicalization, LiteralUnroll, PreserveIR) from numba.core.typed_passes import (NopythonTypeInference, IRLegalization, NoPythonBackend, PartialTypeInference, NativeLowering) from numba.core.ir_utils import (compute_cfg_from_blocks, flatten_labels) from numba.core.types.functions import _header_lead _X_GLOBAL = (10, 11) class TestLiteralTupleInterpretation(MemoryLeakMixin, TestCase): def check(self, func, var): cres = func.overloads[func.signatures[0]] ty = cres.fndesc.typemap[var] self.assertTrue(isinstance(ty, types.Tuple)) for subty in ty: self.assertTrue(isinstance(subty, types.Literal), "non literal") def test_homogeneous_literal(self): @njit def foo(): x = (1, 2, 3) return x[1] self.assertEqual(foo(), foo.py_func()) self.check(foo, 'x') def test_heterogeneous_literal(self): @njit def foo(): x = (1, 2, 3, 'a') return x[3] self.assertEqual(foo(), foo.py_func()) self.check(foo, 'x') def test_non_literal(self): @njit def foo(): x = (1, 2, 3, 'a', 1j) return x[4] self.assertEqual(foo(), foo.py_func()) with self.assertRaises(AssertionError) as e: self.check(foo, 'x') self.assertIn("non literal", str(e.exception)) @register_pass(mutates_CFG=False, analysis_only=False) class ResetTypeInfo(FunctionPass): _name = "reset_the_type_information" def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): state.typemap = None state.return_type = None state.calltypes = None return True class TestLoopCanonicalisation(MemoryLeakMixin, TestCase): def get_pipeline(use_canonicaliser, use_partial_typing=False): class NewCompiler(CompilerBase): def define_pipelines(self): pm = PassManager("custom_pipeline") # untyped pm.add_pass(TranslateByteCode, "analyzing bytecode") pm.add_pass(IRProcessing, "processing IR") pm.add_pass(InlineClosureLikes, "inline calls to locally defined closures") if use_partial_typing: pm.add_pass(PartialTypeInference, "do partial typing") if use_canonicaliser: pm.add_pass(IterLoopCanonicalization, "Canonicalise loops") pm.add_pass(SimplifyCFG, "Simplify the CFG") # typed if use_partial_typing: pm.add_pass(ResetTypeInfo, "resets the type info state") pm.add_pass(NopythonTypeInference, "nopython frontend") # legalise pm.add_pass(IRLegalization, "ensure IR is legal") # preserve pm.add_pass(PreserveIR, "save IR for later inspection") # lower pm.add_pass(NativeLowering, "native lowering") pm.add_pass(NoPythonBackend, "nopython mode backend") # finalise the contents pm.finalize() return [pm] return NewCompiler # generate variants LoopIgnoringCompiler = get_pipeline(False) LoopCanonicalisingCompiler = get_pipeline(True) TypedLoopCanonicalisingCompiler = get_pipeline(True, True) def test_simple_loop_in_depth(self): """ This heavily checks a simple loop transform """ def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(tup): acc = 0 for i in tup: acc += i return acc x = (1, 2, 3) self.assertEqual(foo(x), foo.py_func(x)) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.LoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's three more call types in the canonicalised one: # len(tuple arg) # range(of the len() above) # getitem(tuple arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 3, len(canonicalise_loops_fndesc.calltypes)) def find_getX(fd, op): return [x for x in fd.calltypes.keys() if isinstance(x, ir.Expr) and x.op == op] il_getiters = find_getX(ignore_loops_fndesc, "getiter") self.assertEqual(len(il_getiters), 1) # tuple iterator cl_getiters = find_getX(canonicalise_loops_fndesc, "getiter") self.assertEqual(len(cl_getiters), 1) # loop range iterator cl_getitems = find_getX(canonicalise_loops_fndesc, "getitem") self.assertEqual(len(cl_getitems), 1) # tuple getitem induced by loop # check the value of the untransformed IR getiter is now the value of # the transformed getitem self.assertEqual(il_getiters[0].value.name, cl_getitems[0].value.name) # check the type of the transformed IR getiter is a range iter range_inst = canonicalise_loops_fndesc.calltypes[cl_getiters[0]].args[0] self.assertTrue(isinstance(range_inst, types.RangeType)) def test_transform_scope(self): """ This checks the transform, when there's no typemap, will happily transform a loop on something that's not tuple-like """ def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(): acc = 0 for i in [1, 2, 3]: acc += i return acc self.assertEqual(foo(), foo.py_func()) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.LoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's three more call types in the canonicalised one: # len(literal list) # range(of the len() above) # getitem(literal list arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 3, len(canonicalise_loops_fndesc.calltypes)) def find_getX(fd, op): return [x for x in fd.calltypes.keys() if isinstance(x, ir.Expr) and x.op == op] il_getiters = find_getX(ignore_loops_fndesc, "getiter") self.assertEqual(len(il_getiters), 1) # list iterator cl_getiters = find_getX(canonicalise_loops_fndesc, "getiter") self.assertEqual(len(cl_getiters), 1) # loop range iterator cl_getitems = find_getX(canonicalise_loops_fndesc, "getitem") self.assertEqual(len(cl_getitems), 1) # list getitem induced by loop # check the value of the untransformed IR getiter is now the value of # the transformed getitem self.assertEqual(il_getiters[0].value.name, cl_getitems[0].value.name) # check the type of the transformed IR getiter is a range iter range_inst = canonicalise_loops_fndesc.calltypes[cl_getiters[0]].args[0] self.assertTrue(isinstance(range_inst, types.RangeType)) @unittest.skip("Waiting for pass to be enabled for all tuples") def test_influence_of_typed_transform(self): """ This heavily checks a typed transformation only impacts tuple induced loops""" def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(tup): acc = 0 for i in range(4): for y in tup: for j in range(3): acc += 1 return acc x = (1, 2, 3) self.assertEqual(foo(x), foo.py_func(x)) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['func_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.TypedLoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's three more call types in the canonicalised one: # len(tuple arg) # range(of the len() above) # getitem(tuple arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 3, len(canonicalise_loops_fndesc.calltypes)) def find_getX(fd, op): return [x for x in fd.calltypes.keys() if isinstance(x, ir.Expr) and x.op == op] il_getiters = find_getX(ignore_loops_fndesc, "getiter") self.assertEqual(len(il_getiters), 3) # 1 * tuple + 2 * loop range cl_getiters = find_getX(canonicalise_loops_fndesc, "getiter") self.assertEqual(len(cl_getiters), 3) # 3 * loop range iterator cl_getitems = find_getX(canonicalise_loops_fndesc, "getitem") self.assertEqual(len(cl_getitems), 1) # tuple getitem induced by loop # check the value of the untransformed IR getiter is now the value of # the transformed getitem self.assertEqual(il_getiters[1].value.name, cl_getitems[0].value.name) # check the type of the transformed IR getiter's are all range iter for x in cl_getiters: range_inst = canonicalise_loops_fndesc.calltypes[x].args[0] self.assertTrue(isinstance(range_inst, types.RangeType)) def test_influence_of_typed_transform_literal_unroll(self): """ This heavily checks a typed transformation only impacts loops with literal_unroll marker""" def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(tup): acc = 0 for i in range(4): for y in literal_unroll(tup): for j in range(3): acc += 1 return acc x = (1, 2, 3) self.assertEqual(foo(x), foo.py_func(x)) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.TypedLoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's three more call types in the canonicalised one: # len(tuple arg) # range(of the len() above) # getitem(tuple arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 3, len(canonicalise_loops_fndesc.calltypes)) def find_getX(fd, op): return [x for x in fd.calltypes.keys() if isinstance(x, ir.Expr) and x.op == op] il_getiters = find_getX(ignore_loops_fndesc, "getiter") self.assertEqual(len(il_getiters), 3) # 1 * tuple + 2 * loop range cl_getiters = find_getX(canonicalise_loops_fndesc, "getiter") self.assertEqual(len(cl_getiters), 3) # 3 * loop range iterator cl_getitems = find_getX(canonicalise_loops_fndesc, "getitem") self.assertEqual(len(cl_getitems), 1) # tuple getitem induced by loop # check the value of the untransformed IR getiter is now the value of # the transformed getitem self.assertEqual(il_getiters[1].value.name, cl_getitems[0].value.name) # check the type of the transformed IR getiter's are all range iter for x in cl_getiters: range_inst = canonicalise_loops_fndesc.calltypes[x].args[0] self.assertTrue(isinstance(range_inst, types.RangeType)) @unittest.skip("Waiting for pass to be enabled for all tuples") def test_lots_of_loops(self): """ This heavily checks a simple loop transform """ def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(tup): acc = 0 for i in tup: acc += i for j in tup + (4, 5, 6): acc += 1 - j if j > 5: break else: acc -= 2 for i in tup: acc -= i % 2 return acc x = (1, 2, 3) self.assertEqual(foo(x), foo.py_func(x)) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.LoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's three * N more call types in the canonicalised one: # len(tuple arg) # range(of the len() above) # getitem(tuple arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 3 * 3, len(canonicalise_loops_fndesc.calltypes)) def test_inlined_loops(self): """ Checks a loop appearing from a closure """ def get_info(pipeline): @njit(pipeline_class=pipeline) def foo(tup): def bar(n): acc = 0 for i in range(n): acc += 1 return acc acc = 0 for i in tup: acc += i acc += bar(i) return acc x = (1, 2, 3) self.assertEqual(foo(x), foo.py_func(x)) cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir, cres.fndesc ignore_loops_ir, ignore_loops_fndesc = \ get_info(self.LoopIgnoringCompiler) canonicalise_loops_ir, canonicalise_loops_fndesc = \ get_info(self.LoopCanonicalisingCompiler) # check CFG is the same def compare_cfg(a, b): a_cfg = compute_cfg_from_blocks(flatten_labels(a.blocks)) b_cfg = compute_cfg_from_blocks(flatten_labels(b.blocks)) self.assertEqual(a_cfg, b_cfg) compare_cfg(ignore_loops_ir, canonicalise_loops_ir) # check there's 2 * N - 1 more call types in the canonicalised one: # The -1 comes from the closure being inlined and and the call removed. # len(tuple arg) # range(of the len() above) # getitem(tuple arg, index) self.assertEqual(len(ignore_loops_fndesc.calltypes) + 5, len(canonicalise_loops_fndesc.calltypes)) class TestMixedTupleUnroll(MemoryLeakMixin, TestCase): def test_01(self): # test a case which is already in loop canonical form @njit def foo(idx, z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for i in range(len(literal_unroll(a))): acc += a[i] if acc.real < 26: acc -= 1 else: break return acc f = 9 k = f self.assertEqual(foo(2, k), foo.py_func(2, k)) def test_02(self): # same as test_1 but without the explicit loop canonicalisation @njit def foo(idx, z): x = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for a in literal_unroll(x): acc += a if acc.real < 26: acc -= 1 else: break return acc f = 9 k = f self.assertEqual(foo(2, k), foo.py_func(2, k)) def test_03(self): # two unrolls @njit def foo(idx, z): x = (12, 12.7, 3j, 4, z, 2 * z) y = ('foo', z, 2 * z) acc = 0 for a in literal_unroll(x): acc += a if acc.real < 26: acc -= 1 else: for t in literal_unroll(y): acc += t is False break return acc f = 9 k = f self.assertEqual(foo(2, k), foo.py_func(2, k)) def test_04(self): # mixed ref counted types @njit def foo(tup): acc = 0 for a in literal_unroll(tup): acc += a.sum() return acc n = 10 tup = (np.ones((n,)), np.ones((n, n)), np.ones((n, n, n))) self.assertEqual(foo(tup), foo.py_func(tup)) def test_05(self): # mix unroll and static_getitem @njit def foo(tup1, tup2): acc = 0 for a in literal_unroll(tup1): if a == 'a': acc += tup2[0].sum() elif a == 'b': acc += tup2[1].sum() elif a == 'c': acc += tup2[2].sum() elif a == 12: acc += tup2[3].sum() elif a == 3j: acc += tup2[4].sum() else: raise RuntimeError("Unreachable") return acc n = 10 tup1 = ('a', 'b', 'c', 12, 3j,) tup2 = (np.ones((n,)), np.ones((n, n)), np.ones((n, n, n)), np.ones((n, n, n, n)), np.ones((n, n, n, n, n))) self.assertEqual(foo(tup1, tup2), foo.py_func(tup1, tup2)) @unittest.skip("needs more clever branch prune") def test_06(self): # This wont work because both sides of the branch need typing as neither # can be pruned by the current pruner @njit def foo(tup): acc = 0 str_buf = typed.List.empty_list(types.unicode_type) for a in literal_unroll(tup): if a == 'a': str_buf.append(a) else: acc += a return acc tup = ('a', 12) self.assertEqual(foo(tup), foo.py_func(tup)) def test_07(self): # A mix bag of stuff as an arg to a function that unifies as `intp`. @njit def foo(tup): acc = 0 for a in literal_unroll(tup): acc += len(a) return acc n = 10 tup = (np.ones((n,)), np.ones((n, n)), "ABCDEFGHJI", (1, 2, 3), (1, 'foo', 2, 'bar'), {3, 4, 5, 6, 7}) self.assertEqual(foo(tup), foo.py_func(tup)) def test_08(self): # dispatch to functions @njit def foo(tup1, tup2): acc = 0 for a in literal_unroll(tup1): if a == 'a': acc += tup2[0]() elif a == 'b': acc += tup2[1]() elif a == 'c': acc += tup2[2]() return acc def gen(x): def impl(): return x return njit(impl) tup1 = ('a', 'b', 'c', 12, 3j, ('f',)) tup2 = (gen(1), gen(2), gen(3)) self.assertEqual(foo(tup1, tup2), foo.py_func(tup1, tup2)) def test_09(self): # illegal RHS, has a mixed tuple being index dynamically @njit def foo(tup1, tup2): acc = 0 idx = 0 for a in literal_unroll(tup1): if a == 'a': acc += tup2[idx] elif a == 'b': acc += tup2[idx] elif a == 'c': acc += tup2[idx] idx += 1 return idx, acc @njit def func1(): return 1 @njit def func2(): return 2 @njit def func3(): return 3 tup1 = ('a', 'b', 'c') tup2 = (1j, 1, 2) with self.assertRaises(errors.TypingError) as raises: foo(tup1, tup2) self.assertIn(_header_lead, str(raises.exception)) def test_10(self): # dispatch on literals triggering @overload resolution def dt(value): if value == "apple": return 1 elif value == "orange": return 2 elif value == "banana": return 3 elif value == 0xca11ab1e: return 0x5ca1ab1e + value @overload(dt, inline='always') def ol_dt(li): if isinstance(li, types.StringLiteral): value = li.literal_value if value == "apple": def impl(li): return 1 elif value == "orange": def impl(li): return 2 elif value == "banana": def impl(li): return 3 return impl elif isinstance(li, types.IntegerLiteral): value = li.literal_value if value == 0xca11ab1e: def impl(li): # close over the dispatcher :) return 0x5ca1ab1e + value return impl @njit def foo(): acc = 0 for t in literal_unroll(('apple', 'orange', 'banana', 3390155550)): acc += dt(t) return acc self.assertEqual(foo(), foo.py_func()) def test_11(self): @njit def foo(): x = [] z = ('apple', 'orange', 'banana') for i in range(len(literal_unroll(z))): t = z[i] if t == "apple": x.append("0") elif t == "orange": x.append(t) elif t == "banana": x.append("2.0") return x self.assertEqual(foo(), foo.py_func()) def test_11a(self): @njit def foo(): x = typed.List() z = ('apple', 'orange', 'banana') for i in range(len(literal_unroll(z))): t = z[i] if t == "apple": x.append("0") elif t == "orange": x.append(t) elif t == "banana": x.append("2.0") return x self.assertEqual(foo(), foo.py_func()) def test_12(self): # unroll the same target twice @njit def foo(idx, z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for i in literal_unroll(a): acc += i if acc.real < 26: acc -= 1 else: for x in literal_unroll(a): acc += x break if a[0] < 23: acc += 2 return acc f = 9 k = f self.assertEqual(foo(2, k), foo.py_func(2, k)) def test_13(self): # nesting unrolls is illegal @njit def foo(idx, z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for i in literal_unroll(a): acc += i if acc.real < 26: acc -= 1 else: for x in literal_unroll(a): for j in literal_unroll(a): acc += j acc += x for x in literal_unroll(a): acc += x for x in literal_unroll(a): acc += x if a[0] < 23: acc += 2 return acc f = 9 k = f with self.assertRaises(errors.UnsupportedError) as raises: foo(2, k) self.assertIn("Nesting of literal_unroll is unsupported", str(raises.exception)) def test_14(self): # unituple unroll can return derivative of the induction var @njit def foo(): x = (1, 2, 3, 4) acc = 0 for a in literal_unroll(x): acc += a return a self.assertEqual(foo(), foo.py_func()) def test_15(self): # mixed tuple unroll cannot return derivative of the induction var @njit def foo(x): acc = 0 for a in literal_unroll(x): acc += len(a) return a n = 5 tup = (np.ones((n,)), np.ones((n, n)), "ABCDEFGHJI", (1, 2, 3), (1, 'foo', 2, 'bar'), {3, 4, 5, 6, 7}) with self.assertRaises(errors.TypingError) as raises: foo(tup) self.assertIn("Cannot unify", str(raises.exception)) def test_16(self): # unituple slice and unroll is ok def dt(value): if value == 1000: return "a" elif value == 2000: return "b" elif value == 3000: return "c" elif value == 4000: return "d" @overload(dt, inline='always') def ol_dt(li): if isinstance(li, types.IntegerLiteral): value = li.literal_value if value == 1000: def impl(li): return "a" elif value == 2000: def impl(li): return "b" elif value == 3000: def impl(li): return "c" elif value == 4000: def impl(li): return "d" return impl @njit def foo(): x = (1000, 2000, 3000, 4000) acc = "" for a in literal_unroll(x[:2]): acc += dt(a) return acc self.assertEqual(foo(), foo.py_func()) def test_17(self): # mixed tuple slice and unroll is ok def dt(value): if value == 1000: return "a" elif value == 2000: return "b" elif value == 3000: return "c" elif value == 4000: return "d" elif value == 'f': return "EFF" @overload(dt, inline='always') def ol_dt(li): if isinstance(li, types.IntegerLiteral): value = li.literal_value if value == 1000: def impl(li): return "a" elif value == 2000: def impl(li): return "b" elif value == 3000: def impl(li): return "c" elif value == 4000: def impl(li): return "d" return impl elif isinstance(li, types.StringLiteral): value = li.literal_value if value == 'f': def impl(li): return "EFF" return impl @njit def foo(): x = (1000, 2000, 3000, 'f') acc = "" for a in literal_unroll(x[1:]): acc += dt(a) return acc self.assertEqual(foo(), foo.py_func()) def test_18(self): # unituple backwards slice @njit def foo(): x = (1000, 2000, 3000, 4000, 5000, 6000) count = 0 for a in literal_unroll(x[::-1]): count += 1 if a < 3000: break return count self.assertEqual(foo(), foo.py_func()) def test_19(self): # mixed bag of refcounted @njit def foo(): acc = 0 l1 = [1, 2, 3, 4] l2 = [10, 20] tup = (l1, l2) a1 = np.arange(20) a2 = np.ones(5, dtype=np.complex128) tup = (l1, a1, l2, a2) for t in literal_unroll(tup): acc += len(t) return acc self.assertEqual(foo(), foo.py_func()) def test_20(self): # testing partial type inference survives as the list append in the # unrolled version is full inferable @njit def foo(): l = [] a1 = np.arange(20) a2 = np.ones(5, dtype=np.complex128) tup = (a1, a2) for t in literal_unroll(tup): l.append(t.sum()) return l self.assertEqual(foo(), foo.py_func()) def test_21(self): # unroll in closure that gets inlined @njit def foo(z): b = (23, 23.9, 6j, 8) def bar(): acc = 0 for j in literal_unroll(b): acc += j return acc outer_acc = 0 for x in (1, 2, 3, 4): outer_acc += bar() + x return outer_acc f = 9 k = f self.assertEqual(foo(k), foo.py_func(k)) def test_22(self): @njit def foo(z): a = (12, 12.7, 3j, 4, z, 2 * z) b = (23, 23.9, 6j, 8) def bar(): acc = 0 for j in literal_unroll(b): acc += j return acc acc = 0 # this loop is induced in `x` but `x` is not used, there is a nest # here by virtue of inlining for x in literal_unroll(a): acc += bar() return acc f = 9 k = f self.assertEqual(foo(k), foo.py_func(k)) def test_23(self): # unroll from closure that ends up banned as it leads to nesting @njit def foo(z): b = (23, 23.9, 6j, 8) def bar(): acc = 0 for j in literal_unroll(b): acc += j return acc outer_acc = 0 # this drives an inlined literal_unroll loop but also has access to # the induction variable, this is a nested literal_unroll so is # banned for x in literal_unroll(b): outer_acc += bar() + x return outer_acc f = 9 k = f with self.assertRaises(errors.UnsupportedError) as raises: foo(k) self.assertIn("Nesting of literal_unroll is unsupported", str(raises.exception)) def test_24(self): # unroll something unsupported @njit def foo(): for x in literal_unroll("ABCDE"): print(x) with self.assertRaises(errors.UnsupportedError) as raises: foo() msg = "argument should be a tuple or a list of constant values" self.assertIn(msg, str(raises.exception)) def test_25(self): # use unroll by reference/alias @njit def foo(): val = literal_unroll(((1, 2, 3), (2j, 3j), [1, 2], "xyz")) alias1 = val alias2 = alias1 lens = [] for x in alias2: lens.append(len(x)) return lens self.assertEqual(foo(), foo.py_func()) def test_26(self): # var defined in unrolled body escapes # untouched variable is untouched # read only variable is only read # mutated is muted correctly @njit def foo(z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 count = 0 untouched = 54 read_only = 17 mutated = np.empty((len(a),), dtype=np.complex128) for x in literal_unroll(a): acc += x mutated[count] = x count += 1 escape = count + read_only return escape, acc, untouched, read_only, mutated f = 9 k = f self.assertPreciseEqual(foo(k), foo.py_func(k)) @skip_parfors_unsupported def test_27(self): # parfors loop in unrolled loop @njit(parallel=True) def foo(z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for x in literal_unroll(a): for k in prange(10): acc += 1 return acc f = 9 k = f self.assertEqual(foo(k), foo.py_func(k)) @skip_parfors_unsupported def test_28(self): # parfors reducing on the unrolled induction var @njit(parallel=True) def foo(z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for x in literal_unroll(a): for k in prange(10): acc += x return acc f = 9 k = f # summation is unstable np.testing.assert_allclose(foo(k), foo.py_func(k)) @skip_parfors_unsupported def test_29(self): # This "works" but parfors is not producing a parallel loop # TODO: fix @njit(parallel=True) def foo(z): a = (12, 12.7, 3j, 4, z, 2 * z) acc = 0 for k in prange(10): for x in literal_unroll(a): acc += x return acc f = 9 k = f self.assertEqual(foo(k), foo.py_func(k)) def test_30(self): # function escaping containing an unroll @njit def foo(): const = 1234 def bar(t): acc = 0 a = (12, 12.7, 3j, 4) for x in literal_unroll(a): acc += x + const return acc, t return [x for x in map(bar, (1, 2))] self.assertEqual(foo(), foo.py_func()) def test_31(self): # this is testing that generators can survive partial typing # invalid function escaping, map uses zip which can't handle the mixed # tuple @njit def foo(): const = 1234 def bar(t): acc = 0 a = (12, 12.7, 3j, 4) for x in literal_unroll(a): acc += x + const return acc, t return [x for x in map(bar, (1, 2j))] with self.assertRaises(errors.TypingError) as raises: foo() self.assertIn(_header_lead, str(raises.exception)) self.assertIn("zip", str(raises.exception)) def test_32(self): # test yielding from an unroll @njit def gen(a): for x in literal_unroll(a): yield x @njit def foo(): return [x for x in gen((1, 2.3, 4j,))] self.assertEqual(foo(), foo.py_func()) def test_33(self): # test yielding from unroll in escaping function that is consumed and # yields @njit def consumer(func, arg): yield func(arg) def get(cons): @njit def foo(): def gen(a): for x in literal_unroll(a): yield x return [next(x) for x in cons(gen, (1, 2.3, 4j,))] return foo cfunc = get(consumer) pyfunc = get(consumer.py_func).py_func self.assertEqual(cfunc(), pyfunc()) def test_34(self): # mixed bag, redefinition of tuple @njit def foo(): acc = 0 l1 = [1, 2, 3, 4] l2 = [10, 20] if acc - 2 > 3: tup = (l1, l2) else: a1 = np.arange(20) a2 = np.ones(5, dtype=np.complex128) tup = (l1, a1, l2, a2) for t in literal_unroll(tup): acc += len(t) return acc with self.assertRaises(errors.UnsupportedError) as raises: foo() self.assertIn("Invalid use of", str(raises.exception)) self.assertIn("found multiple definitions of variable", str(raises.exception)) class TestConstListUnroll(MemoryLeakMixin, TestCase): def test_01(self): @njit def foo(): a = [12, 12.7, 3j, 4] acc = 0 for i in range(len(literal_unroll(a))): acc += a[i] if acc.real < 26: acc -= 1 else: break return acc self.assertEqual(foo(), foo.py_func()) def test_02(self): # same as test_1 but without the explicit loop canonicalisation @njit def foo(): x = [12, 12.7, 3j, 4] acc = 0 for a in literal_unroll(x): acc += a if acc.real < 26: acc -= 1 else: break return acc self.assertEqual(foo(), foo.py_func()) def test_03(self): # two unrolls @njit def foo(): x = [12, 12.7, 3j, 4] y = ['foo', 8] acc = 0 for a in literal_unroll(x): acc += a if acc.real < 26: acc -= 1 else: for t in literal_unroll(y): acc += t is False break return acc self.assertEqual(foo(), foo.py_func()) def test_04(self): # two unrolls, one is a const list, one is a tuple @njit def foo(): x = [12, 12.7, 3j, 4] y = ('foo', 8) acc = 0 for a in literal_unroll(x): acc += a if acc.real < 26: acc -= 1 else: for t in literal_unroll(y): acc += t is False break return acc self.assertEqual(foo(), foo.py_func()) def test_05(self): # illegal, list has to be const @njit def foo(tup1, tup2): acc = 0 for a in literal_unroll(tup1): if a[0] > 1: acc += tup2[0].sum() return acc n = 10 tup1 = [np.zeros(10), np.zeros(10)] tup2 = (np.ones((n,)), np.ones((n, n)), np.ones((n, n, n)), np.ones((n, n, n, n)), np.ones((n, n, n, n, n))) with self.assertRaises(errors.UnsupportedError) as raises: foo(tup1, tup2) msg = "Invalid use of literal_unroll with a function argument" self.assertIn(msg, str(raises.exception)) def test_06(self): # illegal: list containing non const @njit def foo(): n = 10 tup = [np.ones((n,)), np.ones((n, n)), "ABCDEFGHJI", (1, 2, 3), (1, 'foo', 2, 'bar'), {3, 4, 5, 6, 7}] acc = 0 for a in literal_unroll(tup): acc += len(a) return acc with self.assertRaises(errors.UnsupportedError) as raises: foo() self.assertIn("Found non-constant value at position 0", str(raises.exception)) def test_7(self): # dispatch on literals triggering @overload resolution def dt(value): if value == "apple": return 1 elif value == "orange": return 2 elif value == "banana": return 3 elif value == 0xca11ab1e: return 0x5ca1ab1e + value @overload(dt, inline='always') def ol_dt(li): if isinstance(li, types.StringLiteral): value = li.literal_value if value == "apple": def impl(li): return 1 elif value == "orange": def impl(li): return 2 elif value == "banana": def impl(li): return 3 return impl elif isinstance(li, types.IntegerLiteral): value = li.literal_value if value == 0xca11ab1e: def impl(li): # close over the dispatcher :) return 0x5ca1ab1e + value return impl @njit def foo(): acc = 0 for t in literal_unroll(['apple', 'orange', 'banana', 3390155550]): acc += dt(t) return acc self.assertEqual(foo(), foo.py_func()) def test_8(self): @njit def foo(): x = [] z = ['apple', 'orange', 'banana'] for i in range(len(literal_unroll(z))): t = z[i] if t == "apple": x.append("0") elif t == "orange": x.append(t) elif t == "banana": x.append("2.0") return x self.assertEqual(foo(), foo.py_func()) def test_9(self): # unroll the same target twice @njit def foo(idx, z): a = [12, 12.7, 3j, 4] acc = 0 for i in literal_unroll(a): acc += i if acc.real < 26: acc -= 1 else: for x in literal_unroll(a): acc += x break if a[0] < 23: acc += 2 return acc f = 9 k = f self.assertEqual(foo(2, k), foo.py_func(2, k)) def test_10(self): # nesting unrolls is illegal @njit def foo(idx, z): a = (12, 12.7, 3j, 4, z, 2 * z) b = [12, 12.7, 3j, 4] acc = 0 for i in literal_unroll(a): acc += i if acc.real < 26: acc -= 1 else: for x in literal_unroll(a): for j in literal_unroll(b): acc += j acc += x for x in literal_unroll(a): acc += x for x in literal_unroll(a): acc += x if a[0] < 23: acc += 2 return acc f = 9 k = f with self.assertRaises(errors.UnsupportedError) as raises: foo(2, k) self.assertIn("Nesting of literal_unroll is unsupported", str(raises.exception)) def test_11(self): # homogeneous const list unroll can return derivative of the induction # var @njit def foo(): x = [1, 2, 3, 4] acc = 0 for a in literal_unroll(x): acc += a return a self.assertEqual(foo(), foo.py_func()) def test_12(self): # mixed unroll cannot return derivative of the induction var @njit def foo(): acc = 0 x = [1, 2, 'a'] for a in literal_unroll(x): acc += bool(a) return a with self.assertRaises(errors.TypingError) as raises: foo() self.assertIn("Cannot unify", str(raises.exception)) def test_13(self): # list slice is illegal @njit def foo(): x = [1000, 2000, 3000, 4000] acc = 0 for a in literal_unroll(x[:2]): acc += a return acc with self.assertRaises(errors.UnsupportedError) as raises: foo() self.assertIn("Invalid use of literal_unroll", str(raises.exception)) def test_14(self): # list mutate is illegal @njit def foo(): x = [1000, 2000, 3000, 4000] acc = 0 for a in literal_unroll(x): acc += a x.append(10) return acc with self.assertRaises(errors.TypingError) as raises: foo() self.assertIn("Unknown attribute 'append' of type Tuple", str(raises.exception)) class TestMore(TestCase): def test_invalid_use_of_unroller(self): @njit def foo(): x = (10, 20) r = 0 for a in literal_unroll(x, x): r += a return r with self.assertRaises(errors.UnsupportedError) as raises: foo() self.assertIn( "literal_unroll takes one argument, found 2", str(raises.exception), ) def test_non_constant_list(self): @njit def foo(y): x = [10, y] r = 0 for a in literal_unroll(x): r += a return r with self.assertRaises(errors.UnsupportedError) as raises: foo(10) self.assertIn( ("Found non-constant value at position 1 in a list argument to " "literal_unroll"), str(raises.exception) ) @unittest.skip("numba.literally not supported yet") def test_literally_constant_list(self): # FAIL. May need to consider it in a future PR from numba import literally @njit def foo(y): x = [10, literally(y)] r = 0 for a in literal_unroll(x): r += a return r # Found non-constant value at position 1 in a list argument to # literal_unroll foo(12) @njit def bar(): return foo(12) # Found non-constant value at position 1 in a list argument to # literal_unroll bar() @unittest.skip("inlining of foo doesn't have const prop so y isn't const") def test_inlined_unroll_list(self): @njit(inline='always') def foo(y): x = [10, y] r = 0 for a in literal_unroll(x): r += a return r @njit def bar(): return foo(12) self.assertEqual(bar(), 10 + 12) def test_unroll_tuple_arg(self): @njit def foo(y): x = (10, y) r = 0 for a in literal_unroll(x): r += a return r self.assertEqual(foo(12), foo.py_func(12)) self.assertEqual(foo(1.2), foo.py_func(1.2)) def test_unroll_tuple_arg2(self): @njit def foo(x): r = 0 for a in literal_unroll(x): r += a return r self.assertEqual(foo((12, 1.2)), foo.py_func((12, 1.2))) self.assertEqual(foo((12, 1.2)), foo.py_func((12, 1.2))) def test_unroll_tuple_alias(self): @njit def foo(): x = (10, 1.2) out = 0 for i in literal_unroll(x): j = i k = j out += j + k + i return out self.assertEqual(foo(), foo.py_func()) def test_unroll_tuple_nested(self): @njit def foo(): x = ((10, 1.2), (1j, 3.)) out = 0 for i in literal_unroll(x): for j in (i): out += j return out with self.assertRaises(errors.TypingError) as raises: foo() self.assertIn("getiter", str(raises.exception)) re = r".*Tuple\(int[0-9][0-9], float64\).*" self.assertRegexpMatches(str(raises.exception), re) def test_unroll_tuple_of_dict(self): @njit def foo(): x = {} x["a"] = 1 x["b"] = 2 y = {} y[3] = "c" y[4] = "d" for it in literal_unroll((x, y)): for k, v in it.items(): print(k, v) with captured_stdout() as stdout: foo() lines = stdout.getvalue().splitlines() self.assertEqual( lines, ['a 1', 'b 2', '3 c', '4 d'], ) def test_unroll_named_tuple(self): ABC = namedtuple('ABC', ['a', 'b', 'c']) @njit def foo(): abc = ABC(1, 2j, 3.4) out = 0 for i in literal_unroll(abc): out += i return out self.assertEqual(foo(), foo.py_func()) def test_unroll_named_tuple_arg(self): ABC = namedtuple('ABC', ['a', 'b', 'c']) @njit def foo(x): out = 0 for i in literal_unroll(x): out += i return out abc = ABC(1, 2j, 3.4) self.assertEqual(foo(abc), foo.py_func(abc)) def test_unroll_named_unituple(self): ABC = namedtuple('ABC', ['a', 'b', 'c']) @njit def foo(): abc = ABC(1, 2, 3) out = 0 for i in literal_unroll(abc): out += i return out self.assertEqual(foo(), foo.py_func()) def test_unroll_named_unituple_arg(self): ABC = namedtuple('ABC', ['a', 'b', 'c']) @njit def foo(x): out = 0 for i in literal_unroll(x): out += i return out abc = ABC(1, 2, 3) self.assertEqual(foo(abc), foo.py_func(abc)) def test_unroll_global_tuple(self): @njit def foo(): out = 0 for i in literal_unroll(_X_GLOBAL): out += i return out self.assertEqual(foo(), foo.py_func()) def test_unroll_freevar_tuple(self): x = (10, 11) @njit def foo(): out = 0 for i in literal_unroll(x): out += i return out self.assertEqual(foo(), foo.py_func()) def test_unroll_function_tuple(self): @njit def a(): return 1 @njit def b(): return 2 x = (a, b) @njit def foo(): out = 0 for f in literal_unroll(x): out += f() return out self.assertEqual(foo(), foo.py_func()) def test_unroll_indexing_list(self): # See issue #5477 @njit def foo(cont): i = 0 acc = 0 normal_list = [a for a in cont] heter_tuple = ('a', 25, 0.23, None) for item in literal_unroll(heter_tuple): acc += normal_list[i] i += 1 print(item) return i, acc data = [j for j in range(4)] # send stdout to nowhere, just check return values with captured_stdout(): self.assertEqual(foo(data), foo.py_func(data)) # now capture stdout for jit function and check with captured_stdout() as stdout: foo(data) lines = stdout.getvalue().splitlines() self.assertEqual( lines, ['a', '25', '0.23', 'None'], ) def test_unroller_as_freevar(self): mixed = (np.ones((1,)), np.ones((1, 1)), np.ones((1, 1, 1))) from numba import literal_unroll as freevar_unroll @njit def foo(): out = 0 for i in freevar_unroll(mixed): out += i.ndim return out self.assertEqual(foo(), foo.py_func()) def capture(real_pass): """ Returns a compiler pass that captures the mutation state reported by the pass used in the argument""" @register_pass(mutates_CFG=False, analysis_only=True) class ResultCapturer(AnalysisPass): _name = "capture_%s" % real_pass._name _real_pass = real_pass def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): result = real_pass().run_pass(state) mutation_results = state.metadata.setdefault('mutation_results', {}) mutation_results[real_pass] = result return result return ResultCapturer class CapturingCompiler(CompilerBase): """ Simple pipeline that wraps passes with the ResultCapturer pass""" def define_pipelines(self): pm = PassManager("Capturing Compiler") def add_pass(x, y): return pm.add_pass(capture(x), y) add_pass(TranslateByteCode, "analyzing bytecode") add_pass(FixupArgs, "fix up args") add_pass(IRProcessing, "processing IR") add_pass(LiteralUnroll, "handles literal_unroll") # typing add_pass(NopythonTypeInference, "nopython frontend") # legalise add_pass(IRLegalization, "ensure IR is legal prior to lowering") # lower add_pass(NativeLowering, "native lowering") add_pass(NoPythonBackend, "nopython mode backend") pm.finalize() return [pm] class TestLiteralUnrollPassTriggering(TestCase): def test_literal_unroll_not_invoked(self): @njit(pipeline_class=CapturingCompiler) def foo(): acc = 0 for i in (1, 2, 3): acc += i return acc foo() cres = foo.overloads[foo.signatures[0]] self.assertFalse(cres.metadata['mutation_results'][LiteralUnroll]) def test_literal_unroll_is_invoked(self): @njit(pipeline_class=CapturingCompiler) def foo(): acc = 0 for i in literal_unroll((1, 2, 3)): acc += i return acc foo() cres = foo.overloads[foo.signatures[0]] self.assertTrue(cres.metadata['mutation_results'][LiteralUnroll]) def test_literal_unroll_is_invoked_via_alias(self): alias = literal_unroll @njit(pipeline_class=CapturingCompiler) def foo(): acc = 0 for i in alias((1, 2, 3)): acc += i return acc foo() cres = foo.overloads[foo.signatures[0]] self.assertTrue(cres.metadata['mutation_results'][LiteralUnroll]) def test_literal_unroll_assess_empty_function(self): @njit(pipeline_class=CapturingCompiler) def foo(): pass foo() cres = foo.overloads[foo.signatures[0]] self.assertFalse(cres.metadata['mutation_results'][LiteralUnroll]) def test_literal_unroll_not_in_globals(self): f = """def foo():\n\tpass""" l = {} exec(f, {}, l) foo = njit(pipeline_class=CapturingCompiler)(l['foo']) foo() cres = foo.overloads[foo.signatures[0]] self.assertFalse(cres.metadata['mutation_results'][LiteralUnroll]) def test_literal_unroll_globals_and_locals(self): f = """def foo():\n\tfor x in literal_unroll((1,)):\n\t\tpass""" l = {} exec(f, {}, l) foo = njit(pipeline_class=CapturingCompiler)(l['foo']) with self.assertRaises(errors.TypingError) as raises: foo() self.assertIn("Untyped global name 'literal_unroll'", str(raises.exception)) # same as above but now add literal_unroll to globals l = {} exec(f, {'literal_unroll': literal_unroll}, l) foo = njit(pipeline_class=CapturingCompiler)(l['foo']) foo() cres = foo.overloads[foo.signatures[0]] self.assertTrue(cres.metadata['mutation_results'][LiteralUnroll]) # same as above, but now with import from textwrap import dedent f = """ def gen(): from numba import literal_unroll def foo(): for x in literal_unroll((1,)): pass return foo bar = gen() """ l = {} exec(dedent(f), {}, l) foo = njit(pipeline_class=CapturingCompiler)(l['bar']) foo() cres = foo.overloads[foo.signatures[0]] self.assertTrue(cres.metadata['mutation_results'][LiteralUnroll]) # same as above, but now with import as something else from textwrap import dedent f = """ def gen(): from numba import literal_unroll as something_else def foo(): for x in something_else((1,)): pass return foo bar = gen() """ l = {} exec(dedent(f), {}, l) foo = njit(pipeline_class=CapturingCompiler)(l['bar']) foo() cres = foo.overloads[foo.signatures[0]] self.assertTrue(cres.metadata['mutation_results'][LiteralUnroll]) if __name__ == '__main__': unittest.main()