import copy import itertools import pickle import numpy as np import pytest from scipy.spatial.distance import cdist from sklearn.metrics import DistanceMetric from sklearn.metrics._dist_metrics import ( BOOL_METRICS, DistanceMetric32, DistanceMetric64, ) from sklearn.utils import check_random_state from sklearn.utils._testing import assert_allclose, create_memmap_backed_data from sklearn.utils.fixes import CSR_CONTAINERS, parse_version, sp_version def dist_func(x1, x2, p): return np.sum((x1 - x2) ** p) ** (1.0 / p) rng = check_random_state(0) d = 4 n1 = 20 n2 = 25 X64 = rng.random_sample((n1, d)) Y64 = rng.random_sample((n2, d)) X32 = X64.astype("float32") Y32 = Y64.astype("float32") [X_mmap, Y_mmap] = create_memmap_backed_data([X64, Y64]) # make boolean arrays: ones and zeros X_bool = (X64 < 0.3).astype(np.float64) # quite sparse Y_bool = (Y64 < 0.7).astype(np.float64) # not too sparse [X_bool_mmap, Y_bool_mmap] = create_memmap_backed_data([X_bool, Y_bool]) V = rng.random_sample((d, d)) VI = np.dot(V, V.T) METRICS_DEFAULT_PARAMS = [ ("euclidean", {}), ("cityblock", {}), ("minkowski", dict(p=(0.5, 1, 1.5, 2, 3))), ("chebyshev", {}), ("seuclidean", dict(V=(rng.random_sample(d),))), ("mahalanobis", dict(VI=(VI,))), ("hamming", {}), ("canberra", {}), ("braycurtis", {}), ("minkowski", dict(p=(0.5, 1, 1.5, 3), w=(rng.random_sample(d),))), ] @pytest.mark.parametrize( "metric_param_grid", METRICS_DEFAULT_PARAMS, ids=lambda params: params[0] ) @pytest.mark.parametrize("X, Y", [(X64, Y64), (X32, Y32), (X_mmap, Y_mmap)]) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) def test_cdist(metric_param_grid, X, Y, csr_container): metric, param_grid = metric_param_grid keys = param_grid.keys() X_csr, Y_csr = csr_container(X), csr_container(Y) for vals in itertools.product(*param_grid.values()): kwargs = dict(zip(keys, vals)) rtol_dict = {} if metric == "mahalanobis" and X.dtype == np.float32: # Computation of mahalanobis differs between # the scipy and scikit-learn implementation. # Hence, we increase the relative tolerance. # TODO: Inspect slight numerical discrepancy # with scipy rtol_dict = {"rtol": 1e-6} # TODO: Remove when scipy minimum version >= 1.7.0 # scipy supports 0= 1.7.0 if metric == "minkowski": p = kwargs["p"] if sp_version < parse_version("1.7.0") and p < 1: pytest.skip("scipy does not support 0= 1.7.0 # scipy supports 0= 1.7.0 if metric == "minkowski": p = kwargs["p"] if sp_version < parse_version("1.7.0") and p < 1: pytest.skip("scipy does not support 0