# Licensed under a 3-clause BSD style license - see LICENSE.rst import itertools from contextlib import nullcontext import pytest import numpy as np from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal_nulp from astropy.convolution.convolve import convolve_fft, convolve from astropy.utils.exceptions import AstropyUserWarning from astropy import units as u VALID_DTYPES = ('>f4', 'f8', ' exception # if nan_treatment and not normalize_kernel: # with pytest.raises(ValueError): # z = convolve_fft(x, y, boundary=boundary, # nan_treatment=nan_treatment, # normalize_kernel=normalize_kernel, # ignore_edge_zeros=ignore_edge_zeros, # ) # return with expected_boundary_warning(boundary=boundary): z = convolve_fft(x, y, boundary=boundary, nan_treatment=nan_treatment, # you cannot fill w/nan, you can only interpolate over it fill_value=np.nan if normalize_kernel and nan_treatment=='interpolate' else 0, normalize_kernel=normalize_kernel, preserve_nan=preserve_nan) if preserve_nan: assert np.isnan(z[1, 1]) # weights w_n = np.array([[3., 5., 3.], [5., 8., 5.], [3., 5., 3.]], dtype='float64') w_z = np.array([[4., 6., 4.], [6., 9., 6.], [4., 6., 4.]], dtype='float64') answer_dict = { 'sum': np.array([[1., 4., 3.], [3., 6., 5.], [3., 3., 2.]], dtype='float64'), 'sum_wrap': np.array([[6., 6., 6.], [6., 6., 6.], [6., 6., 6.]], dtype='float64'), } answer_dict['average'] = answer_dict['sum'] / w_z answer_dict['average_interpnan'] = answer_dict['sum'] / w_n answer_dict['average_wrap_interpnan'] = answer_dict['sum_wrap'] / 8. answer_dict['average_wrap'] = answer_dict['sum_wrap'] / 9. answer_dict['average_withzeros'] = answer_dict['sum'] / 9. answer_dict['average_withzeros_interpnan'] = answer_dict['sum'] / 8. answer_dict['sum_withzeros'] = answer_dict['sum'] answer_dict['sum_interpnan'] = answer_dict['sum'] * 9/8. answer_dict['sum_withzeros_interpnan'] = answer_dict['sum'] answer_dict['sum_wrap_interpnan'] = answer_dict['sum_wrap'] * 9/8. if normalize_kernel: answer_key = 'average' else: answer_key = 'sum' if boundary == 'wrap': answer_key += '_wrap' elif nan_treatment == 'fill': answer_key += '_withzeros' if nan_treatment == 'interpolate': answer_key += '_interpnan' answer_dict[answer_key] # Skip the NaN at [1, 1] when preserve_nan=True posns = np.where(np.isfinite(z)) # for reasons unknown, the Windows FFT returns an answer for the [0, 0] # component that is EXACTLY 10*np.spacing assert_floatclose(z[posns], z[posns]) def test_big_fail(self): """ Test that convolve_fft raises an exception if a too-large array is passed in.""" with pytest.raises((ValueError, MemoryError)): # while a good idea, this approach did not work; it actually writes to disk # arr = np.memmap('file.np', mode='w+', shape=(512, 512, 512), dtype=complex) # this just allocates the memory but never touches it; it's better: arr = np.empty([512, 512, 512], dtype=complex) # note 512**3 * 16 bytes = 2.0 GB convolve_fft(arr, arr) def test_padding(self): """ Test that convolve_fft pads to _next_fast_lengths and does not expand all dimensions to length of longest side (#11242/#10047). """ # old implementation expanded this to up to 2048**3 shape = (1, 1226, 518) img = np.zeros(shape, dtype='float64') img[0, 600:610, 300:304] = 1.0 kernel = np.zeros((1, 7, 7), dtype='float64') kernel[0, 3, 3] = 1.0 with pytest.warns(AstropyUserWarning, match="psf_pad was set to False, which overrides the boundary='fill'"): img_fft = convolve_fft(img, kernel, return_fft=True, psf_pad=False, fft_pad=False) assert_array_equal(img_fft.shape, shape) img_fft = convolve_fft(img, kernel, return_fft=True, psf_pad=False, fft_pad=True) # should be from either hardcoded _good_sizes[] or scipy.fft.next_fast_len() assert img_fft.shape in ((1, 1250, 540), (1, 1232, 525)) img_fft = convolve_fft(img, kernel, return_fft=True, psf_pad=True, fft_pad=False) assert_array_equal(img_fft.shape, np.array(shape) + np.array(kernel.shape)) img_fft = convolve_fft(img, kernel, return_fft=True, psf_pad=True, fft_pad=True) assert img_fft.shape in ((2, 1250, 540), (2, 1250, 525)) @pytest.mark.parametrize(('boundary'), BOUNDARY_OPTIONS) def test_non_normalized_kernel(self, boundary): x = np.array([[0., 0., 4.], [1., 2., 0.], [0., 3., 0.]], dtype='float') y = np.array([[1., -1., 1.], [-1., 0., -1.], [1., -1., 1.]], dtype='float') with expected_boundary_warning(boundary=boundary): z = convolve_fft(x, y, boundary=boundary, nan_treatment='fill', normalize_kernel=False) if boundary in (None, 'fill'): assert_floatclose(z, np.array([[1., -5., 2.], [1., 0., -3.], [-2., -1., -1.]], dtype='float')) elif boundary == 'wrap': assert_floatclose(z, np.array([[0., -8., 6.], [5., 0., -4.], [2., 3., -4.]], dtype='float')) else: raise ValueError("Invalid boundary specification") @pytest.mark.parametrize(('boundary'), BOUNDARY_OPTIONS) def test_asymmetric_kernel(boundary): ''' Make sure that asymmetric convolution functions go the right direction ''' x = np.array([3., 0., 1.], dtype='>f8') y = np.array([1, 2, 3], dtype='>f8') with expected_boundary_warning(boundary=boundary): z = convolve_fft(x, y, boundary=boundary, normalize_kernel=False) if boundary in (None, 'fill'): assert_array_almost_equal_nulp(z, np.array([6., 10., 2.], dtype='float'), 10) elif boundary == 'wrap': assert_array_almost_equal_nulp(z, np.array([9., 10., 5.], dtype='float'), 10) @pytest.mark.parametrize(('boundary', 'nan_treatment', 'normalize_kernel', 'preserve_nan', 'dtype'), itertools.product(BOUNDARY_OPTIONS, NANTREATMENT_OPTIONS, NORMALIZE_OPTIONS, PRESERVE_NAN_OPTIONS, VALID_DTYPES)) def test_input_unmodified(boundary, nan_treatment, normalize_kernel, preserve_nan, dtype): """ Test that convolve_fft works correctly when inputs are lists """ array = [1., 4., 5., 6., 5., 7., 8.] kernel = [0.2, 0.6, 0.2] x = np.array(array, dtype=dtype) y = np.array(kernel, dtype=dtype) # Make pseudoimmutable x.flags.writeable = False y.flags.writeable = False with expected_boundary_warning(boundary=boundary): z = convolve_fft(x, y, boundary=boundary, nan_treatment=nan_treatment, normalize_kernel=normalize_kernel, preserve_nan=preserve_nan) assert np.all(np.array(array, dtype=dtype) == x) assert np.all(np.array(kernel, dtype=dtype) == y) @pytest.mark.parametrize(('boundary', 'nan_treatment', 'normalize_kernel', 'preserve_nan', 'dtype'), itertools.product(BOUNDARY_OPTIONS, NANTREATMENT_OPTIONS, NORMALIZE_OPTIONS, PRESERVE_NAN_OPTIONS, VALID_DTYPES)) def test_input_unmodified_with_nan(boundary, nan_treatment, normalize_kernel, preserve_nan, dtype): """ Test that convolve_fft doesn't modify the input data """ array = [1., 4., 5., np.nan, 5., 7., 8.] kernel = [0.2, 0.6, 0.2] x = np.array(array, dtype=dtype) y = np.array(kernel, dtype=dtype) # Make pseudoimmutable x.flags.writeable = False y.flags.writeable = False # make copies for post call comparison x_copy = x.copy() y_copy = y.copy() with expected_boundary_warning(boundary=boundary): z = convolve_fft(x, y, boundary=boundary, nan_treatment=nan_treatment, normalize_kernel=normalize_kernel, preserve_nan=preserve_nan) # ( NaN == NaN ) = False # Only compare non NaN values for canonical equivalence # and then check NaN explicitly with np.isnan() array_is_nan = np.isnan(array) kernel_is_nan = np.isnan(kernel) array_not_nan = ~array_is_nan kernel_not_nan = ~kernel_is_nan assert np.all(x_copy[array_not_nan] == x[array_not_nan]) assert np.all(y_copy[kernel_not_nan] == y[kernel_not_nan]) assert np.all(np.isnan(x[array_is_nan])) assert np.all(np.isnan(y[kernel_is_nan])) @pytest.mark.parametrize('error_kwarg', [{'psf_pad': True}, {'fft_pad': True}, {'dealias': True}]) def test_convolve_fft_boundary_wrap_error(error_kwarg): x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], dtype='>f8') y = np.array([[1.]], dtype='>f8') assert (convolve_fft(x, y, boundary='wrap') == x).all() with pytest.raises(ValueError) as err: convolve_fft(x, y, boundary='wrap', **error_kwarg) assert str(err.value) == \ f"With boundary='wrap', {list(error_kwarg.keys())[0]} cannot be enabled." def test_convolve_fft_boundary_extend_error(): x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], dtype='>f8') y = np.array([[1.]], dtype='>f8') with pytest.raises(NotImplementedError) as err: convolve_fft(x, y, boundary='extend') assert str(err.value) == \ "The 'extend' option is not implemented for fft-based convolution"