import numpy as np from numpy.testing import assert_equal, assert_raises from statsmodels.tsa.arima.tools import ( standardize_lag_order, validate_basic) def test_standardize_lag_order_int(): # Integer input assert_equal(standardize_lag_order(0, title='test'), 0) assert_equal(standardize_lag_order(3), 3) def test_standardize_lag_order_list_int(): # List input, lags assert_equal(standardize_lag_order([]), 0) assert_equal(standardize_lag_order([1, 2]), 2) assert_equal(standardize_lag_order([1, 3]), [1, 3]) def test_standardize_lag_order_tuple_int(): # Non-list iterable input, lags assert_equal(standardize_lag_order((1, 2)), 2) assert_equal(standardize_lag_order((1, 3)), [1, 3]) def test_standardize_lag_order_ndarray_int(): assert_equal(standardize_lag_order(np.array([1, 2])), 2) assert_equal(standardize_lag_order(np.array([1, 3])), [1, 3]) def test_standardize_lag_order_list_bool(): # List input, booleans assert_equal(standardize_lag_order([0]), 0) assert_equal(standardize_lag_order([1]), 1) assert_equal(standardize_lag_order([0, 1]), [2]) assert_equal(standardize_lag_order([0, 1, 0, 1]), [2, 4]) def test_standardize_lag_order_tuple_bool(): # Non-list iterable input, lags assert_equal(standardize_lag_order((0)), 0) assert_equal(standardize_lag_order((1)), 1) assert_equal(standardize_lag_order((0, 1)), [2]) assert_equal(standardize_lag_order((0, 1, 0, 1)), [2, 4]) def test_standardize_lag_order_ndarray_bool(): assert_equal(standardize_lag_order(np.array([0])), 0) assert_equal(standardize_lag_order(np.array([1])), 1) assert_equal(standardize_lag_order(np.array([0, 1])), [2]) assert_equal(standardize_lag_order(np.array([0, 1, 0, 1])), [2, 4]) def test_standardize_lag_order_misc(): # Misc. assert_equal(standardize_lag_order(np.array([[1], [3]])), [1, 3]) def test_standardize_lag_order_invalid(): # Invalid input assert_raises(TypeError, standardize_lag_order, None) assert_raises(ValueError, standardize_lag_order, 1.2) assert_raises(ValueError, standardize_lag_order, -1) assert_raises(ValueError, standardize_lag_order, np.arange(4).reshape(2, 2)) # Boolean list can't have 2, lag order list can't have 0 assert_raises(ValueError, standardize_lag_order, [0, 2]) # Can't have duplicates assert_raises(ValueError, standardize_lag_order, [1, 1, 2]) def test_validate_basic(): # Valid parameters assert_equal(validate_basic([], 0, title='test'), []) assert_equal(validate_basic(0, 1), [0]) assert_equal(validate_basic([0], 1), [0]) assert_equal(validate_basic(np.array([1.2, 0.5 + 1j]), 2), np.array([1.2, 0.5 + 1j])) assert_equal( validate_basic([np.nan, -np.inf, np.inf], 3, allow_infnan=True), [np.nan, -np.inf, np.inf]) # Invalid parameters assert_raises(ValueError, validate_basic, [], 1, title='test') assert_raises(ValueError, validate_basic, 0, 0) assert_raises(ValueError, validate_basic, 'a', 1) assert_raises(ValueError, validate_basic, None, 1) assert_raises(ValueError, validate_basic, np.nan, 1) assert_raises(ValueError, validate_basic, np.inf, 1) assert_raises(ValueError, validate_basic, -np.inf, 1) assert_raises(ValueError, validate_basic, [1, 2], 1)