# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. class TransformerMixin: """Mixin class for all transformers in scikit-learn.""" def fit_transform(self, X, y=None, **fit_params): """Fit to data, then transform it. Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and returns a transformed version of ``X``. Parameters ---------- X : array, shape (n_samples, n_features) Training set. y : array, shape (n_samples,) Target values or class labels. **fit_params : dict Additional fitting parameters passed to the ``fit`` method.. Returns ------- X_new : array, shape (n_samples, n_features_new) Transformed array. """ # non-optimized default implementation; override when a better # method is possible for a given clustering algorithm if y is None: # fit method of arity 1 (unsupervised transformation) return self.fit(X, **fit_params).transform(X) else: # fit method of arity 2 (supervised transformation) return self.fit(X, y, **fit_params).transform(X) class EstimatorMixin: """Mixin class for estimators.""" def get_params(self, deep=True): """Get the estimator params. Parameters ---------- deep : bool Deep. """ return def set_params(self, **params): """Set parameters (mimics sklearn API). Parameters ---------- **params : dict Extra parameters. Returns ------- inst : object The instance. """ if not params: return self valid_params = self.get_params(deep=True) for key, value in params.items(): split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: raise ValueError( f"Invalid parameter {name} for estimator {self}. Check the list" " of available parameters with `estimator.get_params().keys()`." ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: raise ValueError( f"Invalid parameter {key} for estimator " f"{self.__class__.__name__}. Check the list of available " "parameters with `estimator.get_params().keys()`." ) setattr(self, key, value) return self