""" Tests for the birch clustering algorithm. """ import numpy as np import pytest from sklearn.cluster import AgglomerativeClustering, Birch from sklearn.cluster.tests.common import generate_clustered_data from sklearn.datasets import make_blobs from sklearn.exceptions import ConvergenceWarning from sklearn.metrics import pairwise_distances_argmin, v_measure_score from sklearn.utils._testing import assert_allclose, assert_array_equal from sklearn.utils.fixes import CSR_CONTAINERS def test_n_samples_leaves_roots(global_random_seed, global_dtype): # Sanity check for the number of samples in leaves and roots X, y = make_blobs(n_samples=10, random_state=global_random_seed) X = X.astype(global_dtype, copy=False) brc = Birch() brc.fit(X) n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_]) n_samples_leaves = sum( [sc.n_samples_ for leaf in brc._get_leaves() for sc in leaf.subclusters_] ) assert n_samples_leaves == X.shape[0] assert n_samples_root == X.shape[0] def test_partial_fit(global_random_seed, global_dtype): # Test that fit is equivalent to calling partial_fit multiple times X, y = make_blobs(n_samples=100, random_state=global_random_seed) X = X.astype(global_dtype, copy=False) brc = Birch(n_clusters=3) brc.fit(X) brc_partial = Birch(n_clusters=None) brc_partial.partial_fit(X[:50]) brc_partial.partial_fit(X[50:]) assert_allclose(brc_partial.subcluster_centers_, brc.subcluster_centers_) # Test that same global labels are obtained after calling partial_fit # with None brc_partial.set_params(n_clusters=3) brc_partial.partial_fit(None) assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_) def test_birch_predict(global_random_seed, global_dtype): # Test the predict method predicts the nearest centroid. rng = np.random.RandomState(global_random_seed) X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10) X = X.astype(global_dtype, copy=False) # n_samples * n_samples_per_cluster shuffle_indices = np.arange(30) rng.shuffle(shuffle_indices) X_shuffle = X[shuffle_indices, :] brc = Birch(n_clusters=4, threshold=1.0) brc.fit(X_shuffle) # Birch must preserve inputs' dtype assert brc.subcluster_centers_.dtype == global_dtype assert_array_equal(brc.labels_, brc.predict(X_shuffle)) centroids = brc.subcluster_centers_ nearest_centroid = brc.subcluster_labels_[ pairwise_distances_argmin(X_shuffle, centroids) ] assert_allclose(v_measure_score(nearest_centroid, brc.labels_), 1.0) def test_n_clusters(global_random_seed, global_dtype): # Test that n_clusters param works properly X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed) X = X.astype(global_dtype, copy=False) brc1 = Birch(n_clusters=10) brc1.fit(X) assert len(brc1.subcluster_centers_) > 10 assert len(np.unique(brc1.labels_)) == 10 # Test that n_clusters = Agglomerative Clustering gives # the same results. gc = AgglomerativeClustering(n_clusters=10) brc2 = Birch(n_clusters=gc) brc2.fit(X) assert_array_equal(brc1.subcluster_labels_, brc2.subcluster_labels_) assert_array_equal(brc1.labels_, brc2.labels_) # Test that a small number of clusters raises a warning. brc4 = Birch(threshold=10000.0) with pytest.warns(ConvergenceWarning): brc4.fit(X) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) def test_sparse_X(global_random_seed, global_dtype, csr_container): # Test that sparse and dense data give same results X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed) X = X.astype(global_dtype, copy=False) brc = Birch(n_clusters=10) brc.fit(X) csr = csr_container(X) brc_sparse = Birch(n_clusters=10) brc_sparse.fit(csr) # Birch must preserve inputs' dtype assert brc_sparse.subcluster_centers_.dtype == global_dtype assert_array_equal(brc.labels_, brc_sparse.labels_) assert_allclose(brc.subcluster_centers_, brc_sparse.subcluster_centers_) def test_partial_fit_second_call_error_checks(): # second partial fit calls will error when n_features is not consistent # with the first call X, y = make_blobs(n_samples=100) brc = Birch(n_clusters=3) brc.partial_fit(X, y) msg = "X has 1 features, but Birch is expecting 2 features" with pytest.raises(ValueError, match=msg): brc.partial_fit(X[:, [0]], y) def check_branching_factor(node, branching_factor): subclusters = node.subclusters_ assert branching_factor >= len(subclusters) for cluster in subclusters: if cluster.child_: check_branching_factor(cluster.child_, branching_factor) def test_branching_factor(global_random_seed, global_dtype): # Test that nodes have at max branching_factor number of subclusters X, y = make_blobs(random_state=global_random_seed) X = X.astype(global_dtype, copy=False) branching_factor = 9 # Purposefully set a low threshold to maximize the subclusters. brc = Birch(n_clusters=None, branching_factor=branching_factor, threshold=0.01) brc.fit(X) check_branching_factor(brc.root_, branching_factor) brc = Birch(n_clusters=3, branching_factor=branching_factor, threshold=0.01) brc.fit(X) check_branching_factor(brc.root_, branching_factor) def check_threshold(birch_instance, threshold): """Use the leaf linked list for traversal""" current_leaf = birch_instance.dummy_leaf_.next_leaf_ while current_leaf: subclusters = current_leaf.subclusters_ for sc in subclusters: assert threshold >= sc.radius current_leaf = current_leaf.next_leaf_ def test_threshold(global_random_seed, global_dtype): # Test that the leaf subclusters have a threshold lesser than radius X, y = make_blobs(n_samples=80, centers=4, random_state=global_random_seed) X = X.astype(global_dtype, copy=False) brc = Birch(threshold=0.5, n_clusters=None) brc.fit(X) check_threshold(brc, 0.5) brc = Birch(threshold=5.0, n_clusters=None) brc.fit(X) check_threshold(brc, 5.0) def test_birch_n_clusters_long_int(): # Check that birch supports n_clusters with np.int64 dtype, for instance # coming from np.arange. #16484 X, _ = make_blobs(random_state=0) n_clusters = np.int64(5) Birch(n_clusters=n_clusters).fit(X) def test_feature_names_out(): """Check `get_feature_names_out` for `Birch`.""" X, _ = make_blobs(n_samples=80, n_features=4, random_state=0) brc = Birch(n_clusters=4) brc.fit(X) n_clusters = brc.subcluster_centers_.shape[0] names_out = brc.get_feature_names_out() assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out) def test_transform_match_across_dtypes(global_random_seed): X, _ = make_blobs(n_samples=80, n_features=4, random_state=global_random_seed) brc = Birch(n_clusters=4, threshold=1.1) Y_64 = brc.fit_transform(X) Y_32 = brc.fit_transform(X.astype(np.float32)) assert_allclose(Y_64, Y_32, atol=1e-6) def test_subcluster_dtype(global_dtype): X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype( global_dtype, copy=False ) brc = Birch(n_clusters=4) assert brc.fit(X).subcluster_centers_.dtype == global_dtype def test_both_subclusters_updated(): """Check that both subclusters are updated when a node a split, even when there are duplicated data points. Non-regression test for #23269. """ X = np.array( [ [-2.6192791, -1.5053215], [-2.9993038, -1.6863596], [-2.3724914, -1.3438171], [-2.336792, -1.3417323], [-2.4089134, -1.3290224], [-2.3724914, -1.3438171], [-3.364009, -1.8846745], [-2.3724914, -1.3438171], [-2.617677, -1.5003285], [-2.2960556, -1.3260119], [-2.3724914, -1.3438171], [-2.5459878, -1.4533926], [-2.25979, -1.3003055], [-2.4089134, -1.3290224], [-2.3724914, -1.3438171], [-2.4089134, -1.3290224], [-2.5459878, -1.4533926], [-2.3724914, -1.3438171], [-2.9720619, -1.7058647], [-2.336792, -1.3417323], [-2.3724914, -1.3438171], ], dtype=np.float32, ) # no error Birch(branching_factor=5, threshold=1e-5, n_clusters=None).fit(X)