import pytest
import numpy as np
from numpy.testing import assert_equal

from astropy.table import Table

da = pytest.importorskip('dask.array')


class TestDaskHandler:

    def setup_method(self, method):
        self.t = Table()
        self.t['a'] = da.arange(10)

    def test_add_row(self):
        self.t.add_row(self.t[0])
        assert_equal(self.t['a'].compute(), np.hstack([np.arange(10), 0]))

    def test_get_column(self):
        assert isinstance(self.t['a'], da.Array)
        assert_equal(self.t['a'].compute(), np.arange(10))

    def test_slicing_row_single(self):
        sub = self.t[5]
        assert isinstance(sub['a'], da.Array)
        assert not hasattr(sub['a'], 'info')  # should be a plain dask array
        assert sub['a'].compute() == 5

    def test_slicing_row_range(self):
        sub = self.t[5:]
        assert isinstance(sub['a'], da.Array)
        assert hasattr(sub['a'], 'info')  # should be a mixin column
        assert_equal(sub['a'].compute(), np.arange(5, 10))

    def test_slicing_column_range(self):
        sub = self.t[('a',)]
        assert isinstance(sub['a'], da.Array)
        assert hasattr(sub['a'], 'info')  # should be a mixin column
        assert_equal(sub['a'].compute(), np.arange(10))

    def test_pformat(self):
        assert self.t.pformat_all() == [' a ', '---', '  0', '  1', '  2',
                                        '  3', '  4', '  5', '  6', '  7',
                                        '  8', '  9']

    def test_info_preserved(self):

        self.t['a'].info.description = 'A dask column'

        sub = self.t[1:3]
        assert sub['a'].info.name == 'a'
        assert sub['a'].info.description == 'A dask column'

        col = self.t['a'].copy()
        assert col.info.name == 'a'
        assert col.info.description == 'A dask column'

        self.t.add_row(self.t[0])
        assert self.t['a'].info.name == 'a'
        assert self.t['a'].info.description == 'A dask column'