"""Base class for ensemble-based estimators.""" # Authors: Gilles Louppe # License: BSD 3 clause from abc import ABCMeta, abstractmethod from typing import List import numpy as np from joblib import effective_n_jobs from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor from ..utils import Bunch, _print_elapsed_time, check_random_state from ..utils._tags import _safe_tags from ..utils.metaestimators import _BaseComposition def _fit_single_estimator( estimator, X, y, sample_weight=None, message_clsname=None, message=None ): """Private function used to fit an estimator within a job.""" if sample_weight is not None: try: with _print_elapsed_time(message_clsname, message): estimator.fit(X, y, sample_weight=sample_weight) except TypeError as exc: if "unexpected keyword argument 'sample_weight'" in str(exc): raise TypeError( "Underlying estimator {} does not support sample weights.".format( estimator.__class__.__name__ ) ) from exc raise else: with _print_elapsed_time(message_clsname, message): estimator.fit(X, y) return estimator def _set_random_states(estimator, random_state=None): """Set fixed random_state parameters for an estimator. Finds all parameters ending ``random_state`` and sets them to integers derived from ``random_state``. Parameters ---------- estimator : estimator supporting get/set_params Estimator with potential randomness managed by random_state parameters. random_state : int, RandomState instance or None, default=None Pseudo-random number generator to control the generation of the random integers. Pass an int for reproducible output across multiple function calls. See :term:`Glossary `. Notes ----- This does not necessarily set *all* ``random_state`` attributes that control an estimator's randomness, only those accessible through ``estimator.get_params()``. ``random_state``s not controlled include those belonging to: * cross-validation splitters * ``scipy.stats`` rvs """ random_state = check_random_state(random_state) to_set = {} for key in sorted(estimator.get_params(deep=True)): if key == "random_state" or key.endswith("__random_state"): to_set[key] = random_state.randint(np.iinfo(np.int32).max) if to_set: estimator.set_params(**to_set) class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): """Base class for all ensemble classes. Warning: This class should not be used directly. Use derived classes instead. Parameters ---------- estimator : object The base estimator from which the ensemble is built. n_estimators : int, default=10 The number of estimators in the ensemble. estimator_params : list of str, default=tuple() The list of attributes to use as parameters when instantiating a new base estimator. If none are given, default parameters are used. Attributes ---------- estimator_ : estimator The base estimator from which the ensemble is grown. estimators_ : list of estimators The collection of fitted base estimators. """ # overwrite _required_parameters from MetaEstimatorMixin _required_parameters: List[str] = [] @abstractmethod def __init__( self, estimator=None, *, n_estimators=10, estimator_params=tuple(), ): # Set parameters self.estimator = estimator self.n_estimators = n_estimators self.estimator_params = estimator_params # Don't instantiate estimators now! Parameters of estimator might # still change. Eg., when grid-searching with the nested object syntax. # self.estimators_ needs to be filled by the derived classes in fit. def _validate_estimator(self, default=None): """Check the base estimator. Sets the `estimator_` attributes. """ if self.estimator is not None: self.estimator_ = self.estimator else: self.estimator_ = default def _make_estimator(self, append=True, random_state=None): """Make and configure a copy of the `estimator_` attribute. Warning: This method should be used to properly instantiate new sub-estimators. """ estimator = clone(self.estimator_) estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) if random_state is not None: _set_random_states(estimator, random_state) if append: self.estimators_.append(estimator) return estimator def __len__(self): """Return the number of estimators in the ensemble.""" return len(self.estimators_) def __getitem__(self, index): """Return the index'th estimator in the ensemble.""" return self.estimators_[index] def __iter__(self): """Return iterator over estimators in the ensemble.""" return iter(self.estimators_) def _partition_estimators(n_estimators, n_jobs): """Private function used to partition estimators between jobs.""" # Compute the number of jobs n_jobs = min(effective_n_jobs(n_jobs), n_estimators) # Partition estimators between jobs n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int) n_estimators_per_job[: n_estimators % n_jobs] += 1 starts = np.cumsum(n_estimators_per_job) return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() class _BaseHeterogeneousEnsemble( MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta ): """Base class for heterogeneous ensemble of learners. Parameters ---------- estimators : list of (str, estimator) tuples The ensemble of estimators to use in the ensemble. Each element of the list is defined as a tuple of string (i.e. name of the estimator) and an estimator instance. An estimator can be set to `'drop'` using `set_params`. Attributes ---------- estimators_ : list of estimators The elements of the estimators parameter, having been fitted on the training data. If an estimator has been set to `'drop'`, it will not appear in `estimators_`. """ _required_parameters = ["estimators"] @property def named_estimators(self): """Dictionary to access any fitted sub-estimators by name. Returns ------- :class:`~sklearn.utils.Bunch` """ return Bunch(**dict(self.estimators)) @abstractmethod def __init__(self, estimators): self.estimators = estimators def _validate_estimators(self): if len(self.estimators) == 0: raise ValueError( "Invalid 'estimators' attribute, 'estimators' should be a " "non-empty list of (string, estimator) tuples." ) names, estimators = zip(*self.estimators) # defined by MetaEstimatorMixin self._validate_names(names) has_estimator = any(est != "drop" for est in estimators) if not has_estimator: raise ValueError( "All estimators are dropped. At least one is required " "to be an estimator." ) is_estimator_type = is_classifier if is_classifier(self) else is_regressor for est in estimators: if est != "drop" and not is_estimator_type(est): raise ValueError( "The estimator {} should be a {}.".format( est.__class__.__name__, is_estimator_type.__name__[3:] ) ) return names, estimators def set_params(self, **params): """ Set the parameters of an estimator from the ensemble. Valid parameter keys can be listed with `get_params()`. Note that you can directly set the parameters of the estimators contained in `estimators`. Parameters ---------- **params : keyword arguments Specific parameters using e.g. `set_params(parameter_name=new_value)`. In addition, to setting the parameters of the estimator, the individual estimator of the estimators can also be set, or can be removed by setting them to 'drop'. Returns ------- self : object Estimator instance. """ super()._set_params("estimators", **params) return self def get_params(self, deep=True): """ Get the parameters of an estimator from the ensemble. Returns the parameters given in the constructor as well as the estimators contained within the `estimators` parameter. Parameters ---------- deep : bool, default=True Setting it to True gets the various estimators and the parameters of the estimators as well. Returns ------- params : dict Parameter and estimator names mapped to their values or parameter names mapped to their values. """ return super()._get_params("estimators", deep=deep) def _more_tags(self): try: allow_nan = all( _safe_tags(est[1])["allow_nan"] if est[1] != "drop" else True for est in self.estimators ) except Exception: # If `estimators` does not comply with our API (list of tuples) then it will # fail. In this case, we assume that `allow_nan` is False but the parameter # validation will raise an error during `fit`. allow_nan = False return {"preserves_dtype": [], "allow_nan": allow_nan}