Files
Feature-Extraction/dist/client/mne/decoding/mixin.py

90 lines
2.8 KiB
Python

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
class TransformerMixin:
"""Mixin class for all transformers in scikit-learn."""
def fit_transform(self, X, y=None, **fit_params):
"""Fit to data, then transform it.
Fits transformer to ``X`` and ``y`` with optional parameters
``fit_params``, and returns a transformed version of ``X``.
Parameters
----------
X : array, shape (n_samples, n_features)
Training set.
y : array, shape (n_samples,)
Target values or class labels.
**fit_params : dict
Additional fitting parameters passed to the ``fit`` method..
Returns
-------
X_new : array, shape (n_samples, n_features_new)
Transformed array.
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm
if y is None:
# fit method of arity 1 (unsupervised transformation)
return self.fit(X, **fit_params).transform(X)
else:
# fit method of arity 2 (supervised transformation)
return self.fit(X, y, **fit_params).transform(X)
class EstimatorMixin:
"""Mixin class for estimators."""
def get_params(self, deep=True):
"""Get the estimator params.
Parameters
----------
deep : bool
Deep.
"""
return
def set_params(self, **params):
"""Set parameters (mimics sklearn API).
Parameters
----------
**params : dict
Extra parameters.
Returns
-------
inst : object
The instance.
"""
if not params:
return self
valid_params = self.get_params(deep=True)
for key, value in params.items():
split = key.split("__", 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError(
f"Invalid parameter {name} for estimator {self}. Check the list"
" of available parameters with `estimator.get_params().keys()`."
)
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if key not in valid_params:
raise ValueError(
f"Invalid parameter {key} for estimator "
f"{self.__class__.__name__}. Check the list of available "
"parameters with `estimator.get_params().keys()`."
)
setattr(self, key, value)
return self