import textwrap from io import BytesIO import pytest from sklearn.datasets._arff_parser import ( _liac_arff_parser, _pandas_arff_parser, _post_process_frame, load_arff_from_gzip_file, ) @pytest.mark.parametrize( "feature_names, target_names", [ ( [ "col_int_as_integer", "col_int_as_numeric", "col_float_as_real", "col_float_as_numeric", ], ["col_categorical", "col_string"], ), ( [ "col_int_as_integer", "col_int_as_numeric", "col_float_as_real", "col_float_as_numeric", ], ["col_categorical"], ), ( [ "col_int_as_integer", "col_int_as_numeric", "col_float_as_real", "col_float_as_numeric", ], [], ), ], ) def test_post_process_frame(feature_names, target_names): """Check the behaviour of the post-processing function for splitting a dataframe.""" pd = pytest.importorskip("pandas") X_original = pd.DataFrame( { "col_int_as_integer": [1, 2, 3], "col_int_as_numeric": [1, 2, 3], "col_float_as_real": [1.0, 2.0, 3.0], "col_float_as_numeric": [1.0, 2.0, 3.0], "col_categorical": ["a", "b", "c"], "col_string": ["a", "b", "c"], } ) X, y = _post_process_frame(X_original, feature_names, target_names) assert isinstance(X, pd.DataFrame) if len(target_names) >= 2: assert isinstance(y, pd.DataFrame) elif len(target_names) == 1: assert isinstance(y, pd.Series) else: assert y is None def test_load_arff_from_gzip_file_error_parser(): """An error will be raised if the parser is not known.""" # None of the input parameters are required to be accurate since the check # of the parser will be carried out first. err_msg = "Unknown parser: 'xxx'. Should be 'liac-arff' or 'pandas'" with pytest.raises(ValueError, match=err_msg): load_arff_from_gzip_file("xxx", "xxx", "xxx", "xxx", "xxx", "xxx") @pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser]) def test_pandas_arff_parser_strip_single_quotes(parser_func): """Check that we properly strip single quotes from the data.""" pd = pytest.importorskip("pandas") arff_file = BytesIO(textwrap.dedent(""" @relation 'toy' @attribute 'cat_single_quote' {'A', 'B', 'C'} @attribute 'str_single_quote' string @attribute 'str_nested_quote' string @attribute 'class' numeric @data 'A','some text','\"expect double quotes\"',0 """).encode("utf-8")) columns_info = { "cat_single_quote": { "data_type": "nominal", "name": "cat_single_quote", }, "str_single_quote": { "data_type": "string", "name": "str_single_quote", }, "str_nested_quote": { "data_type": "string", "name": "str_nested_quote", }, "class": { "data_type": "numeric", "name": "class", }, } feature_names = [ "cat_single_quote", "str_single_quote", "str_nested_quote", ] target_names = ["class"] # We don't strip single quotes for string columns with the pandas parser. expected_values = { "cat_single_quote": "A", "str_single_quote": ( "some text" if parser_func is _liac_arff_parser else "'some text'" ), "str_nested_quote": ( '"expect double quotes"' if parser_func is _liac_arff_parser else "'\"expect double quotes\"'" ), "class": 0, } _, _, frame, _ = parser_func( arff_file, output_arrays_type="pandas", openml_columns_info=columns_info, feature_names_to_select=feature_names, target_names_to_select=target_names, ) assert frame.columns.tolist() == feature_names + target_names pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0)) @pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser]) def test_pandas_arff_parser_strip_double_quotes(parser_func): """Check that we properly strip double quotes from the data.""" pd = pytest.importorskip("pandas") arff_file = BytesIO(textwrap.dedent(""" @relation 'toy' @attribute 'cat_double_quote' {"A", "B", "C"} @attribute 'str_double_quote' string @attribute 'str_nested_quote' string @attribute 'class' numeric @data "A","some text","\'expect double quotes\'",0 """).encode("utf-8")) columns_info = { "cat_double_quote": { "data_type": "nominal", "name": "cat_double_quote", }, "str_double_quote": { "data_type": "string", "name": "str_double_quote", }, "str_nested_quote": { "data_type": "string", "name": "str_nested_quote", }, "class": { "data_type": "numeric", "name": "class", }, } feature_names = [ "cat_double_quote", "str_double_quote", "str_nested_quote", ] target_names = ["class"] expected_values = { "cat_double_quote": "A", "str_double_quote": "some text", "str_nested_quote": "'expect double quotes'", "class": 0, } _, _, frame, _ = parser_func( arff_file, output_arrays_type="pandas", openml_columns_info=columns_info, feature_names_to_select=feature_names, target_names_to_select=target_names, ) assert frame.columns.tolist() == feature_names + target_names pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0)) @pytest.mark.parametrize( "parser_func", [ # internal quotes are not considered to follow the ARFF spec in LIAC ARFF pytest.param(_liac_arff_parser, marks=pytest.mark.xfail), _pandas_arff_parser, ], ) def test_pandas_arff_parser_strip_no_quotes(parser_func): """Check that we properly parse with no quotes characters.""" pd = pytest.importorskip("pandas") arff_file = BytesIO(textwrap.dedent(""" @relation 'toy' @attribute 'cat_without_quote' {A, B, C} @attribute 'str_without_quote' string @attribute 'str_internal_quote' string @attribute 'class' numeric @data A,some text,'internal' quote,0 """).encode("utf-8")) columns_info = { "cat_without_quote": { "data_type": "nominal", "name": "cat_without_quote", }, "str_without_quote": { "data_type": "string", "name": "str_without_quote", }, "str_internal_quote": { "data_type": "string", "name": "str_internal_quote", }, "class": { "data_type": "numeric", "name": "class", }, } feature_names = [ "cat_without_quote", "str_without_quote", "str_internal_quote", ] target_names = ["class"] expected_values = { "cat_without_quote": "A", "str_without_quote": "some text", "str_internal_quote": "'internal' quote", "class": 0, } _, _, frame, _ = parser_func( arff_file, output_arrays_type="pandas", openml_columns_info=columns_info, feature_names_to_select=feature_names, target_names_to_select=target_names, ) assert frame.columns.tolist() == feature_names + target_names pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))