# Licensed under a 3-clause BSD style license - see LICENSE.rst import pytest import numpy as np from numpy.testing import assert_equal from astropy.visualization.stretch import (LinearStretch, SqrtStretch, PowerStretch, PowerDistStretch, InvertedPowerDistStretch, SquaredStretch, LogStretch, InvertedLogStretch, AsinhStretch, SinhStretch, HistEqStretch, InvertedHistEqStretch, ContrastBiasStretch) DATA = np.array([0.00, 0.25, 0.50, 0.75, 1.00]) RESULTS = {} RESULTS[LinearStretch()] = np.array([0.00, 0.25, 0.50, 0.75, 1.00]) RESULTS[LinearStretch(intercept=0.5) + LinearStretch(slope=0.5)] = \ np.array([0.5, 0.625, 0.75, 0.875, 1.]) RESULTS[SqrtStretch()] = np.array([0., 0.5, 0.70710678, 0.8660254, 1.]) RESULTS[SquaredStretch()] = np.array([0., 0.0625, 0.25, 0.5625, 1.]) RESULTS[PowerStretch(0.5)] = np.array([0., 0.5, 0.70710678, 0.8660254, 1.]) RESULTS[PowerDistStretch()] = np.array([0., 0.004628, 0.030653, 0.177005, 1.]) RESULTS[LogStretch()] = np.array([0., 0.799776, 0.899816, 0.958408, 1.]) RESULTS[AsinhStretch()] = np.array([0., 0.549402, 0.77127, 0.904691, 1.]) RESULTS[SinhStretch()] = np.array([0., 0.082085, 0.212548, 0.46828, 1.]) RESULTS[ContrastBiasStretch(contrast=2., bias=0.4)] = np.array([-0.3, 0.2, 0.7, 1.2, 1.7]) RESULTS[HistEqStretch(DATA)] = DATA RESULTS[HistEqStretch(DATA[::-1])] = DATA RESULTS[HistEqStretch(DATA ** 0.5)] = np.array([0., 0.125, 0.25, 0.5674767, 1.]) class TestStretch: @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_no_clip(self, stretch): np.testing.assert_allclose(stretch(DATA, clip=False), RESULTS[stretch], atol=1.e-6) @pytest.mark.parametrize('ndim', [2, 3]) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_clip_ndimensional(self, stretch, ndim): new_shape = DATA.shape + (1,) * ndim np.testing.assert_allclose(stretch(DATA.reshape(new_shape), clip=True).ravel(), np.clip(RESULTS[stretch], 0., 1), atol=1.e-6) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_clip(self, stretch): np.testing.assert_allclose(stretch(DATA, clip=True), np.clip(RESULTS[stretch], 0., 1), atol=1.e-6) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_inplace(self, stretch): data_in = DATA.copy() result = np.zeros(DATA.shape) stretch(data_in, out=result, clip=False) np.testing.assert_allclose(result, RESULTS[stretch], atol=1.e-6) np.testing.assert_allclose(data_in, DATA) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_round_trip(self, stretch): np.testing.assert_allclose(stretch.inverse(stretch(DATA, clip=False), clip=False), DATA) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_inplace_roundtrip(self, stretch): result = np.zeros(DATA.shape) stretch(DATA, out=result, clip=False) stretch.inverse(result, out=result, clip=False) np.testing.assert_allclose(result, DATA) @pytest.mark.parametrize('stretch', RESULTS.keys()) def test_double_inverse(self, stretch): np.testing.assert_allclose(stretch.inverse.inverse(DATA), stretch(DATA), atol=1.e-6) def test_inverted(self): stretch_1 = SqrtStretch().inverse stretch_2 = PowerStretch(2) np.testing.assert_allclose(stretch_1(DATA), stretch_2(DATA)) def test_chaining(self): stretch_1 = SqrtStretch() + SqrtStretch() stretch_2 = PowerStretch(0.25) stretch_3 = PowerStretch(4.) np.testing.assert_allclose(stretch_1(DATA), stretch_2(DATA)) np.testing.assert_allclose(stretch_1.inverse(DATA), stretch_3(DATA)) def test_clip_invalid(): stretch = SqrtStretch() values = stretch([-1., 0., 0.5, 1., 1.5]) np.testing.assert_allclose(values, [0., 0., 0.70710678, 1., 1.]) values = stretch([-1., 0., 0.5, 1., 1.5], clip=False) np.testing.assert_allclose(values, [np.nan, 0., 0.70710678, 1., 1.2247448]) @pytest.mark.parametrize('a', [-2., -1, 1.]) def test_invalid_powerdist_a(a): match = 'a must be >= 0, but cannot be set to 1' with pytest.raises(ValueError, match=match): PowerDistStretch(a=a) with pytest.raises(ValueError, match=match): InvertedPowerDistStretch(a=a) @pytest.mark.parametrize('a', [-2., -1, 0.]) def test_invalid_power_log_a(a): match = 'a must be > 0' with pytest.raises(ValueError, match=match): PowerStretch(a=a) with pytest.raises(ValueError, match=match): LogStretch(a=a) with pytest.raises(ValueError, match=match): InvertedLogStretch(a=a) @pytest.mark.parametrize('a', [-2., -1, 0., 1.5]) def test_invalid_sinh_a(a): match = 'a must be > 0 and <= 1' with pytest.raises(ValueError, match=match): AsinhStretch(a=a) with pytest.raises(ValueError, match=match): SinhStretch(a=a) def test_histeqstretch_invalid(): data = np.array([-np.inf, 0.00, 0.25, 0.50, 0.75, 1.00, np.inf]) result = np.array([0.0, 0.0, 0.25, 0.5, 0.75, 1.0, 1.0]) assert_equal(HistEqStretch(data)(data), result) assert_equal(InvertedHistEqStretch(data)(data), result)