208 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			208 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Authors: The MNE-Python contributors.
 | 
						|
# License: BSD-3-Clause
 | 
						|
# Copyright the MNE-Python contributors.
 | 
						|
 | 
						|
from datetime import datetime
 | 
						|
from glob import glob
 | 
						|
from os.path import basename, join, splitext
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from ...utils import _soft_import, _validate_type, logger, warn
 | 
						|
 | 
						|
 | 
						|
def _read_events(input_fname, info):
 | 
						|
    """Read events for the record.
 | 
						|
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    input_fname : path-like
 | 
						|
        The file path.
 | 
						|
    info : dict
 | 
						|
        Header info array.
 | 
						|
    """
 | 
						|
    n_samples = info["last_samps"][-1]
 | 
						|
    mff_events, event_codes = _read_mff_events(input_fname, info["sfreq"])
 | 
						|
    info["n_events"] = len(event_codes)
 | 
						|
    info["event_codes"] = event_codes
 | 
						|
    events = np.zeros([info["n_events"], info["n_segments"] * n_samples])
 | 
						|
    for n, event in enumerate(event_codes):
 | 
						|
        for i in mff_events[event]:
 | 
						|
            if (i < 0) or (i >= events.shape[1]):
 | 
						|
                continue
 | 
						|
            events[n][i] = n + 1
 | 
						|
    return events, info, mff_events
 | 
						|
 | 
						|
 | 
						|
def _read_mff_events(filename, sfreq):
 | 
						|
    """Extract the events.
 | 
						|
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    filename : path-like
 | 
						|
        File path.
 | 
						|
    sfreq : float
 | 
						|
        The sampling frequency
 | 
						|
    """
 | 
						|
    orig = {}
 | 
						|
    for xml_file in glob(join(filename, "*.xml")):
 | 
						|
        xml_type = splitext(basename(xml_file))[0]
 | 
						|
        orig[xml_type] = _parse_xml(xml_file)
 | 
						|
    xml_files = orig.keys()
 | 
						|
    xml_events = [x for x in xml_files if x[:7] == "Events_"]
 | 
						|
    for item in orig["info"]:
 | 
						|
        if "recordTime" in item:
 | 
						|
            start_time = _ns2py_time(item["recordTime"])
 | 
						|
            break
 | 
						|
    markers = []
 | 
						|
    code = []
 | 
						|
    for xml in xml_events:
 | 
						|
        for event in orig[xml][2:]:
 | 
						|
            event_start = _ns2py_time(event["beginTime"])
 | 
						|
            start = (event_start - start_time).total_seconds()
 | 
						|
            if event["code"] not in code:
 | 
						|
                code.append(event["code"])
 | 
						|
            marker = {
 | 
						|
                "name": event["code"],
 | 
						|
                "start": start,
 | 
						|
                "start_sample": int(np.fix(start * sfreq)),
 | 
						|
                "end": start + float(event["duration"]) / 1e9,
 | 
						|
                "chan": None,
 | 
						|
            }
 | 
						|
            markers.append(marker)
 | 
						|
    events_tims = dict()
 | 
						|
    for ev in code:
 | 
						|
        trig_samp = list(
 | 
						|
            c["start_sample"] for n, c in enumerate(markers) if c["name"] == ev
 | 
						|
        )
 | 
						|
        events_tims.update({ev: trig_samp})
 | 
						|
    return events_tims, code
 | 
						|
 | 
						|
 | 
						|
def _parse_xml(xml_file):
 | 
						|
    """Parse XML file."""
 | 
						|
    defusedxml = _soft_import("defusedxml", "reading EGI MFF data")
 | 
						|
    xml = defusedxml.ElementTree.parse(xml_file)
 | 
						|
    root = xml.getroot()
 | 
						|
    return _xml2list(root)
 | 
						|
 | 
						|
 | 
						|
def _xml2list(root):
 | 
						|
    """Parse XML item."""
 | 
						|
    output = []
 | 
						|
    for element in root:
 | 
						|
        if len(element) > 0:
 | 
						|
            if element[0].tag != element[-1].tag:
 | 
						|
                output.append(_xml2dict(element))
 | 
						|
            else:
 | 
						|
                output.append(_xml2list(element))
 | 
						|
 | 
						|
        elif element.text:
 | 
						|
            text = element.text.strip()
 | 
						|
            if text:
 | 
						|
                tag = _ns(element.tag)
 | 
						|
                output.append({tag: text})
 | 
						|
 | 
						|
    return output
 | 
						|
 | 
						|
 | 
						|
def _ns(s):
 | 
						|
    """Remove namespace, but only if there is a namespace to begin with."""
 | 
						|
    if "}" in s:
 | 
						|
        return "}".join(s.split("}")[1:])
 | 
						|
    else:
 | 
						|
        return s
 | 
						|
 | 
						|
 | 
						|
def _xml2dict(root):
 | 
						|
    """Use functions instead of Class.
 | 
						|
 | 
						|
    remove namespace based on
 | 
						|
    http://stackoverflow.com/questions/2148119
 | 
						|
    """
 | 
						|
    output = {}
 | 
						|
    if root.items():
 | 
						|
        output.update(dict(root.items()))
 | 
						|
 | 
						|
    for element in root:
 | 
						|
        if len(element) > 0:
 | 
						|
            if len(element) == 1 or element[0].tag != element[1].tag:
 | 
						|
                one_dict = _xml2dict(element)
 | 
						|
            else:
 | 
						|
                one_dict = {_ns(element[0].tag): _xml2list(element)}
 | 
						|
 | 
						|
            if element.items():
 | 
						|
                one_dict.update(dict(element.items()))
 | 
						|
            output.update({_ns(element.tag): one_dict})
 | 
						|
 | 
						|
        elif element.items():
 | 
						|
            output.update({_ns(element.tag): dict(element.items())})
 | 
						|
 | 
						|
        else:
 | 
						|
            output.update({_ns(element.tag): element.text})
 | 
						|
    return output
 | 
						|
 | 
						|
 | 
						|
def _ns2py_time(nstime):
 | 
						|
    """Parse times."""
 | 
						|
    nsdate = nstime[0:10]
 | 
						|
    nstime0 = nstime[11:26]
 | 
						|
    nstime00 = nsdate + " " + nstime0
 | 
						|
    pytime = datetime.strptime(nstime00, "%Y-%m-%d %H:%M:%S.%f")
 | 
						|
    return pytime
 | 
						|
 | 
						|
 | 
						|
def _combine_triggers(data, remapping=None):
 | 
						|
    """Combine binary triggers."""
 | 
						|
    new_trigger = np.zeros(data.shape[1])
 | 
						|
    if data.astype(bool).sum(axis=0).max() > 1:  # ensure no overlaps
 | 
						|
        logger.info(
 | 
						|
            "    Found multiple events at the same time "
 | 
						|
            "sample. Cannot create trigger channel."
 | 
						|
        )
 | 
						|
        return
 | 
						|
    if remapping is None:
 | 
						|
        remapping = np.arange(data) + 1
 | 
						|
    for d, event_id in zip(data, remapping):
 | 
						|
        idx = d.nonzero()
 | 
						|
        if np.any(idx):
 | 
						|
            new_trigger[idx] += event_id
 | 
						|
    return new_trigger
 | 
						|
 | 
						|
 | 
						|
def _triage_include_exclude(include, exclude, egi_events, egi_info):
 | 
						|
    """Triage include and exclude."""
 | 
						|
    _validate_type(exclude, (list, None), "exclude")
 | 
						|
    _validate_type(include, (list, None), "include")
 | 
						|
    event_codes = list(egi_info["event_codes"])
 | 
						|
    for name, lst in dict(exclude=exclude, include=include).items():
 | 
						|
        for ii, item in enumerate(lst or []):
 | 
						|
            what = f"{name}[{ii}]"
 | 
						|
            _validate_type(item, str, what)
 | 
						|
            if item not in event_codes:
 | 
						|
                raise ValueError(
 | 
						|
                    f"Could not find event channel named {what}={repr(item)}"
 | 
						|
                )
 | 
						|
    if include is None:
 | 
						|
        if exclude is None:
 | 
						|
            default_exclude = ["sync", "TREV"]
 | 
						|
            exclude = [code for code in default_exclude if code in event_codes]
 | 
						|
            for code, event in zip(event_codes, egi_events):
 | 
						|
                if event.sum() < 1 and code:
 | 
						|
                    exclude.append(code)
 | 
						|
            if (
 | 
						|
                len(exclude) == len(event_codes)
 | 
						|
                and egi_info["n_events"]
 | 
						|
                and set(exclude) - set(default_exclude)
 | 
						|
            ):
 | 
						|
                warn(
 | 
						|
                    "Did not find any event code with at least one event.",
 | 
						|
                    RuntimeWarning,
 | 
						|
                )
 | 
						|
        include = [k for k in event_codes if k not in exclude]
 | 
						|
    del exclude
 | 
						|
    excl_events = ", ".join(k for k in event_codes if k not in include)
 | 
						|
    logger.info(f"    Excluding events {{{excl_events}}} ...")
 | 
						|
    return include
 |