# Licensed under a 3-clause BSD style license - see LICENSE.rst import warnings import pytest import numpy as np from .test_table import SetupData from astropy.table.bst import BST from astropy.table.sorted_array import SortedArray from astropy.table.soco import SCEngine from astropy.table import QTable, Row, Table, Column, hstack from astropy import units as u from astropy.time import Time from astropy.table.column import BaseColumn from astropy.table.index import get_index, SlicedIndex from astropy.utils.compat.optional_deps import HAS_SORTEDCONTAINERS available_engines = [BST, SortedArray] if HAS_SORTEDCONTAINERS: available_engines.append(SCEngine) @pytest.fixture(params=available_engines) def engine(request): return request.param _col = [1, 2, 3, 4, 5] @pytest.fixture(params=[ _col, u.Quantity(_col), Time(_col, format='jyear'), ]) def main_col(request): return request.param def assert_col_equal(col, array): if isinstance(col, Time): assert np.all(col == Time(array, format='jyear')) else: assert np.all(col == col.__class__(array)) @pytest.mark.usefixtures('table_types') class TestIndex(SetupData): def _setup(self, main_col, table_types): super()._setup(table_types) self.main_col = main_col if isinstance(main_col, u.Quantity): self._table_type = QTable if not isinstance(main_col, list): self._column_type = lambda x: x # don't change mixin type self.mutable = isinstance(main_col, (list, u.Quantity)) def make_col(self, name, lst): return self._column_type(lst, name=name) def make_val(self, val): if isinstance(self.main_col, Time): return Time(val, format='jyear') return val @property def t(self): if not hasattr(self, '_t'): # Note that order of columns is important, and the 'a' column is # last to ensure that the index column does not need to be the first # column (as was discovered in #10025). Most testing uses 'a' and # ('a', 'b') for the columns. self._t = self._table_type() self._t['b'] = self._column_type([4.0, 5.1, 6.2, 7.0, 1.1]) self._t['c'] = self._column_type(['7', '8', '9', '10', '11']) self._t['a'] = self._column_type(self.main_col) return self._t @pytest.mark.parametrize("composite", [False, True]) def test_table_index(self, main_col, table_types, composite, engine): self._setup(main_col, table_types) t = self.t t.add_index(('a', 'b') if composite else 'a', engine=engine) assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) if not self.mutable: return # test altering table columns t['a'][0] = 4 t.add_row((6.0, '7', 6)) t['a'][3] = 10 t.remove_row(2) t.add_row((5.0, '9', 4)) assert_col_equal(t['a'], np.array([4, 2, 10, 5, 6, 4])) assert np.allclose(t['b'], np.array([4.0, 5.1, 7.0, 1.1, 6.0, 5.0])) assert np.all(t['c'].data == np.array(['7', '8', '10', '11', '7', '9'])) index = t.indices[0] ll = list(index.data.items()) if composite: assert np.all(ll == [((2, 5.1), [1]), ((4, 4.0), [0]), ((4, 5.0), [5]), ((5, 1.1), [3]), ((6, 6.0), [4]), ((10, 7.0), [2])]) else: assert np.all(ll == [((2,), [1]), ((4,), [0, 5]), ((5,), [3]), ((6,), [4]), ((10,), [2])]) t.remove_indices('a') assert len(t.indices) == 0 def test_table_slicing(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) for slice_ in ([0, 2], np.array([0, 2])): t2 = t[slice_] # t2 should retain an index on column 'a' assert len(t2.indices) == 1 assert_col_equal(t2['a'], [1, 3]) # the index in t2 should reorder row numbers after slicing assert np.all(t2.indices[0].sorted_data() == [0, 1]) # however, this index should be a deep copy of t1's index assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) def test_remove_rows(self, main_col, table_types, engine): self._setup(main_col, table_types) if not self.mutable: return t = self.t t.add_index('a', engine=engine) # remove individual row t2 = t.copy() t2.remove_rows(2) assert_col_equal(t2['a'], [1, 2, 4, 5]) assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3]) # remove by list, ndarray, or slice for cut in ([0, 2, 4], np.array([0, 2, 4]), slice(0, 5, 2)): t2 = t.copy() t2.remove_rows(cut) assert_col_equal(t2['a'], [2, 4]) assert np.all(t2.indices[0].sorted_data() == [0, 1]) with pytest.raises(ValueError): t.remove_rows((0, 2, 4)) def test_col_get_slice(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) # get slice t2 = t[1:3] # table slice assert_col_equal(t2['a'], [2, 3]) assert np.all(t2.indices[0].sorted_data() == [0, 1]) col_slice = t['a'][1:3] assert_col_equal(col_slice, [2, 3]) # true column slices discard indices if isinstance(t['a'], BaseColumn): assert len(col_slice.info.indices) == 0 # take slice of slice t2 = t[::2] assert_col_equal(t2['a'], np.array([1, 3, 5])) t3 = t2[::-1] assert_col_equal(t3['a'], np.array([5, 3, 1])) assert np.all(t3.indices[0].sorted_data() == [2, 1, 0]) t3 = t2[:2] assert_col_equal(t3['a'], np.array([1, 3])) assert np.all(t3.indices[0].sorted_data() == [0, 1]) # out-of-bound slices for t_empty in (t2[3:], t2[2:1], t3[2:]): assert len(t_empty['a']) == 0 assert np.all(t_empty.indices[0].sorted_data() == []) if self.mutable: # get boolean mask mask = t['a'] % 2 == 1 t2 = t[mask] assert_col_equal(t2['a'], [1, 3, 5]) assert np.all(t2.indices[0].sorted_data() == [0, 1, 2]) def test_col_set_slice(self, main_col, table_types, engine): self._setup(main_col, table_types) if not self.mutable: return t = self.t t.add_index('a', engine=engine) # set slice t2 = t.copy() t2['a'][1:3] = np.array([6, 7]) assert_col_equal(t2['a'], np.array([1, 6, 7, 4, 5])) assert np.all(t2.indices[0].sorted_data() == [0, 3, 4, 1, 2]) # change original table via slice reference t2 = t.copy() t3 = t2[1:3] assert_col_equal(t3['a'], np.array([2, 3])) assert np.all(t3.indices[0].sorted_data() == [0, 1]) t3['a'][0] = 5 assert_col_equal(t3['a'], np.array([5, 3])) assert_col_equal(t2['a'], np.array([1, 5, 3, 4, 5])) assert np.all(t3.indices[0].sorted_data() == [1, 0]) assert np.all(t2.indices[0].sorted_data() == [0, 2, 3, 1, 4]) # set boolean mask t2 = t.copy() mask = t['a'] % 2 == 1 t2['a'][mask] = 0. assert_col_equal(t2['a'], [0, 2, 0, 4, 0]) assert np.all(t2.indices[0].sorted_data() == [0, 2, 4, 1, 3]) def test_multiple_slices(self, main_col, table_types, engine): self._setup(main_col, table_types) if not self.mutable: return t = self.t t.add_index('a', engine=engine) for i in range(6, 51): t.add_row((1.0, 'A', i)) assert_col_equal(t['a'], [i for i in range(1, 51)]) assert np.all(t.indices[0].sorted_data() == [i for i in range(50)]) evens = t[::2] assert np.all(evens.indices[0].sorted_data() == [i for i in range(25)]) reverse = evens[::-1] index = reverse.indices[0] assert (index.start, index.stop, index.step) == (48, -2, -2) assert np.all(index.sorted_data() == [i for i in range(24, -1, -1)]) # modify slice of slice reverse[-10:] = 0 expected = np.array([i for i in range(1, 51)]) expected[:20][expected[:20] % 2 == 1] = 0 assert_col_equal(t['a'], expected) assert_col_equal(evens['a'], expected[::2]) assert_col_equal(reverse['a'], expected[::2][::-1]) # first ten evens are now zero assert np.all(t.indices[0].sorted_data() == ([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19] + [i for i in range(20, 50)])) assert np.all(evens.indices[0].sorted_data() == [i for i in range(25)]) assert np.all(reverse.indices[0].sorted_data() == [i for i in range(24, -1, -1)]) # try different step sizes of slice t2 = t[1:20:2] assert_col_equal(t2['a'], [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]) assert np.all(t2.indices[0].sorted_data() == [i for i in range(10)]) t3 = t2[::3] assert_col_equal(t3['a'], [2, 8, 14, 20]) assert np.all(t3.indices[0].sorted_data() == [0, 1, 2, 3]) t4 = t3[2::-1] assert_col_equal(t4['a'], [14, 8, 2]) assert np.all(t4.indices[0].sorted_data() == [2, 1, 0]) def test_sort(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t[::-1] # reverse table assert_col_equal(t['a'], [5, 4, 3, 2, 1]) t.add_index('a', engine=engine) assert np.all(t.indices[0].sorted_data() == [4, 3, 2, 1, 0]) if not self.mutable: return # sort table by column a t2 = t.copy() t2.sort('a') assert_col_equal(t2['a'], [1, 2, 3, 4, 5]) assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4]) # sort table by primary key t2 = t.copy() t2.sort() assert_col_equal(t2['a'], [1, 2, 3, 4, 5]) assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4]) def test_insert_row(self, main_col, table_types, engine): self._setup(main_col, table_types) if not self.mutable: return t = self.t t.add_index('a', engine=engine) t.insert_row(2, (1.0, '12', 6)) assert_col_equal(t['a'], [1, 2, 6, 3, 4, 5]) assert np.all(t.indices[0].sorted_data() == [0, 1, 3, 4, 5, 2]) t.insert_row(1, (4.0, '13', 0)) assert_col_equal(t['a'], [1, 0, 2, 6, 3, 4, 5]) assert np.all(t.indices[0].sorted_data() == [1, 0, 2, 4, 5, 6, 3]) def test_index_modes(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) # first, no special mode assert len(t[[1, 3]].indices) == 1 assert len(t[::-1].indices) == 1 assert len(self._table_type(t).indices) == 1 assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) t2 = t.copy() # non-copy mode with t.index_mode('discard_on_copy'): assert len(t[[1, 3]].indices) == 0 assert len(t[::-1].indices) == 0 assert len(self._table_type(t).indices) == 0 assert len(t2.copy().indices) == 1 # mode should only affect t # make sure non-copy mode is exited correctly assert len(t[[1, 3]].indices) == 1 if not self.mutable: return # non-modify mode with t.index_mode('freeze'): assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) t['a'][0] = 6 assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) t.add_row((1.5, '12', 2)) assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) t.remove_rows([1, 3]) assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4]) assert_col_equal(t['a'], [6, 3, 5, 2]) # mode should only affect t assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4]) t2['a'][0] = 6 assert np.all(t2.indices[0].sorted_data() == [1, 2, 3, 4, 0]) # make sure non-modify mode is exited correctly assert np.all(t.indices[0].sorted_data() == [3, 1, 2, 0]) if isinstance(t['a'], BaseColumn): assert len(t['a'][::-1].info.indices) == 0 with t.index_mode('copy_on_getitem'): assert len(t['a'][[1, 2]].info.indices) == 1 # mode should only affect t assert len(t2['a'][[1, 2]].info.indices) == 0 assert len(t['a'][::-1].info.indices) == 0 assert len(t2['a'][::-1].info.indices) == 0 def test_index_retrieval(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) t.add_index(['a', 'c'], engine=engine) assert len(t.indices) == 2 assert len(t.indices['a'].columns) == 1 assert len(t.indices['a', 'c'].columns) == 2 with pytest.raises(IndexError): t.indices['b'] def test_col_rename(self, main_col, table_types, engine): ''' Checks for a previous bug in which copying a Table with different column names raised an exception. ''' self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) t2 = self._table_type(self.t, names=['d', 'e', 'f']) assert len(t2.indices) == 1 def test_table_loc(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) t.add_index('b', engine=engine) t2 = t.loc[self.make_val(3)] # single label, with primary key 'a' assert_col_equal(t2['a'], [3]) assert isinstance(t2, Row) # list search t2 = t.loc[[self.make_val(1), self.make_val(4), self.make_val(2)]] assert_col_equal(t2['a'], [1, 4, 2]) # same order as input list if not isinstance(main_col, Time): # ndarray search t2 = t.loc[np.array([1, 4, 2])] assert_col_equal(t2['a'], [1, 4, 2]) assert_col_equal(t2['a'], [1, 4, 2]) t2 = t.loc[self.make_val(3): self.make_val(5)] # range search assert_col_equal(t2['a'], [3, 4, 5]) t2 = t.loc['b', 5.0:7.0] assert_col_equal(t2['b'], [5.1, 6.2, 7.0]) # search by sorted index t2 = t.iloc[0:2] # two smallest rows by column 'a' assert_col_equal(t2['a'], [1, 2]) t2 = t.iloc['b', 2:] # exclude two smallest rows in column 'b' assert_col_equal(t2['b'], [5.1, 6.2, 7.0]) for t2 in (t.loc[:], t.iloc[:]): assert_col_equal(t2['a'], [1, 2, 3, 4, 5]) def test_table_loc_indices(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine) t.add_index('b', engine=engine) t2 = t.loc_indices[self.make_val(3)] # single label, with primary key 'a' assert t2 == 2 # list search t2 = t.loc_indices[[self.make_val(1), self.make_val(4), self.make_val(2)]] for i, p in zip(t2, [1, 4, 2]): # same order as input list assert i == p - 1 def test_invalid_search(self, main_col, table_types, engine): # using .loc and .loc_indices with a value not present should raise an exception self._setup(main_col, table_types) t = self.t t.add_index('a') with pytest.raises(KeyError): t.loc[self.make_val(6)] with pytest.raises(KeyError): t.loc_indices[self.make_val(6)] def test_copy_index_references(self, main_col, table_types, engine): # check against a bug in which indices were given an incorrect # column reference when copied self._setup(main_col, table_types) t = self.t t.add_index('a') t.add_index('b') t2 = t.copy() assert t2.indices['a'].columns[0] is t2['a'] assert t2.indices['b'].columns[0] is t2['b'] def test_unique_index(self, main_col, table_types, engine): self._setup(main_col, table_types) t = self.t t.add_index('a', engine=engine, unique=True) assert np.all(t.indices['a'].sorted_data() == [0, 1, 2, 3, 4]) if self.mutable: with pytest.raises(ValueError): t.add_row((5.0, '9', 5)) def test_copy_indexed_table(self, table_types): self._setup(_col, table_types) t = self.t t.add_index('a') t.add_index(['a', 'b']) for tp in (self._table_type(t), t.copy()): assert len(t.indices) == len(tp.indices) for index, indexp in zip(t.indices, tp.indices): assert np.all(index.data.data == indexp.data.data) assert index.data.data.colnames == indexp.data.data.colnames def test_updating_row_byindex(self, main_col, table_types, engine): self._setup(main_col, table_types) t = Table([['a', 'b', 'c', 'd'], [2, 3, 4, 5], [3, 4, 5, 6]], names=('a', 'b', 'c'), meta={'name': 'first table'}) t.add_index('a', engine=engine) t.add_index('b', engine=engine) t.loc['c'] = ['g', 40, 50] # single label, with primary key 'a' t2 = t[2] assert list(t2) == ['g', 40, 50] # list search t.loc[['a', 'd', 'b']] = [['a', 20, 30], ['d', 50, 60], ['b', 30, 40]] t2 = [['a', 20, 30], ['d', 50, 60], ['b', 30, 40]] for i, p in zip(t2, [1, 4, 2]): # same order as input list assert list(t[p - 1]) == i def test_invalid_updates(self, main_col, table_types, engine): # using .loc and .loc_indices with a value not present should raise an exception self._setup(main_col, table_types) t = Table([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], names=('a', 'b', 'c'), meta={'name': 'first table'}) t.add_index('a') with pytest.raises(ValueError): t.loc[3] = [[1, 2, 3]] with pytest.raises(ValueError): t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5, 6]] with pytest.raises(ValueError): t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5, 6], [2, 3]] with pytest.raises(ValueError): t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5], [2, 3]] def test_get_index(): a = [1, 4, 5, 2, 7, 4, 45] b = [2.0, 5.0, 8.2, 3.7, 4.3, 6.5, 3.3] t = Table([a, b], names=('a', 'b'), meta={'name': 'first table'}) t.add_index(['a']) # Getting the values of index using names x1 = get_index(t, names=['a']) assert isinstance(x1, SlicedIndex) assert len(x1.columns) == 1 assert len(x1.columns[0]) == 7 assert x1.columns[0].info.name == 'a' # Getting the vales of index using table_copy x2 = get_index(t, table_copy=t[['a']]) assert isinstance(x2, SlicedIndex) assert len(x2.columns) == 1 assert len(x2.columns[0]) == 7 assert x2.columns[0].info.name == 'a' with pytest.raises(ValueError): get_index(t, names=['a'], table_copy=t[['a']]) with pytest.raises(ValueError): get_index(t, names=None, table_copy=None) def test_table_index_time_warning(engine): # Make sure that no ERFA warnings are emitted when indexing a table by # a Time column with a non-default time scale tab = Table() tab['a'] = Time([1, 2, 3], format='jyear', scale='tai') tab['b'] = [4, 3, 2] with warnings.catch_warnings(record=True) as wlist: tab.add_index(('a', 'b'), engine=engine) assert len(wlist) == 0 @pytest.mark.parametrize('col', [ Column(np.arange(50000, 50005)), np.arange(50000, 50005) * u.m, Time(np.arange(50000, 50005), format='mjd')]) def test_table_index_does_not_propagate_to_column_slices(col): # They lost contact to the parent table, so they should also not have # information on the indices; this helps prevent large memory usage if, # e.g., a large time column is turned into an object array; see gh-10688. tab = QTable() tab['t'] = col tab.add_index('t') t = tab['t'] assert t.info.indices tx = t[1:] assert not tx.info.indices tabx = tab[1:] t = tabx['t'] assert t.info.indices def test_hstack_qtable_table(): # Check in particular that indices are initialized or copied correctly # for a Column that is being converted to a Quantity. qtab = QTable([np.arange(5.)*u.m], names=['s']) qtab.add_index('s') tab = Table([Column(np.arange(5.), unit=u.s)], names=['t']) qstack = hstack([qtab, tab]) assert qstack['t'].info.indices == [] assert qstack.indices == [] def test_index_slice_exception(): with pytest.raises(TypeError, match='index_slice must be tuple or slice'): SlicedIndex(None, None)