"""Utilities for meta-estimators""" # Author: Joel Nothman # Andreas Mueller # License: BSD from typing import List, Any from abc import ABCMeta, abstractmethod from operator import attrgetter from functools import update_wrapper import numpy as np from ..utils import _safe_indexing from ..base import BaseEstimator from ..base import _is_pairwise __all__ = ["available_if", "if_delegate_has_method"] class _BaseComposition(BaseEstimator, metaclass=ABCMeta): """Handles parameter management for classifiers composed of named estimators.""" steps: List[Any] @abstractmethod def __init__(self): pass def _get_params(self, attr, deep=True): out = super().get_params(deep=deep) if not deep: return out estimators = getattr(self, attr) out.update(estimators) for name, estimator in estimators: if hasattr(estimator, "get_params"): for key, value in estimator.get_params(deep=True).items(): out["%s__%s" % (name, key)] = value return out def _set_params(self, attr, **params): # Ensure strict ordering of parameter setting: # 1. All steps if attr in params: setattr(self, attr, params.pop(attr)) # 2. Step replacement items = getattr(self, attr) names = [] if items: names, _ = zip(*items) for name in list(params.keys()): if "__" not in name and name in names: self._replace_estimator(attr, name, params.pop(name)) # 3. Step parameters and other initialisation arguments super().set_params(**params) return self def _replace_estimator(self, attr, name, new_val): # assumes `name` is a valid estimator name new_estimators = list(getattr(self, attr)) for i, (estimator_name, _) in enumerate(new_estimators): if estimator_name == name: new_estimators[i] = (name, new_val) break setattr(self, attr, new_estimators) def _validate_names(self, names): if len(set(names)) != len(names): raise ValueError("Names provided are not unique: {0!r}".format(list(names))) invalid_names = set(names).intersection(self.get_params(deep=False)) if invalid_names: raise ValueError( "Estimator names conflict with constructor arguments: {0!r}".format( sorted(invalid_names) ) ) invalid_names = [name for name in names if "__" in name] if invalid_names: raise ValueError( "Estimator names must not contain __: got {0!r}".format(invalid_names) ) class _AvailableIfDescriptor: """Implements a conditional property using the descriptor protocol. Using this class to create a decorator will raise an ``AttributeError`` if check(self) returns a falsey value. Note that if check raises an error this will also result in hasattr returning false. See https://docs.python.org/3/howto/descriptor.html for an explanation of descriptors. """ def __init__(self, fn, check, attribute_name): self.fn = fn self.check = check self.attribute_name = attribute_name # update the docstring of the descriptor update_wrapper(self, fn) def __get__(self, obj, owner=None): attr_err = AttributeError( f"This {repr(owner.__name__)} has no attribute {repr(self.attribute_name)}" ) if obj is not None: # delegate only on instances, not the classes. # this is to allow access to the docstrings. if not self.check(obj): raise attr_err # lambda, but not partial, allows help() to work with update_wrapper out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # noqa else: def fn(*args, **kwargs): if not self.check(args[0]): raise attr_err return self.fn(*args, **kwargs) # This makes it possible to use the decorated method as an unbound method, # for instance when monkeypatching. out = lambda *args, **kwargs: fn(*args, **kwargs) # noqa # update the docstring of the returned function update_wrapper(out, self.fn) return out def available_if(check): """An attribute that is available only if check returns a truthy value Parameters ---------- check : callable When passed the object with the decorated method, this should return a truthy value if the attribute is available, and either return False or raise an AttributeError if not available. Examples -------- >>> from sklearn.utils.metaestimators import available_if >>> class HelloIfEven: ... def __init__(self, x): ... self.x = x ... ... def _x_is_even(self): ... return self.x % 2 == 0 ... ... @available_if(_x_is_even) ... def say_hello(self): ... print("Hello") ... >>> obj = HelloIfEven(1) >>> hasattr(obj, "say_hello") False >>> obj.x = 2 >>> hasattr(obj, "say_hello") True >>> obj.say_hello() Hello """ return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__) class _IffHasAttrDescriptor(_AvailableIfDescriptor): """Implements a conditional property using the descriptor protocol. Using this class to create a decorator will raise an ``AttributeError`` if none of the delegates (specified in ``delegate_names``) is an attribute of the base object or the first found delegate does not have an attribute ``attribute_name``. This allows ducktyping of the decorated method based on ``delegate.attribute_name``. Here ``delegate`` is the first item in ``delegate_names`` for which ``hasattr(object, delegate) is True``. See https://docs.python.org/3/howto/descriptor.html for an explanation of descriptors. """ def __init__(self, fn, delegate_names, attribute_name): super().__init__(fn, self._check, attribute_name) self.delegate_names = delegate_names def _check(self, obj): delegate = None for delegate_name in self.delegate_names: try: delegate = attrgetter(delegate_name)(obj) break except AttributeError: continue if delegate is None: return False # raise original AttributeError getattr(delegate, self.attribute_name) return True def if_delegate_has_method(delegate): """Create a decorator for methods that are delegated to a sub-estimator This enables ducktyping by hasattr returning True according to the sub-estimator. Parameters ---------- delegate : str, list of str or tuple of str Name of the sub-estimator that can be accessed as an attribute of the base object. If a list or a tuple of names are provided, the first sub-estimator that is an attribute of the base object will be used. """ if isinstance(delegate, list): delegate = tuple(delegate) if not isinstance(delegate, tuple): delegate = (delegate,) return lambda fn: _IffHasAttrDescriptor(fn, delegate, attribute_name=fn.__name__) def _safe_split(estimator, X, y, indices, train_indices=None): """Create subset of dataset and properly handle kernels. Slice X, y according to indices for cross-validation, but take care of precomputed kernel-matrices or pairwise affinities / distances. If ``estimator._pairwise is True``, X needs to be square and we slice rows and columns. If ``train_indices`` is not None, we slice rows using ``indices`` (assumed the test set) and columns using ``train_indices``, indicating the training set. .. deprecated:: 0.24 The _pairwise attribute is deprecated in 0.24. From 1.1 (renaming of 0.26) and onward, this function will check for the pairwise estimator tag. Labels y will always be indexed only along the first axis. Parameters ---------- estimator : object Estimator to determine whether we should slice only rows or rows and columns. X : array-like, sparse matrix or iterable Data to be indexed. If ``estimator._pairwise is True``, this needs to be a square array-like or sparse matrix. y : array-like, sparse matrix or iterable Targets to be indexed. indices : array of int Rows to select from X and y. If ``estimator._pairwise is True`` and ``train_indices is None`` then ``indices`` will also be used to slice columns. train_indices : array of int or None, default=None If ``estimator._pairwise is True`` and ``train_indices is not None``, then ``train_indices`` will be use to slice the columns of X. Returns ------- X_subset : array-like, sparse matrix or list Indexed data. y_subset : array-like, sparse matrix or list Indexed targets. """ if _is_pairwise(estimator): if not hasattr(X, "shape"): raise ValueError( "Precomputed kernels or affinity matrices have " "to be passed as arrays or sparse matrices." ) # X is a precomputed square kernel matrix if X.shape[0] != X.shape[1]: raise ValueError("X should be a square kernel matrix") if train_indices is None: X_subset = X[np.ix_(indices, indices)] else: X_subset = X[np.ix_(indices, train_indices)] else: X_subset = _safe_indexing(X, indices) if y is not None: y_subset = _safe_indexing(y, indices) else: y_subset = None return X_subset, y_subset