import pytest import numpy as np from .._ni_support import _get_output @pytest.mark.parametrize( 'dtype', [ # String specifiers 'f4', 'float32', 'complex64', 'complex128', # Type and dtype specifiers np.float32, float, np.dtype('f4'), # Derive from input None, ], ) def test_get_output_basic(dtype): shape = (2, 3) input_ = np.zeros(shape, 'float32') # For None, derive dtype from input expected_dtype = 'float32' if dtype is None else dtype # Output is dtype-specifier, retrieve shape from input result = _get_output(dtype, input_) assert result.shape == shape assert result.dtype == np.dtype(expected_dtype) # Output is dtype specifier, with explicit shape, overriding input result = _get_output(dtype, input_, shape=(3, 2)) assert result.shape == (3, 2) assert result.dtype == np.dtype(expected_dtype) # Output is pre-allocated array, return directly output = np.zeros(shape, dtype) result = _get_output(output, input_) assert result is output def test_get_output_complex(): shape = (2, 3) input_ = np.zeros(shape) # None, promote input type to complex result = _get_output(None, input_, complex_output=True) assert result.shape == shape assert result.dtype == np.dtype('complex128') # Explicit type, promote type to complex with pytest.warns(UserWarning, match='promoting specified output dtype to complex'): result = _get_output(float, input_, complex_output=True) assert result.shape == shape assert result.dtype == np.dtype('complex128') # String specifier, simply verify complex output result = _get_output('complex64', input_, complex_output=True) assert result.shape == shape assert result.dtype == np.dtype('complex64') def test_get_output_error_cases(): input_ = np.zeros((2, 3), 'float32') # Two separate paths can raise the same error with pytest.raises(RuntimeError, match='output must have complex dtype'): _get_output('float32', input_, complex_output=True) with pytest.raises(RuntimeError, match='output must have complex dtype'): _get_output(np.zeros((2, 3)), input_, complex_output=True) with pytest.raises(RuntimeError, match='output must have numeric dtype'): _get_output('void', input_) with pytest.raises(RuntimeError, match='shape not correct'): _get_output(np.zeros((3, 2)), input_)