90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
|
|
class TransformerMixin:
|
|
"""Mixin class for all transformers in scikit-learn."""
|
|
|
|
def fit_transform(self, X, y=None, **fit_params):
|
|
"""Fit to data, then transform it.
|
|
|
|
Fits transformer to ``X`` and ``y`` with optional parameters
|
|
``fit_params``, and returns a transformed version of ``X``.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_samples, n_features)
|
|
Training set.
|
|
y : array, shape (n_samples,)
|
|
Target values or class labels.
|
|
**fit_params : dict
|
|
Additional fitting parameters passed to the ``fit`` method..
|
|
|
|
Returns
|
|
-------
|
|
X_new : array, shape (n_samples, n_features_new)
|
|
Transformed array.
|
|
"""
|
|
# non-optimized default implementation; override when a better
|
|
# method is possible for a given clustering algorithm
|
|
if y is None:
|
|
# fit method of arity 1 (unsupervised transformation)
|
|
return self.fit(X, **fit_params).transform(X)
|
|
else:
|
|
# fit method of arity 2 (supervised transformation)
|
|
return self.fit(X, y, **fit_params).transform(X)
|
|
|
|
|
|
class EstimatorMixin:
|
|
"""Mixin class for estimators."""
|
|
|
|
def get_params(self, deep=True):
|
|
"""Get the estimator params.
|
|
|
|
Parameters
|
|
----------
|
|
deep : bool
|
|
Deep.
|
|
"""
|
|
return
|
|
|
|
def set_params(self, **params):
|
|
"""Set parameters (mimics sklearn API).
|
|
|
|
Parameters
|
|
----------
|
|
**params : dict
|
|
Extra parameters.
|
|
|
|
Returns
|
|
-------
|
|
inst : object
|
|
The instance.
|
|
"""
|
|
if not params:
|
|
return self
|
|
valid_params = self.get_params(deep=True)
|
|
for key, value in params.items():
|
|
split = key.split("__", 1)
|
|
if len(split) > 1:
|
|
# nested objects case
|
|
name, sub_name = split
|
|
if name not in valid_params:
|
|
raise ValueError(
|
|
f"Invalid parameter {name} for estimator {self}. Check the list"
|
|
" of available parameters with `estimator.get_params().keys()`."
|
|
)
|
|
sub_object = valid_params[name]
|
|
sub_object.set_params(**{sub_name: value})
|
|
else:
|
|
# simple objects case
|
|
if key not in valid_params:
|
|
raise ValueError(
|
|
f"Invalid parameter {key} for estimator "
|
|
f"{self.__class__.__name__}. Check the list of available "
|
|
"parameters with `estimator.get_params().keys()`."
|
|
)
|
|
setattr(self, key, value)
|
|
return self
|