import warnings
import dis
from itertools import product
import numpy as np
from numba import njit, typed, objmode, prange
from numba.core.utils import PYVERSION
from numba.core import ir_utils, ir
from numba.core.errors import (
UnsupportedError, CompilerError, NumbaPerformanceWarning, TypingError,
)
from numba.tests.support import (
TestCase, unittest, captured_stdout, MemoryLeakMixin,
skip_parfors_unsupported, skip_unless_scipy,
)
class MyError(Exception):
pass
class TestTryBareExcept(TestCase):
"""Test the following pattern:
try:
except:
"""
def test_try_inner_raise(self):
@njit
def inner(x):
if x:
raise MyError
@njit
def udt(x):
try:
inner(x)
return "not raised"
except: # noqa: E722
return "caught"
self.assertEqual(udt(False), "not raised")
self.assertEqual(udt(True), "caught")
def test_try_state_reset(self):
@njit
def inner(x):
if x == 1:
raise MyError("one")
elif x == 2:
raise MyError("two")
@njit
def udt(x):
try:
inner(x)
res = "not raised"
except: # noqa: E722
res = "caught"
if x == 0:
inner(2)
return res
with self.assertRaises(MyError) as raises:
udt(0)
self.assertEqual(str(raises.exception), "two")
self.assertEqual(udt(1), "caught")
self.assertEqual(udt(-1), "not raised")
def _multi_inner(self):
@njit
def inner(x):
if x == 1:
print("call_one")
raise MyError("one")
elif x == 2:
print("call_two")
raise MyError("two")
elif x == 3:
print("call_three")
raise MyError("three")
else:
print("call_other")
return inner
def test_nested_try(self):
inner = self._multi_inner()
@njit
def udt(x, y, z):
try:
try:
print("A")
inner(x)
print("B")
except: # noqa: E722
print("C")
inner(y)
print("D")
except: # noqa: E722
print("E")
inner(z)
print("F")
# case 1
with self.assertRaises(MyError) as raises:
with captured_stdout() as stdout:
udt(1, 2, 3)
self.assertEqual(
stdout.getvalue().split(),
["A", "call_one", "C", "call_two", "E", "call_three"],
)
self.assertEqual(str(raises.exception), "three")
# case 2
with captured_stdout() as stdout:
udt(1, 0, 3)
self.assertEqual(
stdout.getvalue().split(),
["A", "call_one", "C", "call_other", "D"],
)
# case 3
with captured_stdout() as stdout:
udt(1, 2, 0)
self.assertEqual(
stdout.getvalue().split(),
["A", "call_one", "C", "call_two", "E", "call_other", "F"],
)
def test_loop_in_try(self):
inner = self._multi_inner()
@njit
def udt(x, n):
try:
print("A")
for i in range(n):
print(i)
if i == x:
inner(i)
except: # noqa: E722
print("B")
return i
# case 1
with captured_stdout() as stdout:
res = udt(3, 5)
self.assertEqual(
stdout.getvalue().split(),
["A", "0", "1", "2", "3", "call_three", "B"],
)
self.assertEqual(res, 3)
# case 2
with captured_stdout() as stdout:
res = udt(1, 3)
self.assertEqual(
stdout.getvalue().split(),
["A", "0", "1", "call_one", "B"],
)
self.assertEqual(res, 1)
# case 3
with captured_stdout() as stdout:
res = udt(0, 3)
self.assertEqual(
stdout.getvalue().split(),
["A", "0", "call_other", "1", "2"],
)
self.assertEqual(res, 2)
def test_raise_in_try(self):
@njit
def udt(x):
try:
print("A")
if x:
raise MyError("my_error")
print("B")
except: # noqa: E722
print("C")
return 321
return 123
# case 1
with captured_stdout() as stdout:
res = udt(True)
self.assertEqual(
stdout.getvalue().split(),
["A", "C"],
)
self.assertEqual(res, 321)
# case 2
with captured_stdout() as stdout:
res = udt(False)
self.assertEqual(
stdout.getvalue().split(),
["A", "B"],
)
self.assertEqual(res, 123)
def test_recursion(self):
@njit
def foo(x):
if x > 0:
try:
foo(x - 1)
except: # noqa: E722
print("CAUGHT")
return 12
if x == 1:
raise ValueError("exception")
with captured_stdout() as stdout:
res = foo(10)
self.assertIsNone(res)
self.assertEqual(stdout.getvalue().split(), ["CAUGHT",],)
def test_yield(self):
@njit
def foo(x):
if x > 0:
try:
yield 7
raise ValueError("exception") # never hit
except Exception:
print("CAUGHT")
@njit
def bar(z):
return next(foo(z))
with captured_stdout() as stdout:
res = bar(10)
self.assertEqual(res, 7)
self.assertEqual(stdout.getvalue().split(), [])
def test_closure2(self):
@njit
def foo(x):
def bar():
try:
raise ValueError("exception")
except: # noqa: E722
print("CAUGHT")
return 12
bar()
with captured_stdout() as stdout:
foo(10)
self.assertEqual(stdout.getvalue().split(), ["CAUGHT",],)
def test_closure3(self):
@njit
def foo(x):
def bar(z):
try:
raise ValueError("exception")
except: # noqa: E722
print("CAUGHT")
return z
return [x for x in map(bar, [1, 2, 3])]
with captured_stdout() as stdout:
res = foo(10)
self.assertEqual(res, [1, 2, 3])
self.assertEqual(stdout.getvalue().split(), ["CAUGHT",] * 3,)
def test_closure4(self):
@njit
def foo(x):
def bar(z):
if z < 0:
raise ValueError("exception")
return z
try:
return [x for x in map(bar, [1, 2, 3, x])]
except: # noqa: E722
print("CAUGHT")
with captured_stdout() as stdout:
res = foo(-1)
self.assertEqual(stdout.getvalue().strip(), "CAUGHT")
self.assertIsNone(res)
with captured_stdout() as stdout:
res = foo(4)
self.assertEqual(stdout.getvalue(), "")
self.assertEqual(res, [1, 2, 3, 4])
@skip_unless_scipy
def test_real_problem(self):
@njit
def foo():
a = np.zeros((4, 4))
try:
chol = np.linalg.cholesky(a)
except: # noqa: E722
print("CAUGHT")
return chol
with captured_stdout() as stdout:
foo()
self.assertEqual(stdout.getvalue().split(), ["CAUGHT",])
def test_for_loop(self):
@njit
def foo(n):
for i in range(n):
try:
if i > 5:
raise ValueError
except: # noqa: E722
print("CAUGHT")
else:
try:
try:
try:
if i > 5:
raise ValueError
except: # noqa: E722
print("CAUGHT1")
raise ValueError
except: # noqa: E722
print("CAUGHT2")
raise ValueError
except: # noqa: E722
print("CAUGHT3")
with captured_stdout() as stdout:
foo(10)
self.assertEqual(
stdout.getvalue().split(),
["CAUGHT",] * 4 + ["CAUGHT%s" % i for i in range(1, 4)],
)
def test_try_pass(self):
@njit
def foo(x):
try:
pass
except: # noqa: E722
pass
return x
res = foo(123)
self.assertEqual(res, 123)
def test_try_except_reraise(self):
@njit
def udt():
try:
raise ValueError("ERROR")
except: # noqa: E722
raise
with self.assertRaises(UnsupportedError) as raises:
udt()
self.assertIn(
"The re-raising of an exception is not yet supported.",
str(raises.exception),
)
class TestTryExceptCaught(TestCase):
def test_catch_exception(self):
@njit
def udt(x):
try:
print("A")
if x:
raise ZeroDivisionError("321")
print("B")
except Exception:
print("C")
print("D")
# case 1
with captured_stdout() as stdout:
udt(True)
self.assertEqual(
stdout.getvalue().split(),
["A", "C", "D"],
)
# case 2
with captured_stdout() as stdout:
udt(False)
self.assertEqual(
stdout.getvalue().split(),
["A", "B", "D"],
)
def test_return_in_catch(self):
@njit
def udt(x):
try:
print("A")
if x:
raise ZeroDivisionError
print("B")
r = 123
except Exception:
print("C")
r = 321
return r
print("D")
return r
# case 1
with captured_stdout() as stdout:
res = udt(True)
self.assertEqual(
stdout.getvalue().split(),
["A", "C"],
)
self.assertEqual(res, 321)
# case 2
with captured_stdout() as stdout:
res = udt(False)
self.assertEqual(
stdout.getvalue().split(),
["A", "B", "D"],
)
self.assertEqual(res, 123)
def test_save_caught(self):
@njit
def udt(x):
try:
if x:
raise ZeroDivisionError
r = 123
except Exception as e: # noqa: F841
r = 321
return r
return r
with self.assertRaises(UnsupportedError) as raises:
udt(True)
self.assertIn(
"Exception object cannot be stored into variable (e)",
str(raises.exception)
)
def test_try_except_reraise(self):
@njit
def udt():
try:
raise ValueError("ERROR")
except Exception:
raise
with self.assertRaises(UnsupportedError) as raises:
udt()
self.assertIn(
"The re-raising of an exception is not yet supported.",
str(raises.exception),
)
def test_try_except_reraise_chain(self):
@njit
def udt():
try:
raise ValueError("ERROR")
except Exception:
try:
raise
except Exception:
raise
with self.assertRaises(UnsupportedError) as raises:
udt()
self.assertIn(
"The re-raising of an exception is not yet supported.",
str(raises.exception),
)
def test_division_operator(self):
# This test that old-style implementation propagate exception
# to the exception handler properly.
@njit
def udt(y):
try:
1 / y
except Exception:
return 0xdead
else:
return 1 / y
self.assertEqual(udt(0), 0xdead)
self.assertEqual(udt(2), 0.5)
class TestTryExceptNested(TestCase):
"Tests for complicated nesting"
def check_compare(self, cfunc, pyfunc, *args, **kwargs):
with captured_stdout() as stdout:
pyfunc(*args, **kwargs)
expect = stdout.getvalue()
with captured_stdout() as stdout:
cfunc(*args, **kwargs)
got = stdout.getvalue()
self.assertEqual(
expect, got,
msg="args={} kwargs={}".format(args, kwargs)
)
def test_try_except_else(self):
@njit
def udt(x, y, z, p):
print('A')
if x:
print('B')
try:
print('C')
if y:
print('D')
raise MyError("y")
print('E')
except Exception: # noqa: F722
print('F')
try:
print('H')
try:
print('I')
if z:
print('J')
raise MyError('z')
print('K')
except Exception:
print('L')
else:
print('M')
except Exception:
print('N')
else:
print('O')
print('P')
else:
print('G')
print('Q')
print('R')
cases = list(product([True, False], repeat=4))
self.assertTrue(cases)
for x, y, z, p in cases:
self.check_compare(
udt, udt.py_func,
x=x, y=y, z=z, p=p,
)
def test_try_except_finally(self):
@njit
def udt(p, q):
try:
print('A')
if p:
print('B')
raise MyError
print('C')
except: # noqa: E722
print('D')
finally:
try:
print('E')
if q:
print('F')
raise MyError
except Exception:
print('G')
else:
print('H')
finally:
print('I')
cases = list(product([True, False], repeat=2))
self.assertTrue(cases)
for p, q in cases:
self.check_compare(
udt, udt.py_func,
p=p, q=q,
)
class TestTryExceptRefct(MemoryLeakMixin, TestCase):
def test_list_direct_raise(self):
@njit
def udt(n, raise_at):
lst = typed.List()
try:
for i in range(n):
if i == raise_at:
raise IndexError
lst.append(i)
except Exception:
return lst
else:
return lst
out = udt(10, raise_at=5)
self.assertEqual(list(out), list(range(5)))
out = udt(10, raise_at=10)
self.assertEqual(list(out), list(range(10)))
def test_list_indirect_raise(self):
@njit
def appender(lst, n, raise_at):
for i in range(n):
if i == raise_at:
raise IndexError
lst.append(i)
return lst
@njit
def udt(n, raise_at):
lst = typed.List()
lst.append(0xbe11)
try:
appender(lst, n, raise_at)
except Exception:
return lst
else:
return lst
out = udt(10, raise_at=5)
self.assertEqual(list(out), [0xbe11] + list(range(5)))
out = udt(10, raise_at=10)
self.assertEqual(list(out), [0xbe11] + list(range(10)))
def test_incompatible_refinement(self):
@njit
def udt():
try:
lst = typed.List()
print("A")
lst.append(0)
print("B")
lst.append("fda") # invalid type will cause typing error
print("C")
return lst
except Exception:
print("D")
with self.assertRaises(TypingError) as raises:
udt()
self.assertRegexpMatches(
str(raises.exception),
r"Cannot refine type|cannot safely cast unicode_type to int(32|64)"
)
class TestTryExceptOtherControlFlow(TestCase):
def test_yield(self):
@njit
def udt(n, x):
for i in range(n):
try:
if i == x:
raise ValueError
yield i
except Exception:
return
self.assertEqual(list(udt(10, 5)), list(range(5)))
self.assertEqual(list(udt(10, 10)), list(range(10)))
def test_objmode(self):
@njit
def udt():
try:
with objmode():
print(object())
except Exception:
return
with self.assertRaises(CompilerError) as raises:
udt()
if PYVERSION <= (3,7):
msg = ("unsupported control flow: due to return statements inside "
"with block")
else:
msg = ("unsupported control flow: with-context contains branches "
"(i.e. break/return/raise) that can leave the block ")
self.assertIn(
msg,
str(raises.exception),
)
def test_objmode_output_type(self):
def bar(x):
return np.asarray(list(reversed(x.tolist())))
@njit
def test_objmode():
x = np.arange(5)
y = np.zeros_like(x)
try:
with objmode(y='intp[:]'): # annotate return type
# this region is executed by object-mode.
y += bar(x)
except Exception:
pass
return y
with self.assertRaises(CompilerError) as raises:
test_objmode()
if PYVERSION == (3,7):
msg = ("unsupported control flow: due to return statements inside "
"with block")
else:
msg = ("unsupported control flow: with-context contains branches "
"(i.e. break/return/raise) that can leave the block ")
self.assertIn(
msg,
str(raises.exception),
)
@unittest.skipIf(PYVERSION < (3, 9), "Python 3.9+ only")
def test_reraise_opcode_unreachable(self):
# The opcode RERAISE was added in python 3.9, there should be no
# supported way to actually reach it. This test just checks that an
# exception is present to deal with if it is reached in a case known
# to produce this opcode.
def pyfunc():
try:
raise Exception
except Exception:
raise ValueError("ERROR")
for inst in dis.get_instructions(pyfunc):
if inst.opname == 'RERAISE':
break
else:
self.fail("expected RERAISE opcode not found")
func_ir = ir_utils.get_ir_of_code({}, pyfunc.__code__)
found = False
for lbl, blk in func_ir.blocks.items():
for stmt in blk.find_insts(ir.StaticRaise):
# don't worry about guarding this strongly, if the exec_args[0]
# is a string it'll either be "ERROR" or the guard message
# saying unreachable has been reached
msg = "Unreachable condition reached (op code RERAISE executed)"
if stmt.exc_args and msg in stmt.exc_args[0]:
found = True
if not found:
self.fail("expected RERAISE unreachable message not found")
@skip_parfors_unsupported
class TestTryExceptParfors(TestCase):
def test_try_in_prange_reduction(self):
# The try-except is transformed basically into chains of if-else
def udt(n):
c = 0
for i in prange(n):
try:
c += 1
except Exception:
c += 1
return c
args = [10]
expect = udt(*args)
self.assertEqual(njit(parallel=False)(udt)(*args), expect)
self.assertEqual(njit(parallel=True)(udt)(*args), expect)
def test_try_outside_prange_reduction(self):
# The try-except is transformed basically into chains of if-else
def udt(n):
c = 0
try:
for i in prange(n):
c += 1
except Exception:
return 0xdead
else:
return c
args = [10]
expect = udt(*args)
self.assertEqual(njit(parallel=False)(udt)(*args), expect)
# Parfors transformation didn't happen
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', NumbaPerformanceWarning)
self.assertEqual(njit(parallel=True)(udt)(*args), expect)
self.assertEqual(len(w), 1)
self.assertIn("no transformation for parallel execution was possible",
str(w[0]))
def test_try_in_prange_map(self):
def udt(arr, x):
out = arr.copy()
for i in prange(arr.size):
try:
if i == x:
raise ValueError
out[i] = arr[i] + i
except Exception:
out[i] = -1
return out
args = [np.arange(10), 6]
expect = udt(*args)
self.assertPreciseEqual(njit(parallel=False)(udt)(*args), expect)
self.assertPreciseEqual(njit(parallel=True)(udt)(*args), expect)
def test_try_outside_prange_map(self):
def udt(arr, x):
out = arr.copy()
try:
for i in prange(arr.size):
if i == x:
raise ValueError
out[i] = arr[i] + i
except Exception:
out[i] = -1
return out
args = [np.arange(10), 6]
expect = udt(*args)
self.assertPreciseEqual(njit(parallel=False)(udt)(*args), expect)
self.assertPreciseEqual(njit(parallel=True)(udt)(*args), expect)
if __name__ == '__main__':
unittest.main()