123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Dict,
 | 
						|
    Optional,
 | 
						|
)
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from pandas.compat._optional import import_optional_dependency
 | 
						|
 | 
						|
from pandas.core.util.numba_ import (
 | 
						|
    NUMBA_FUNC_CACHE,
 | 
						|
    get_jit_arguments,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]):
 | 
						|
    """
 | 
						|
    Generate a numba jitted groupby ewma function specified by values
 | 
						|
    from engine_kwargs.
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    engine_kwargs : dict
 | 
						|
        dictionary of arguments to be passed into numba.jit
 | 
						|
    Returns
 | 
						|
    -------
 | 
						|
    Numba function
 | 
						|
    """
 | 
						|
    nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
 | 
						|
 | 
						|
    cache_key = (lambda x: x, "online_ewma")
 | 
						|
    if cache_key in NUMBA_FUNC_CACHE:
 | 
						|
        return NUMBA_FUNC_CACHE[cache_key]
 | 
						|
 | 
						|
    if TYPE_CHECKING:
 | 
						|
        import numba
 | 
						|
    else:
 | 
						|
        numba = import_optional_dependency("numba")
 | 
						|
 | 
						|
    @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
 | 
						|
    def online_ewma(
 | 
						|
        values: np.ndarray,
 | 
						|
        deltas: np.ndarray,
 | 
						|
        minimum_periods: int,
 | 
						|
        old_wt_factor: float,
 | 
						|
        new_wt: float,
 | 
						|
        old_wt: np.ndarray,
 | 
						|
        adjust: bool,
 | 
						|
        ignore_na: bool,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Compute online exponentially weighted mean per column over 2D values.
 | 
						|
 | 
						|
        Takes the first observation as is, then computes the subsequent
 | 
						|
        exponentially weighted mean accounting minimum periods.
 | 
						|
        """
 | 
						|
        result = np.empty(values.shape)
 | 
						|
        weighted_avg = values[0]
 | 
						|
        nobs = (~np.isnan(weighted_avg)).astype(np.int64)
 | 
						|
        result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
 | 
						|
 | 
						|
        for i in range(1, len(values)):
 | 
						|
            cur = values[i]
 | 
						|
            is_observations = ~np.isnan(cur)
 | 
						|
            nobs += is_observations.astype(np.int64)
 | 
						|
            for j in numba.prange(len(cur)):
 | 
						|
                if not np.isnan(weighted_avg[j]):
 | 
						|
                    if is_observations[j] or not ignore_na:
 | 
						|
 | 
						|
                        # note that len(deltas) = len(vals) - 1 and deltas[i] is to be
 | 
						|
                        # used in conjunction with vals[i+1]
 | 
						|
                        old_wt[j] *= old_wt_factor ** deltas[j - 1]
 | 
						|
                        if is_observations[j]:
 | 
						|
                            # avoid numerical errors on constant series
 | 
						|
                            if weighted_avg[j] != cur[j]:
 | 
						|
                                weighted_avg[j] = (
 | 
						|
                                    (old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
 | 
						|
                                ) / (old_wt[j] + new_wt)
 | 
						|
                            if adjust:
 | 
						|
                                old_wt[j] += new_wt
 | 
						|
                            else:
 | 
						|
                                old_wt[j] = 1.0
 | 
						|
                elif is_observations[j]:
 | 
						|
                    weighted_avg[j] = cur[j]
 | 
						|
 | 
						|
            result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
 | 
						|
 | 
						|
        return result, old_wt
 | 
						|
 | 
						|
    return online_ewma
 | 
						|
 | 
						|
 | 
						|
class EWMMeanState:
 | 
						|
    def __init__(self, com, adjust, ignore_na, axis, shape):
 | 
						|
        alpha = 1.0 / (1.0 + com)
 | 
						|
        self.axis = axis
 | 
						|
        self.shape = shape
 | 
						|
        self.adjust = adjust
 | 
						|
        self.ignore_na = ignore_na
 | 
						|
        self.new_wt = 1.0 if adjust else alpha
 | 
						|
        self.old_wt_factor = 1.0 - alpha
 | 
						|
        self.old_wt = np.ones(self.shape[self.axis - 1])
 | 
						|
        self.last_ewm = None
 | 
						|
 | 
						|
    def run_ewm(self, weighted_avg, deltas, min_periods, ewm_func):
 | 
						|
        result, old_wt = ewm_func(
 | 
						|
            weighted_avg,
 | 
						|
            deltas,
 | 
						|
            min_periods,
 | 
						|
            self.old_wt_factor,
 | 
						|
            self.new_wt,
 | 
						|
            self.old_wt,
 | 
						|
            self.adjust,
 | 
						|
            self.ignore_na,
 | 
						|
        )
 | 
						|
        self.old_wt = old_wt
 | 
						|
        self.last_ewm = result[-1]
 | 
						|
        return result
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        self.old_wt = np.ones(self.shape[self.axis - 1])
 | 
						|
        self.last_ewm = None
 |