"""Helpers for various transformations.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import glob import os from copy import deepcopy from pathlib import Path import numpy as np from scipy import linalg from scipy.spatial.distance import cdist from scipy.special import sph_harm from ._fiff.constants import FIFF from ._fiff.open import fiff_open from ._fiff.tag import read_tag from ._fiff.write import start_and_end_file, write_coord_trans from .defaults import _handle_default from .fixes import _get_img_fdata, jit from .utils import ( _check_fname, _check_option, _ensure_int, _import_nibabel, _path_like, _require_version, _validate_type, check_fname, fill_doc, get_subjects_dir, logger, verbose, wrapped_stdout, ) # transformation from anterior/left/superior coordinate system to # right/anterior/superior: als_ras_trans = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) _str_to_frame = dict( meg=FIFF.FIFFV_COORD_DEVICE, mri=FIFF.FIFFV_COORD_MRI, mri_voxel=FIFF.FIFFV_MNE_COORD_MRI_VOXEL, head=FIFF.FIFFV_COORD_HEAD, mni_tal=FIFF.FIFFV_MNE_COORD_MNI_TAL, ras=FIFF.FIFFV_MNE_COORD_RAS, fs_tal=FIFF.FIFFV_MNE_COORD_FS_TAL, ctf_head=FIFF.FIFFV_MNE_COORD_CTF_HEAD, ctf_meg=FIFF.FIFFV_MNE_COORD_CTF_DEVICE, unknown=FIFF.FIFFV_COORD_UNKNOWN, ) _frame_to_str = {val: key for key, val in _str_to_frame.items()} _verbose_frames = { FIFF.FIFFV_COORD_UNKNOWN: "unknown", FIFF.FIFFV_COORD_DEVICE: "MEG device", FIFF.FIFFV_COORD_ISOTRAK: "isotrak", FIFF.FIFFV_COORD_HPI: "hpi", FIFF.FIFFV_COORD_HEAD: "head", FIFF.FIFFV_COORD_MRI: "MRI (surface RAS)", FIFF.FIFFV_MNE_COORD_MRI_VOXEL: "MRI voxel", FIFF.FIFFV_COORD_MRI_SLICE: "MRI slice", FIFF.FIFFV_COORD_MRI_DISPLAY: "MRI display", FIFF.FIFFV_MNE_COORD_CTF_DEVICE: "CTF MEG device", FIFF.FIFFV_MNE_COORD_CTF_HEAD: "CTF/4D/KIT head", FIFF.FIFFV_MNE_COORD_RAS: "RAS (non-zero origin)", FIFF.FIFFV_MNE_COORD_MNI_TAL: "MNI Talairach", FIFF.FIFFV_MNE_COORD_FS_TAL_GTZ: "Talairach (MNI z > 0)", FIFF.FIFFV_MNE_COORD_FS_TAL_LTZ: "Talairach (MNI z < 0)", -1: "unknown", } def _to_const(cf): """Convert string or int coord frame into int.""" if isinstance(cf, str): if cf not in _str_to_frame: raise ValueError( f"Unknown coordinate frame {cf}, " 'expected "' + '", "'.join(_str_to_frame.keys()) + '"' ) cf = _str_to_frame[cf] else: cf = _ensure_int(cf, "coordinate frame", "a str or int") return int(cf) class Transform(dict): """A transform. Parameters ---------- fro : str | int The starting coordinate frame. See notes for valid coordinate frames. to : str | int The ending coordinate frame. See notes for valid coordinate frames. trans : array of shape (4, 4) | None The transformation matrix. If None, an identity matrix will be used. Notes ----- Valid coordinate frames are ``'meg'``, ``'mri'``, ``'mri_voxel'``, ``'head'``, ``'mri_tal'``, ``'ras'``, ``'fs_tal'``, ``'ctf_head'``, ``'ctf_meg'``, ``'unknown'``. """ def __init__(self, fro, to, trans=None): super().__init__() # we could add some better sanity checks here fro = _to_const(fro) to = _to_const(to) trans = np.eye(4) if trans is None else np.asarray(trans, np.float64) if trans.shape != (4, 4): raise ValueError(f"Transformation must be shape (4, 4) not {trans.shape}") self["from"] = fro self["to"] = to self["trans"] = trans def __repr__(self): # noqa: D105 with np.printoptions(suppress=True): # suppress scientific notation return "{to}>\n{trans}".format( fro=_coord_frame_name(self["from"]), to=_coord_frame_name(self["to"]), trans=self["trans"], ) def __eq__(self, other, rtol=0.0, atol=0.0): """Check for equality. Parameter --------- other : instance of Transform The other transform. rtol : float Relative tolerance. atol : float Absolute tolerance. Returns ------- eq : bool True if the transforms are equal. """ return ( isinstance(other, Transform) and self["from"] == other["from"] and self["to"] == other["to"] and np.allclose(self["trans"], other["trans"], rtol=rtol, atol=atol) ) def __ne__(self, other, rtol=0.0, atol=0.0): """Check for inequality. Parameter --------- other : instance of Transform The other transform. rtol : float Relative tolerance. atol : float Absolute tolerance. Returns ------- eq : bool True if the transforms are not equal. """ return not self == other @property def from_str(self): """The "from" frame as a string.""" return _coord_frame_name(self["from"]) @property def to_str(self): """The "to" frame as a string.""" return _coord_frame_name(self["to"]) @fill_doc @verbose def save(self, fname, *, overwrite=False, verbose=None): """Save the transform as -trans.fif file. Parameters ---------- fname : path-like The name of the file, which should end in ``-trans.fif``. %(overwrite)s %(verbose)s """ write_trans(fname, self, overwrite=overwrite, verbose=verbose) def copy(self): """Make a copy of the transform.""" return deepcopy(self) def _coord_frame_name(cframe): """Map integers to human-readable (verbose) names.""" return _verbose_frames.get(int(cframe), "unknown") def _print_coord_trans( t, prefix="Coordinate transformation: ", units="m", level="info" ): # Units gives the units of the transformation. This always prints in mm. log_func = getattr(logger, level) log_func( prefix + "{fro} -> {to}".format( fro=_coord_frame_name(t["from"]), to=_coord_frame_name(t["to"]) ) ) for ti, tt in enumerate(t["trans"]): scale = 1000.0 if (ti != 3 and units != "mm") else 1.0 text = " mm" if ti != 3 else "" log_func( f" {tt[0]:8.6f} {tt[1]:8.6f} {tt[2]:8.6f} {scale * tt[3]:7.2f}{text}" ) def _find_trans(subject, subjects_dir=None): if subject is None: if "SUBJECT" in os.environ: subject = os.environ["SUBJECT"] else: raise ValueError("SUBJECT environment variable not set") trans_fnames = glob.glob(str(subjects_dir / subject / "*-trans.fif")) if len(trans_fnames) < 1: raise RuntimeError(f"Could not find the transformation for {subject}") elif len(trans_fnames) > 1: raise RuntimeError(f"Found multiple transformations for {subject}") return Path(trans_fnames[0]) def apply_trans(trans, pts, move=True): """Apply a transform matrix to an array of points. Parameters ---------- trans : array, shape = (4, 4) | instance of Transform Transform matrix. pts : array, shape = (3,) | (n, 3) Array with coordinates for one or n points. move : bool If True (default), apply translation. Returns ------- transformed_pts : shape = (3,) | (n, 3) Transformed point(s). """ if isinstance(trans, dict): trans = trans["trans"] pts = np.asarray(pts) if pts.size == 0: return pts.copy() # apply rotation & scale out_pts = np.dot(pts, trans[:3, :3].T) # apply translation if move: out_pts += trans[:3, 3] return out_pts def rotation(x=0, y=0, z=0): """Create an array with a 4 dimensional rotation matrix. Parameters ---------- x, y, z : scalar Rotation around the origin (in rad). Returns ------- r : array, shape = (4, 4) The rotation matrix. """ r = np.eye(4) r[:3, :3] = rotation3d(x=x, y=y, z=z) return r def rotation3d(x=0, y=0, z=0): """Create an array with a 3 dimensional rotation matrix. Parameters ---------- x, y, z : scalar Rotation around the origin (in rad). Returns ------- r : array, shape = (3, 3) The rotation matrix. """ cos_x = np.cos(x) cos_y = np.cos(y) cos_z = np.cos(z) sin_x = np.sin(x) sin_y = np.sin(y) sin_z = np.sin(z) r = np.array( [ [ cos_y * cos_z, -cos_x * sin_z + sin_x * sin_y * cos_z, sin_x * sin_z + cos_x * sin_y * cos_z, ], [ cos_y * sin_z, cos_x * cos_z + sin_x * sin_y * sin_z, -sin_x * cos_z + cos_x * sin_y * sin_z, ], [-sin_y, sin_x * cos_y, cos_x * cos_y], ], dtype=float, ) return r def rotation3d_align_z_axis(target_z_axis): """Compute a rotation matrix to align [ 0 0 1] with supplied target z axis. Parameters ---------- target_z_axis : array, shape (1, 3) z axis. computed matrix (r) will map [0 0 1] to target_z_axis Returns ------- r : array, shape (3, 3) The rotation matrix. """ target_z_axis = target_z_axis / np.linalg.norm(target_z_axis) r = np.zeros((3, 3)) if (1.0 + target_z_axis[2]) < 1e-12: r[0, 0] = 1.0 r[1, 1] = -1.0 r[2, 2] = -1.0 else: f = 1.0 / (1.0 + target_z_axis[2]) r[0, 0] = 1.0 - 1.0 * f * target_z_axis[0] * target_z_axis[0] r[0, 1] = -1.0 * f * target_z_axis[0] * target_z_axis[1] r[0, 2] = target_z_axis[0] r[1, 0] = -1.0 * f * target_z_axis[0] * target_z_axis[1] r[1, 1] = 1.0 - 1.0 * f * target_z_axis[1] * target_z_axis[1] r[1, 2] = target_z_axis[1] r[2, 0] = -target_z_axis[0] r[2, 1] = -target_z_axis[1] r[2, 2] = 1.0 - f * ( target_z_axis[0] * target_z_axis[0] + target_z_axis[1] * target_z_axis[1] ) # assert that r is a rotation matrix r^t * r = I and det(r) = 1 assert np.any((r.dot(r.T) - np.identity(3)) < 1e-12) assert (np.linalg.det(r) - 1.0) < 1e-12 # assert that r maps [0 0 1] on the device z axis (target_z_axis) assert np.linalg.norm(target_z_axis - r.dot([0, 0, 1])) < 1e-12 return r def rotation_angles(m): """Find rotation angles from a transformation matrix. Parameters ---------- m : array, shape >= (3, 3) Rotation matrix. Only the top left 3 x 3 partition is accessed. Returns ------- x, y, z : float Rotation around x, y and z axes. """ x = np.arctan2(m[2, 1], m[2, 2]) c2 = np.sqrt(m[0, 0] ** 2 + m[1, 0] ** 2) y = np.arctan2(-m[2, 0], c2) s1 = np.sin(x) c1 = np.cos(x) z = np.arctan2(s1 * m[0, 2] - c1 * m[0, 1], c1 * m[1, 1] - s1 * m[1, 2]) return x, y, z def scaling(x=1, y=1, z=1): """Create an array with a scaling matrix. Parameters ---------- x, y, z : scalar Scaling factors. Returns ------- s : array, shape = (4, 4) The scaling matrix. """ s = np.array([[x, 0, 0, 0], [0, y, 0, 0], [0, 0, z, 0], [0, 0, 0, 1]], dtype=float) return s def translation(x=0, y=0, z=0): """Create an array with a translation matrix. Parameters ---------- x, y, z : scalar Translation parameters. Returns ------- m : array, shape = (4, 4) The translation matrix. """ m = np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]], dtype=float) return m def _ensure_trans(trans, fro="mri", to="head"): """Ensure we have the proper transform.""" if isinstance(fro, str): from_str = fro from_const = _str_to_frame[fro] else: from_str = _frame_to_str[fro] from_const = fro del fro if isinstance(to, str): to_str = to to_const = _str_to_frame[to] else: to_str = _frame_to_str[to] to_const = to del to err_str = f"trans must be a Transform between {from_str}<->{to_str}, got" if not isinstance(trans, (list, tuple)): trans = [trans] # Ensure that we have exactly one match idx = list() misses = list() for ti, this_trans in enumerate(trans): if not isinstance(this_trans, Transform): raise ValueError(f"{err_str} None") if {this_trans["from"], this_trans["to"]} == {from_const, to_const}: idx.append(ti) else: misses += [ "{fro}->{to}".format( fro=_frame_to_str[this_trans["from"]], to=_frame_to_str[this_trans["to"]], ) ] if len(idx) != 1: raise ValueError(f"{err_str} " + ", ".join(misses)) trans = trans[idx[0]] if trans["from"] != from_const: trans = invert_transform(trans) return trans def _get_trans(trans, fro="mri", to="head", allow_none=True): """Get mri_head_t (from=mri, to=head) from mri filename.""" types = (Transform, "path-like") if allow_none: types += (None,) _validate_type(trans, types, "trans") if _path_like(trans): if trans == "fsaverage": trans = Path(__file__).parent / "data" / "fsaverage" / "fsaverage-trans.fif" trans = Path(trans) if not trans.is_file(): raise OSError(f'trans file "{trans}" not found') if trans.suffix in [".fif", ".gz"]: fro_to_t = read_trans(trans) else: # convert "-trans.txt" to "-trans.fif" mri-type equivalent # these are usually actually in to_fro form t = np.genfromtxt(trans) if t.ndim != 2 or t.shape != (4, 4): raise RuntimeError(f'File "{trans}" did not have 4x4 entries') fro_to_t = Transform(to, fro, t) elif isinstance(trans, Transform): fro_to_t = trans trans = "instance of Transform" else: assert trans is None fro_to_t = Transform(fro, to) trans = "identity" # it's usually a head->MRI transform, so we probably need to invert it fro_to_t = _ensure_trans(fro_to_t, fro, to) return fro_to_t, trans def combine_transforms(t_first, t_second, fro, to): """Combine two transforms. Parameters ---------- t_first : dict First transform. t_second : dict Second transform. fro : int From coordinate frame. to : int To coordinate frame. Returns ------- trans : dict Combined transformation. """ fro = _to_const(fro) to = _to_const(to) if t_first["from"] != fro: raise RuntimeError( 'From mismatch: {fro1} ("{cf1}") != {fro2} ("{cf2}")'.format( fro1=t_first["from"], cf1=_coord_frame_name(t_first["from"]), fro2=fro, cf2=_coord_frame_name(fro), ) ) if t_first["to"] != t_second["from"]: raise RuntimeError( 'Transform mismatch: t1["to"] = {to1} ("{cf1}"), ' 't2["from"] = {fro2} ("{cf2}")'.format( to1=t_first["to"], cf1=_coord_frame_name(t_first["to"]), fro2=t_second["from"], cf2=_coord_frame_name(t_second["from"]), ) ) if t_second["to"] != to: raise RuntimeError( 'To mismatch: {to1} ("{cf1}") != {to2} ("{cf2}")'.format( to1=t_second["to"], cf1=_coord_frame_name(t_second["to"]), to2=to, cf2=_coord_frame_name(to), ) ) return Transform(fro, to, np.dot(t_second["trans"], t_first["trans"])) @verbose def read_trans(fname, return_all=False, verbose=None): """Read a ``-trans.fif`` file. Parameters ---------- fname : path-like The name of the file. return_all : bool If True, return all transformations in the file. False (default) will only return the first. .. versionadded:: 0.15 %(verbose)s Returns ------- trans : dict | list of dict The transformation dictionary from the fif file. See Also -------- write_trans mne.transforms.Transform """ fname = _check_fname(fname, overwrite="read", must_exist=True) fid, tree, directory = fiff_open(fname) trans = list() with fid: for t in directory: if t.kind == FIFF.FIFF_COORD_TRANS: trans.append(read_tag(fid, t.pos).data) if not return_all: break if len(trans) == 0: raise OSError("This does not seem to be a -trans.fif file.") return trans if return_all else trans[0] @verbose def write_trans(fname, trans, *, overwrite=False, verbose=None): """Write a transformation FIF file. Parameters ---------- fname : path-like The name of the file, which should end in ``-trans.fif``. trans : dict Trans file data, as returned by `~mne.read_trans`. %(overwrite)s %(verbose)s See Also -------- read_trans """ check_fname( fname, "trans", ("-trans.fif", "-trans.fif.gz", "_trans.fif", "_trans.fif.gz") ) fname = _check_fname(fname=fname, overwrite=overwrite) with start_and_end_file(fname) as fid: write_coord_trans(fid, trans) def invert_transform(trans): """Invert a transformation between coordinate systems. Parameters ---------- trans : dict Transform to invert. Returns ------- inv_trans : dict Inverse transform. """ return Transform(trans["to"], trans["from"], np.linalg.inv(trans["trans"])) def transform_surface_to(surf, dest, trans, copy=False): """Transform surface to the desired coordinate system. Parameters ---------- surf : dict Surface. dest : 'meg' | 'mri' | 'head' | int Destination coordinate system. Can be an integer for using FIFF types. trans : dict | list of dict Transformation to use (or a list of possible transformations to check). copy : bool If False (default), operate in-place. Returns ------- res : dict Transformed source space. """ surf = deepcopy(surf) if copy else surf if isinstance(dest, str): if dest not in _str_to_frame: raise KeyError( f'dest must be one of {list(_str_to_frame.keys())}, not "{dest}"' ) dest = _str_to_frame[dest] # convert to integer if surf["coord_frame"] == dest: return surf trans = _ensure_trans(trans, int(surf["coord_frame"]), dest) surf["coord_frame"] = dest surf["rr"] = apply_trans(trans, surf["rr"]) if "nn" in surf: surf["nn"] = apply_trans(trans, surf["nn"], move=False) return surf def get_ras_to_neuromag_trans(nasion, lpa, rpa): """Construct a transformation matrix to the MNE head coordinate system. Construct a transformation matrix from an arbitrary RAS coordinate system to the MNE head coordinate system, in which the x axis passes through the two preauricular points, and the y axis passes through the nasion and is normal to the x axis. (see mne manual, pg. 97) Parameters ---------- nasion : array_like, shape (3,) Nasion point coordinate. lpa : array_like, shape (3,) Left peri-auricular point coordinate. rpa : array_like, shape (3,) Right peri-auricular point coordinate. Returns ------- trans : numpy.array, shape = (4, 4) Transformation matrix to MNE head space. """ # check input args nasion = np.asarray(nasion) lpa = np.asarray(lpa) rpa = np.asarray(rpa) for pt in (nasion, lpa, rpa): if pt.ndim != 1 or len(pt) != 3: raise ValueError( "Points have to be provided as one dimensional arrays of length 3." ) right = rpa - lpa right_unit = right / np.linalg.norm(right) origin = lpa + np.dot(nasion - lpa, right_unit) * right_unit anterior = nasion - origin anterior_unit = anterior / np.linalg.norm(anterior) superior_unit = np.cross(right_unit, anterior_unit) x, y, z = -origin origin_trans = translation(x, y, z) trans_l = np.vstack((right_unit, anterior_unit, superior_unit, [0, 0, 0])) trans_r = np.reshape([0, 0, 0, 1], (4, 1)) rot_trans = np.hstack((trans_l, trans_r)) trans = np.dot(rot_trans, origin_trans) return trans def _get_transforms_to_coord_frame(info, trans, coord_frame="mri"): """Get the transforms to a coordinate frame from device, head and mri.""" head_mri_t = _get_trans(trans, "head", "mri")[0] dev_head_t = _get_trans(info["dev_head_t"], "meg", "head")[0] mri_dev_t = invert_transform( combine_transforms(dev_head_t, head_mri_t, "meg", "mri") ) to_cf_t = dict( meg=_ensure_trans( [dev_head_t, mri_dev_t, Transform("meg", "meg")], fro="meg", to=coord_frame ), head=_ensure_trans( [dev_head_t, head_mri_t, Transform("head", "head")], fro="head", to=coord_frame, ), mri=_ensure_trans( [head_mri_t, mri_dev_t, Transform("mri", "mri")], fro="mri", to=coord_frame ), ) return to_cf_t ############################################################################### # Spherical coordinates and harmonics def _cart_to_sph(cart): """Convert Cartesian coordinates to spherical coordinates. Parameters ---------- cart_pts : ndarray, shape (n_points, 3) Array containing points in Cartesian coordinates (x, y, z) Returns ------- sph_pts : ndarray, shape (n_points, 3) Array containing points in spherical coordinates (rad, azimuth, polar) """ cart = np.atleast_2d(cart) assert cart.ndim == 2 and cart.shape[1] == 3, cart.shape out = np.empty((len(cart), 3)) out[:, 0] = np.sqrt(np.sum(cart * cart, axis=1)) norm = np.where(out[:, 0] > 0, out[:, 0], 1) # protect against / 0 out[:, 1] = np.arctan2(cart[:, 1], cart[:, 0]) out[:, 2] = np.arccos(cart[:, 2] / norm) out = np.nan_to_num(out) return out def _sph_to_cart(sph_pts): """Convert spherical coordinates to Cartesian coordinates. Parameters ---------- sph_pts : ndarray, shape (n_points, 3) Array containing points in spherical coordinates (rad, azimuth, polar) Returns ------- cart_pts : ndarray, shape (n_points, 3) Array containing points in Cartesian coordinates (x, y, z) """ sph_pts = np.atleast_2d(sph_pts) assert sph_pts.ndim == 2 and sph_pts.shape[1] == 3 cart_pts = np.empty((len(sph_pts), 3)) cart_pts[:, 2] = sph_pts[:, 0] * np.cos(sph_pts[:, 2]) xy = sph_pts[:, 0] * np.sin(sph_pts[:, 2]) cart_pts[:, 0] = xy * np.cos(sph_pts[:, 1]) cart_pts[:, 1] = xy * np.sin(sph_pts[:, 1]) return cart_pts def _get_n_moments(order): """Compute the number of multipolar moments (spherical harmonics). Equivalent to :footcite:`DarvasEtAl2006` Eq. 32. .. note:: This count excludes ``degree=0`` (for ``order=0``). Parameters ---------- order : array-like Expansion orders, often ``[int_order, ext_order]``. Returns ------- M : ndarray Number of moments due to each order. """ order = np.asarray(order, int) return (order + 2) * order def _sph_to_cart_partials(az, pol, g_rad, g_az, g_pol): """Convert spherical partial derivatives to cartesian coords. Note: Because we are dealing with partial derivatives, this calculation is not a static transformation. The transformation matrix itself is dependent on azimuth and polar coord. See the 'Spherical coordinate sytem' section here: wikipedia.org/wiki/Vector_fields_in_cylindrical_and_spherical_coordinates Parameters ---------- az : ndarray, shape (n_points,) Array containing spherical coordinates points (azimuth). pol : ndarray, shape (n_points,) Array containing spherical coordinates points (polar). sph_grads : ndarray, shape (n_points, 3) Array containing partial derivatives at each spherical coordinate (radius, azimuth, polar). Returns ------- cart_grads : ndarray, shape (n_points, 3) Array containing partial derivatives in Cartesian coordinates (x, y, z) """ sph_grads = np.c_[g_rad, g_az, g_pol] c_as, s_as = np.cos(az), np.sin(az) c_ps, s_ps = np.cos(pol), np.sin(pol) trans = np.array( [ [c_as * s_ps, -s_as, c_as * c_ps], [s_as * s_ps, c_as, c_ps * s_as], [c_ps, np.zeros_like(c_as), -s_ps], ] ) cart_grads = np.einsum("ijk,kj->ki", trans, sph_grads) return cart_grads def _deg_ord_idx(deg, order): """Get the index into S_in or S_out given a degree and order.""" # The -1 here is because we typically exclude the degree=0 term return deg * deg + deg + order - 1 def _sh_negate(sh, order): """Get the negative spherical harmonic from a positive one.""" assert order >= 0 return sh.conj() * (-1.0 if order % 2 else 1.0) # == (-1) ** order def _sh_complex_to_real(sh, order): """Convert complex to real basis functions. Parameters ---------- sh : array-like Spherical harmonics. Must be from order >=0 even if negative orders are used. order : int Order (usually 'm') of multipolar moment. Returns ------- real_sh : array-like The real version of the spherical harmonics. Notes ----- This does not include the Condon-Shortely phase. """ if order == 0: return np.real(sh) else: return np.sqrt(2.0) * (np.real if order > 0 else np.imag)(sh) def _sh_real_to_complex(shs, order): """Convert real spherical harmonic pair to complex. Parameters ---------- shs : ndarray, shape (2, ...) The real spherical harmonics at ``[order, -order]``. order : int Order (usually 'm') of multipolar moment. Returns ------- sh : array-like, shape (...) The complex version of the spherical harmonics. """ if order == 0: return shs[0] else: return (shs[0] + 1j * np.sign(order) * shs[1]) / np.sqrt(2.0) def _compute_sph_harm(order, az, pol): """Compute complex spherical harmonics of spherical coordinates.""" out = np.empty((len(az), _get_n_moments(order) + 1)) # _deg_ord_idx(0, 0) = -1 so we're actually okay to use it here for degree in range(order + 1): for order_ in range(degree + 1): sph = sph_harm(order_, degree, az, pol) out[:, _deg_ord_idx(degree, order_)] = _sh_complex_to_real(sph, order_) if order_ > 0: out[:, _deg_ord_idx(degree, -order_)] = _sh_complex_to_real( _sh_negate(sph, order_), -order_ ) return out ############################################################################### # Thin-plate spline transformations # Adapted from code from the MATLAB file exchange: # https://www.mathworks.com/matlabcentral/fileexchange/ # 53867-3d-point-set-warping-by-thin-plate-rbf-function # https://www.mathworks.com/matlabcentral/fileexchange/ # 53828-rbf-or-thin-plate-splines-image-warping # Associated (BSD 2-clause) license: # # Copyright (c) 2015, Wang Lin # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in # the documentation and/or other materials provided with the distribution # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. class _TPSWarp: """Transform points using thin-plate spline (TPS) warping. Notes ----- Based on the method by :footcite:`Bookstein1989` and adapted from code by Wang Lin (wanglin193@hotmail.com>). References ---------- .. footbibliography:: """ def fit(self, source, destination, reg=1e-3): assert source.shape[1] == destination.shape[1] == 3 assert source.shape[0] == destination.shape[0] # Forward warping, different from image warping, use |dist|**2 dists = _tps(cdist(source, destination, "sqeuclidean")) # Y = L * w # L: RBF matrix about source # Y: Points matrix about destination P = np.concatenate((np.ones((source.shape[0], 1)), source), axis=-1) L = np.vstack([np.hstack([dists, P]), np.hstack([P.T, np.zeros((4, 4))])]) Y = np.concatenate((destination, np.zeros((4, 3))), axis=0) # Regularize it a bit L += reg * np.eye(L.shape[0]) self._destination = destination.copy() self._weights = linalg.lstsq(L, Y)[0] return self @verbose def transform(self, pts, verbose=None): """Apply the warp. Parameters ---------- pts : shape (n_transform, 3) Source points to warp to the destination. Returns ------- dest : shape (n_transform, 3) The transformed points. """ logger.info(f"Transforming {len(pts)} points") assert pts.shape[1] == 3 # for memory reasons, we should do this in ~100 MB chunks out = np.zeros_like(pts) n_splits = max( int((pts.shape[0] * self._destination.shape[0]) / (100e6 / 8.0)), 1 ) for this_out, this_pts in zip( np.array_split(out, n_splits), np.array_split(pts, n_splits) ): dists = _tps(cdist(this_pts, self._destination, "sqeuclidean")) L = np.hstack((dists, np.ones((dists.shape[0], 1)), this_pts)) this_out[:] = np.dot(L, self._weights) assert not (out == 0).any() return out def _tps(distsq): """Thin-plate function (r ** 2) * np.log(r).""" # NOTE: For our warping functions, a radial basis like # exp(-distsq / radius ** 2) could also be used out = np.zeros_like(distsq) mask = distsq > 0 # avoid log(0) valid = distsq[mask] out[mask] = valid * np.log(valid) return out ############################################################################### # Spherical harmonic approximation + TPS warp class _SphericalSurfaceWarp: """Warp surfaces via spherical harmonic smoothing and thin-plate splines. Notes ----- This class can be used to warp data from a source subject to a destination subject, as described in :footcite:`DarvasEtAl2006`. The procedure is: 1. Perform a spherical harmonic approximation to the source and destination surfaces, which smooths them and allows arbitrary interpolation. 2. Choose a set of matched points on the two surfaces. 3. Use thin-plate spline warping (common in 2D image manipulation) to generate transformation coefficients. 4. Warp points from the source subject (which should be inside the original surface) to the destination subject. .. versionadded:: 0.14 References ---------- .. footbibliography:: """ def __repr__(self): rep = " 1e-6)]) src_center = _fit_sphere(hsp, disp=False)[1] source = source - src_center hsp = np.array([p for p in destination if not (p[2] < 0 and p[1] > 0)]) dest_center = _fit_sphere(hsp, disp=False)[1] destination = destination - dest_center logger.info( " Using centers {np.array_str(src_center, None, 3)} -> " "{np.array_str(dest_center, None, 3)}" ) self._fit_params = dict( n_src=len(source), n_dest=len(destination), match=match, n_match=len(match_rr), order=order, reg=reg, ) assert source.shape[1] == destination.shape[1] == 3 self._destination = destination.copy() # 1. Compute spherical coordinates of source and destination points logger.info(" Converting to spherical coordinates") src_rad_az_pol = _cart_to_sph(source).T dest_rad_az_pol = _cart_to_sph(destination).T match_rad_az_pol = _cart_to_sph(match_rr).T del match_rr # 2. Compute spherical harmonic coefficients for all points logger.info( f" Computing spherical harmonic approximation with order {order}" ) src_sph = _compute_sph_harm(order, *src_rad_az_pol[1:]) dest_sph = _compute_sph_harm(order, *dest_rad_az_pol[1:]) match_sph = _compute_sph_harm(order, *match_rad_az_pol[1:]) # 3. Fit spherical harmonics to both surfaces to smooth them src_coeffs = linalg.lstsq(src_sph, src_rad_az_pol[0])[0] dest_coeffs = linalg.lstsq(dest_sph, dest_rad_az_pol[0])[0] # 4. Smooth both surfaces using these coefficients, and evaluate at # the "shape" points logger.info( f" Matching {len(match_sph)} points ({match}) on smoothed surfaces" ) src_rad_az_pol = match_rad_az_pol.copy() src_rad_az_pol[0] = np.abs(np.dot(match_sph, src_coeffs)) dest_rad_az_pol = match_rad_az_pol.copy() dest_rad_az_pol[0] = np.abs(np.dot(match_sph, dest_coeffs)) # 5. Convert matched points to Cartesian coordinates and put back source = _sph_to_cart(src_rad_az_pol.T) source += src_center destination = _sph_to_cart(dest_rad_az_pol.T) destination += dest_center # 6. Compute TPS warp of matched points from smoothed surfaces self._warp = _TPSWarp().fit(source, destination, reg) logger.info("[done]") return self @verbose def transform(self, source, verbose=None): """Transform arbitrary source points to the destination. Parameters ---------- source : ndarray, shape (n_pts, 3) Source points to transform. They do not need to be the same points that were used to generate the model, although ideally they will be inside the convex hull formed by the original source points. %(verbose)s Returns ------- destination : ndarray, shape (n_pts, 3) The points transformed to the destination space. """ return self._warp.transform(source) ############################################################################### # Other transforms def _pol_to_cart(pol): """Transform polar coordinates to cartesian.""" out = np.empty((len(pol), 2)) if pol.shape[1] == 2: # phi, theta out[:, 0] = pol[:, 0] * np.cos(pol[:, 1]) out[:, 1] = pol[:, 0] * np.sin(pol[:, 1]) else: # radial distance, theta, phi d = pol[:, 0] * np.sin(pol[:, 2]) out[:, 0] = d * np.cos(pol[:, 1]) out[:, 1] = d * np.sin(pol[:, 1]) return out def _topo_to_sph(topo): """Convert 2D topo coordinates to spherical coordinates.""" assert topo.ndim == 2 and topo.shape[1] == 2 sph = np.ones((len(topo), 3)) sph[:, 1] = -np.deg2rad(topo[:, 0]) sph[:, 2] = np.pi * topo[:, 1] return sph ############################################################################### # Quaternions @jit() def quat_to_rot(quat): """Convert a set of quaternions to rotations. Parameters ---------- quat : array, shape (..., 3) The q1, q2, and q3 (x, y, z) parameters of a unit quaternion. Returns ------- rot : array, shape (..., 3, 3) The corresponding rotation matrices. See Also -------- rot_to_quat """ # z = a + bi + cj + dk b, c, d = quat[..., 0], quat[..., 1], quat[..., 2] bb, cc, dd = b * b, c * c, d * d # use max() here to be safe in case roundoff errs put us over aa = np.maximum(1.0 - bb - cc - dd, 0.0) a = np.sqrt(aa) ab_2 = 2 * a * b ac_2 = 2 * a * c ad_2 = 2 * a * d bc_2 = 2 * b * c bd_2 = 2 * b * d cd_2 = 2 * c * d rotation = np.empty(quat.shape[:-1] + (3, 3)) rotation[..., 0, 0] = aa + bb - cc - dd rotation[..., 0, 1] = bc_2 - ad_2 rotation[..., 0, 2] = bd_2 + ac_2 rotation[..., 1, 0] = bc_2 + ad_2 rotation[..., 1, 1] = aa + cc - bb - dd rotation[..., 1, 2] = cd_2 - ab_2 rotation[..., 2, 0] = bd_2 - ac_2 rotation[..., 2, 1] = cd_2 + ab_2 rotation[..., 2, 2] = aa + dd - bb - cc return rotation @jit() def _one_rot_to_quat(rot): """Convert a rotation matrix to quaternions.""" # see e.g. http://www.euclideanspace.com/maths/geometry/rotations/ # conversions/matrixToQuaternion/ det = np.linalg.det(np.reshape(rot, (3, 3))) if np.abs(det - 1.0) > 1e-3: raise ValueError("Matrix is not a pure rotation, got determinant != 1") t = 1.0 + rot[0] + rot[4] + rot[8] if t > np.finfo(rot.dtype).eps: s = np.sqrt(t) * 2.0 # qw = 0.25 * s qx = (rot[7] - rot[5]) / s qy = (rot[2] - rot[6]) / s qz = (rot[3] - rot[1]) / s elif rot[0] > rot[4] and rot[0] > rot[8]: s = np.sqrt(1.0 + rot[0] - rot[4] - rot[8]) * 2.0 # qw = (rot[7] - rot[5]) / s qx = 0.25 * s qy = (rot[1] + rot[3]) / s qz = (rot[2] + rot[6]) / s elif rot[4] > rot[8]: s = np.sqrt(1.0 - rot[0] + rot[4] - rot[8]) * 2 # qw = (rot[2] - rot[6]) / s qx = (rot[1] + rot[3]) / s qy = 0.25 * s qz = (rot[5] + rot[7]) / s else: s = np.sqrt(1.0 - rot[0] - rot[4] + rot[8]) * 2.0 # qw = (rot[3] - rot[1]) / s qx = (rot[2] + rot[6]) / s qy = (rot[5] + rot[7]) / s qz = 0.25 * s return np.array((qx, qy, qz)) def rot_to_quat(rot): """Convert a set of rotations to quaternions. Parameters ---------- rot : array, shape (..., 3, 3) The rotation matrices to convert. Returns ------- quat : array, shape (..., 3) The q1, q2, and q3 (x, y, z) parameters of the corresponding unit quaternions. See Also -------- quat_to_rot """ rot = rot.reshape(rot.shape[:-2] + (9,)) return np.apply_along_axis(_one_rot_to_quat, -1, rot) def _quat_to_affine(quat): assert quat.shape == (6,) affine = np.eye(4) affine[:3, :3] = quat_to_rot(quat[:3]) affine[:3, 3] = quat[3:] return affine def _affine_to_quat(affine): assert affine.shape[-2:] == (4, 4) return np.concatenate( [rot_to_quat(affine[..., :3, :3]), affine[..., :3, 3]], axis=-1, ) def _angle_dist_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"): a = _affine_to_quat(a) b = np.zeros(6) if b is None else _affine_to_quat(b) ang = _angle_between_quats(a[..., :3], b[..., :3]) dist = np.linalg.norm(a[..., 3:] - b[..., 3:], axis=-1) assert isinstance(angle_units, str) and angle_units in ("rad", "deg") if angle_units == "deg": ang = np.rad2deg(ang) assert isinstance(distance_units, str) and distance_units in ("m", "mm") if distance_units == "mm": dist *= 1e3 return ang, dist def _angle_between_quats(x, y=None): """Compute the ang between two quaternions w/3-element representations.""" # z = conj(x) * y # conjugate just negates all but the first element in a 4-element quat, # so it's just a negative for us y = np.zeros(3) if y is None else y z = _quat_mult(-x, y) z0 = _quat_real(z) return 2 * np.arctan2(np.linalg.norm(z, axis=-1), z0) def _quat_real(quat): """Get the real part of our 3-element quat.""" assert quat.shape[-1] == 3, quat.shape[-1] return np.sqrt( np.maximum( 1.0 - quat[..., 0] * quat[..., 0] - quat[..., 1] * quat[..., 1] - quat[..., 2] * quat[..., 2], 0.0, ) ) def _quat_mult(one, two): assert one.shape[-1] == two.shape[-1] == 3 w1 = _quat_real(one) w2 = _quat_real(two) out = np.empty(np.broadcast(one, two).shape) # Most mathematical expressions use this sort of notation x1, x2 = one[..., 0], two[..., 0] y1, y2 = one[..., 1], two[..., 1] z1, z2 = one[..., 2], two[..., 2] out[..., 0] = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 out[..., 1] = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 out[..., 2] = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 # only need to compute w because we need signs from it w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 signs = np.sign(w) signs = np.where(signs, signs, 1) out *= signs[..., np.newaxis] return out def _skew_symmetric_cross(a): """Compute the skew-symmetric cross product of a vector.""" return np.array([[0.0, -a[2], a[1]], [a[2], 0.0, -a[0]], [-a[1], a[0], 0.0]]) def _find_vector_rotation(a, b): """Find the rotation matrix that maps unit vector a to b.""" # Rodrigues' rotation formula: # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula # http://math.stackexchange.com/a/476311 R = np.eye(3) v = np.cross(a, b) if np.allclose(v, 0.0): # identical return R s = np.dot(v, v) # sine of the angle between them c = np.dot(a, b) # cosine of the angle between them vx = _skew_symmetric_cross(v) R += vx + np.dot(vx, vx) * (1 - c) / s return R @jit() def _fit_matched_points(p, x, weights=None, scale=False): """Fit matched points using an analytical formula.""" # Follow notation of P.J. Besl and N.D. McKay, A Method for # Registration of 3-D Shapes, IEEE Trans. Patt. Anal. Machine Intell., 14, # 239 - 255, 1992. # # The original method is actually by Horn, Closed-form solution of absolute # orientation using unit quaternions, J Opt. Soc. Amer. A vol 4 no 4 # pp 629-642, Apr. 1987. This paper describes how weights can be # easily incorporated, and a uniform scale factor can be computed. # # Caution: This can be dangerous if there are 3 points, or 4 points in # a symmetric layout, as the geometry can be explained # equivalently under 180 degree rotations. # # Eventually this can be extended to also handle a uniform scale factor, # as well. assert p.shape == x.shape assert p.ndim == 2 assert p.shape[1] == 3 # (weighted) centroids weights_ = np.full((p.shape[0], 1), 1.0 / max(p.shape[0], 1)) if weights is not None: weights_[:] = np.reshape(weights / weights.sum(), (weights.size, 1)) mu_p = np.dot(weights_.T, p)[0] mu_x = np.dot(weights_.T, x)[0] dots = np.dot(p.T, weights_ * x) Sigma_px = dots - np.outer(mu_p, mu_x) # eq 24 # x and p should no longer be used A_ij = Sigma_px - Sigma_px.T Delta = np.array([A_ij[1, 2], A_ij[2, 0], A_ij[0, 1]]) tr_Sigma_px = np.trace(Sigma_px) # "N" in Horn: Q = np.empty((4, 4)) Q[0, 0] = tr_Sigma_px Q[0, 1:] = Delta Q[1:, 0] = Delta Q[1:, 1:] = Sigma_px + Sigma_px.T - tr_Sigma_px * np.eye(3) _, v = np.linalg.eigh(Q) # sorted ascending quat = np.empty(6) quat[:3] = v[1:, -1] if v[0, -1] != 0: quat[:3] *= np.sign(v[0, -1]) rot = quat_to_rot(quat[:3]) # scale factor is easy once we know the rotation if scale: # p is "right" (from), x is "left" (to) in Horn 1987 dev_x = x - mu_x dev_p = p - mu_p dev_x *= dev_x dev_p *= dev_p if weights is not None: dev_x *= weights_ dev_p *= weights_ s = np.sqrt(np.sum(dev_x) / np.sum(dev_p)) else: s = 1.0 # translation is easy once rotation and scale are known quat[3:] = mu_x - s * np.dot(rot, mu_p) return quat, s def _average_quats(quats, weights=None): """Average unit quaternions properly.""" assert quats.ndim == 2 and quats.shape[1] in (3, 4) if weights is None: weights = np.ones(quats.shape[0]) assert (weights >= 0).all() norm = weights.sum() if weights.sum() == 0: return np.zeros(3) weights = weights / norm # The naive step here would be: # # avg_quat = np.dot(weights, quats[:, :3]) # # But this is not robust to quaternions having sign ambiguity, # i.e., q == -q. Thus we instead use the rank 1 update method: # # https://arc.aiaa.org/doi/abs/10.2514/1.28949?journalCode=jgcd # https://github.com/tolgabirdal/averaging_quaternions/blob/master/wavg_quaternion_markley.m # noqa: E501 # # We use unit quats and don't store the last element, so reconstruct it # to get our 4-element quaternions: quats = np.concatenate((_quat_real(quats)[..., np.newaxis], quats), -1) quats *= weights[:, np.newaxis] A = np.einsum("ij,ik->jk", quats, quats) # sum of outer product of each q avg_quat = linalg.eigh(A)[1][:, -1] # largest eigenvector is the avg # Same as the largest eigenvector from the concatenation of all as # svd(quats, full_matrices=False)[-1][0], but faster. # # By local convention we take the real term (which we remove from our # representation) as positive. Since it can be zero, let's just ensure # that the first non-zero element is positive. This shouldn't matter once # we go to a rotation matrix, but it's nice for testing to have # consistency. avg_quat *= np.sign(avg_quat[avg_quat != 0][0]) avg_quat = avg_quat[1:] return avg_quat @fill_doc def read_ras_mni_t(subject, subjects_dir=None): """Read a subject's RAS to MNI transform. Parameters ---------- subject : str The subject. %(subjects_dir)s Returns ------- ras_mni_t : instance of Transform The transform from RAS to MNI (in mm). """ subjects_dir = Path(get_subjects_dir(subjects_dir=subjects_dir, raise_error=True)) _validate_type(subject, "str", "subject") fname = subjects_dir / subject / "mri" / "transforms" / "talairach.xfm" fname = str( _check_fname( fname, "read", True, "FreeSurfer Talairach transformation file", ) ) return Transform("ras", "mni_tal", _read_fs_xfm(fname)[0]) def _read_fs_xfm(fname): """Read a Freesurfer transform from a .xfm file.""" assert fname.endswith(".xfm") with open(fname) as fid: logger.debug(f"Reading FreeSurfer talairach.xfm file:\n{fname}") # read lines until we get the string 'Linear_Transform', which precedes # the data transformation matrix comp = "Linear_Transform" for li, line in enumerate(fid): if li == 0: kind = line.strip() logger.debug(f"Found: {repr(kind)}") if line[: len(comp)] == comp: # we have the right line, so don't read any more break else: raise ValueError( f'Failed to find "Linear_Transform" string in xfm file:\n{fname}' ) xfm = list() # read the transformation matrix (3x4) for ii, line in enumerate(fid): digs = [float(s) for s in line.strip("\n;").split()] xfm.append(digs) if ii == 2: break else: raise ValueError("Could not find enough linear transform lines") xfm.append([0.0, 0.0, 0.0, 1.0]) xfm = np.array(xfm, dtype=float) return xfm, kind def _write_fs_xfm(fname, xfm, kind): """Write a Freesurfer transform to a .xfm file.""" with open(fname, "wb") as fid: fid.write((kind + "\n\nTtransform_Type = Linear;\n").encode("ascii")) fid.write("Linear_Transform =\n".encode("ascii")) for li, line in enumerate(xfm[:-1]): line = " ".join([f"{part:0.6f}" for part in line]) line += "\n" if li < 2 else ";\n" fid.write(line.encode("ascii")) def _quat_to_euler(quat): euler = np.empty(quat.shape) x, y, z = quat[..., 0], quat[..., 1], quat[..., 2] w = _quat_real(quat) np.arctan2(2 * (w * x + y * z), 1 - 2 * (x * x + y * y), out=euler[..., 0]) np.arcsin(2 * (w * y - x * z), out=euler[..., 1]) np.arctan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z), out=euler[..., 2]) return euler def _euler_to_quat(euler): quat = np.empty(euler.shape) phi, theta, psi = euler[..., 0] / 2, euler[..., 1] / 2, euler[..., 2] / 2 cphi, sphi = np.cos(phi), np.sin(phi) del phi ctheta, stheta = np.cos(theta), np.sin(theta) del theta cpsi, spsi = np.cos(psi), np.sin(psi) del psi mult = np.sign(cphi * ctheta * cpsi + sphi * stheta * spsi) if np.isscalar(mult): mult = 1.0 if mult == 0 else mult else: mult[mult == 0] = 1.0 mult = mult[..., np.newaxis] quat[..., 0] = sphi * ctheta * cpsi - cphi * stheta * spsi quat[..., 1] = cphi * stheta * cpsi + sphi * ctheta * spsi quat[..., 2] = cphi * ctheta * spsi - sphi * stheta * cpsi quat *= mult return quat ############################################################################### # Affine Registration and SDR _ORDERED_STEPS = ("translation", "rigid", "affine", "sdr") def _validate_zooms(zooms): _validate_type(zooms, (dict, list, tuple, "numeric", None), "zooms") zooms = _handle_default("transform_zooms", zooms) for key, val in zooms.items(): _check_option("zooms key", key, _ORDERED_STEPS) if val is not None: val = tuple(float(x) for x in np.array(val, dtype=float).ravel()) _check_option(f"len(zooms[{repr(key)})", len(val), (1, 3)) if len(val) == 1: val = val * 3 for this_zoom in val: if this_zoom <= 1: raise ValueError(f"Zooms must be > 1, got {this_zoom}") zooms[key] = val return zooms def _validate_niter(niter): _validate_type(niter, (dict, list, tuple, None), "niter") niter = _handle_default("transform_niter", niter) for key, value in niter.items(): _check_option("niter key", key, _ORDERED_STEPS) _check_option(f"len(niter[{repr(key)}])", len(value), (1, 2, 3)) return niter def _validate_pipeline(pipeline): _validate_type(pipeline, (str, list, tuple), "pipeline") pipeline_defaults = dict( all=_ORDERED_STEPS, rigids=_ORDERED_STEPS[: _ORDERED_STEPS.index("rigid") + 1], affines=_ORDERED_STEPS[: _ORDERED_STEPS.index("affine") + 1], ) if isinstance(pipeline, str): # use defaults _check_option( "pipeline", pipeline, ("all", "rigids", "affines"), extra="when str" ) pipeline = pipeline_defaults[pipeline] for ii, step in enumerate(pipeline): name = f"pipeline[{ii}]" _validate_type(step, str, name) _check_option(name, step, _ORDERED_STEPS) ordered_pipeline = tuple(sorted(pipeline, key=lambda x: _ORDERED_STEPS.index(x))) if tuple(pipeline) != ordered_pipeline: raise ValueError( f"Steps in pipeline are out of order, expected {ordered_pipeline} " f"but got {pipeline} instead" ) if len(set(pipeline)) != len(pipeline): raise ValueError("Steps in pipeline should not be repeated") return tuple(pipeline) def _compute_r2(a, b): return 100 * (a.ravel() @ b.ravel()) / (np.linalg.norm(a) * np.linalg.norm(b)) def _reslice_normalize(img, zooms): from dipy.align.reslice import reslice img_zooms = img.header.get_zooms()[:3] img_affine = img.affine img = _get_img_fdata(img) if zooms is not None: img, img_affine = reslice(img, img_affine, img_zooms, zooms) img /= img.max() # normalize return img, img_affine @verbose def compute_volume_registration( moving, static, pipeline="all", zooms=None, niter=None, *, starting_affine=None, verbose=None, ): """Align two volumes using an affine and, optionally, SDR. Parameters ---------- %(moving)s %(static)s %(pipeline)s zooms : float | tuple | dict | None The voxel size of volume for each spatial dimension in mm. If None (default), MRIs won't be resliced (slow, but most accurate). Can be a tuple to provide separate zooms for each dimension (X/Y/Z), or a dict with keys ``['translation', 'rigid', 'affine', 'sdr']`` (each with values that are float`, tuple, or None) to provide separate reslicing/accuracy for the steps. %(niter)s starting_affine : ndarray The affine to initialize the registration with. .. versionadded:: 1.2 %(verbose)s Returns ------- %(reg_affine)s %(sdr_morph)s Notes ----- This function is heavily inspired by and extends :func:`dipy.align.affine_registration `. .. versionadded:: 0.24 """ return _compute_volume_registration( moving, static, pipeline, zooms, niter, starting_affine=starting_affine )[:2] def _compute_volume_registration( moving, static, pipeline, zooms, niter, *, starting_affine=None ): nib = _import_nibabel("SDR morph") _require_version("dipy", "SDR morph", "0.10.1") with np.testing.suppress_warnings(): from dipy.align import ( affine, affine_registration, center_of_mass, imwarp, metrics, rigid, translation, ) from dipy.align.imaffine import AffineMap # input validation _validate_type(moving, nib.spatialimages.SpatialImage, "moving") _validate_type(static, nib.spatialimages.SpatialImage, "static") original_zoom = np.mean(moving.header.get_zooms()[:3]) zooms = _validate_zooms(zooms) niter = _validate_niter(niter) pipeline = _validate_pipeline(pipeline) logger.info("Computing registration...") # affine optimizations reg_affine = starting_affine sdr_morph = None pipeline_options = dict( translation=[center_of_mass, translation], rigid=[rigid], affine=[affine] ) sigmas_mm = np.array([3.0, 1.0, 0.0]) # default for affine_registration sigma_diff_mm = 2.0 factors = [4, 2, 1] current_zoom = None for i, step in enumerate(pipeline): # reslice image with zooms if i == 0 or zooms[step] != zooms[pipeline[i - 1]]: if zooms[step] is not None: logger.info(f"Reslicing to zooms={zooms[step]} for {step} ...") current_zoom = np.mean(zooms[step]) else: logger.info(f"Using original zooms for {step} ...") current_zoom = original_zoom static_zoomed, static_affine = _reslice_normalize(static, zooms[step]) moving_zoomed, moving_affine = _reslice_normalize(moving, zooms[step]) logger.info(f"Optimizing {step}:") if step == "sdr": # happens last sigma_diff_vox = sigma_diff_mm / current_zoom affine_map = AffineMap( reg_affine, # apply registration here static_zoomed.shape, static_affine, moving_zoomed.shape, moving_affine, ) moving_zoomed = affine_map.transform(moving_zoomed) metric = metrics.CCMetric( dim=3, sigma_diff=sigma_diff_vox, radius=max(int(np.ceil(2 * sigma_diff_vox)), 1), ) sdr = imwarp.SymmetricDiffeomorphicRegistration(metric, niter[step]) with wrapped_stdout(indent=" ", cull_newlines=True): sdr_morph = sdr.optimize( static_zoomed, moving_zoomed, static_affine, static_affine ) moved_zoomed = sdr_morph.transform(moving_zoomed) else: sigmas_vox = list(sigmas_mm / current_zoom) with wrapped_stdout(indent=" ", cull_newlines=True): moved_zoomed, reg_affine = affine_registration( moving_zoomed, static_zoomed, moving_affine, static_affine, nbins=32, metric="MI", pipeline=pipeline_options[step], level_iters=niter[step], sigmas=sigmas_vox, factors=factors, starting_affine=reg_affine, ) # report some useful information if step in ("translation", "rigid"): angle, dist = _angle_dist_between_rigid(reg_affine, angle_units="deg") logger.info(f" Translation: {dist:6.1f} mm") if step == "rigid": logger.info(f" Rotation: {angle:6.1f}°") assert moved_zoomed.shape == static_zoomed.shape, step r2 = _compute_r2(static_zoomed, moved_zoomed) logger.info(f" R²: {r2:6.1f}%") return ( reg_affine, sdr_morph, static_zoomed.shape, static_affine, moving_zoomed.shape, moving_affine, ) @verbose def apply_volume_registration( moving, static, reg_affine, sdr_morph=None, interpolation="linear", cval=0.0, verbose=None, ): """Apply volume registration. Uses registration parameters computed by :func:`~mne.transforms.compute_volume_registration`. Parameters ---------- %(moving)s %(static)s %(reg_affine)s %(sdr_morph)s interpolation : str Interpolation to be used during the interpolation. Can be ``"linear"`` (default) or ``"nearest"``. cval : float | str The constant value to assume exists outside the bounds of the ``moving`` image domain. Can be a string percentage like ``'1%%'`` to use the given percentile of image data as the constant value. %(verbose)s Returns ------- reg_img : instance of SpatialImage The image after affine (and SDR, if provided) registration. Notes ----- .. versionadded:: 0.24 """ _require_version("dipy", "SDR morph", "0.10.1") _import_nibabel("SDR morph") from dipy.align.imaffine import AffineMap from dipy.align.imwarp import DiffeomorphicMap from nibabel.spatialimages import SpatialImage _validate_type(moving, SpatialImage, "moving") _validate_type(static, SpatialImage, "static") _validate_type(reg_affine, np.ndarray, "reg_affine") _check_option("reg_affine.shape", reg_affine.shape, ((4, 4),)) _validate_type(sdr_morph, (DiffeomorphicMap, None), "sdr_morph") _validate_type(cval, ("numeric", str), "cval") perc = None if isinstance(cval, str): if not cval.endswith("%"): raise ValueError(f"cval must end with % if str, got {cval}") perc = float(cval[:-1]) logger.info("Applying affine registration ...") moving_affine = moving.affine moving = np.asarray(moving.dataobj, dtype=float) if perc is not None: cval = np.percentile(moving, perc) logger.info(f"Using a lower bound at the {perc} percentile: {cval}") moving -= cval static, static_affine = np.asarray(static.dataobj), static.affine affine_map = AffineMap( reg_affine, static.shape, static_affine, moving.shape, moving_affine ) reg_data = affine_map.transform(moving, interpolation=interpolation) if sdr_morph is not None: logger.info("Applying SDR warp ...") reg_data = sdr_morph.transform( reg_data, interpolation=interpolation, image_world2grid=np.linalg.inv(static_affine), out_shape=static.shape, out_grid2world=static_affine, ) reg_data += cval reg_img = SpatialImage(reg_data, static_affine) logger.info("[done]") return reg_img @verbose def apply_volume_registration_points( info, trans, moving, static, reg_affine, sdr_morph=None, verbose=None ): """Apply volume registration. Uses registration parameters computed by :func:`~mne.transforms.compute_volume_registration`. Parameters ---------- %(info_not_none)s %(trans_not_none)s %(moving)s %(static)s %(reg_affine)s %(sdr_morph)s %(verbose)s Returns ------- %(info_not_none)s trans2 : instance of Transform The head->mri (surface RAS) transform for the static image. Notes ----- .. versionadded:: 1.4.0 """ from .channels import compute_native_head_t, make_dig_montage _require_version("nibabel", "volume registration", "2.1.0") from dipy.align.imwarp import DiffeomorphicMap from nibabel import MGHImage from nibabel.spatialimages import SpatialImage _validate_type(moving, SpatialImage, "moving") _validate_type(static, SpatialImage, "static") _validate_type(reg_affine, np.ndarray, "reg_affine") _check_option("reg_affine.shape", reg_affine.shape, ((4, 4),)) _validate_type(sdr_morph, (DiffeomorphicMap, None), "sdr_morph") moving_mgh = MGHImage(np.array(moving.dataobj).astype(np.float32), moving.affine) static_mgh = MGHImage(np.array(static.dataobj).astype(np.float32), static.affine) montage = info.get_montage() montage_kwargs = montage.get_positions() trans = _ensure_trans(trans, "head", "mri") montage.apply_trans(trans) # to moving surface RAS locs = np.array(list(montage.get_positions()["ch_pos"].values())) locs = apply_trans( Transform( # to moving voxels fro="mri", to="mri_voxel", trans=np.linalg.inv(moving_mgh.header.get_vox2ras_tkr()), ), locs * 1000, ) locs = apply_trans( Transform( # to moving ras fro="mri_voxel", to="ras", trans=moving_mgh.header.get_vox2ras() ), locs, ) locs = apply_trans( Transform( # to static ras fro="ras", to="ras", trans=np.linalg.inv(reg_affine) ), locs, ) if sdr_morph is not None: _require_version("dipy", "SDR morph", "1.6.0") locs = sdr_morph.transform_points( locs, sdr_morph.domain_grid2world, sdr_morph.domain_world2grid ) locs = apply_trans( Transform( # to static voxels fro="ras", to="mri_voxel", trans=np.linalg.inv(static_mgh.header.get_vox2ras()), ), locs, ) locs = ( apply_trans( Transform( # to static surface RAS fro="mri_voxel", to="mri", trans=static_mgh.header.get_vox2ras_tkr() ), locs, ) / 1000 ) montage_kwargs["coord_frame"] = "mri" montage_kwargs["ch_pos"] = {ch: loc for ch, loc in zip(montage.ch_names, locs)} montage2 = make_dig_montage(**montage_kwargs) trans2 = compute_native_head_t(montage2) info2 = info.copy() info2.set_montage(montage2) # converts to head coordinates return info2, trans2 class _MatchedDisplacementFieldInterpolator: """Interpolate from matched points using a displacement field in ND. For a demo, see https://gist.github.com/larsoner/fbe32d57996848395854d5e59dff1e10 and related tests. """ def __init__(self, fro, to, *, extrema=None): from scipy.interpolate import LinearNDInterpolator fro = np.array(fro, float) to = np.array(to, float) assert fro.shape == to.shape assert fro.ndim == 2 # this restriction is only necessary because it's what # _fit_matched_points requires assert fro.shape[1] == 3 # Prealign using affine + uniform scaling self._quat, self._scale = _fit_matched_points(fro, to, scale=True) trans = _quat_to_affine(self._quat) trans[:3, :3] *= self._scale self._affine = trans fro = apply_trans(trans, fro) # Add points at extrema if extrema is None: delta = (to.max(axis=0) - to.min(axis=0)) / 2.0 assert (delta > 0).all() extrema = np.array([fro.min(axis=0) - delta, fro.max(axis=0) + delta]) assert extrema.shape == (2, 3) # min, max self._extrema = np.array(np.meshgrid(*extrema.T)).T.reshape(-1, fro.shape[-1]) fro_concat = np.concatenate((fro, self._extrema)) to_concat = np.concatenate((to, self._extrema)) # Compute the interpolator (which internally uses Delaunay) self._interp = LinearNDInterpolator(fro_concat, to_concat) def __call__(self, x): assert x.ndim in (1, 2) and x.shape[-1] == 3 assert np.isfinite(x).all() singleton = x.ndim == 1 x = apply_trans(self._affine, x) assert np.isfinite(x).all() out = self._interp(x) assert np.isfinite(out).all() self._last_deltas = np.linalg.norm(x - out, axis=1) out = out[0] if singleton else out return out