from typing import ( TYPE_CHECKING, Dict, Optional, ) import numpy as np from pandas.compat._optional import import_optional_dependency from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, get_jit_arguments, ) def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]): """ Generate a numba jitted groupby ewma function specified by values from engine_kwargs. Parameters ---------- engine_kwargs : dict dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs) cache_key = (lambda x: x, "online_ewma") if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] if TYPE_CHECKING: import numba else: numba = import_optional_dependency("numba") @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def online_ewma( values: np.ndarray, deltas: np.ndarray, minimum_periods: int, old_wt_factor: float, new_wt: float, old_wt: np.ndarray, adjust: bool, ignore_na: bool, ): """ Compute online exponentially weighted mean per column over 2D values. Takes the first observation as is, then computes the subsequent exponentially weighted mean accounting minimum periods. """ result = np.empty(values.shape) weighted_avg = values[0] nobs = (~np.isnan(weighted_avg)).astype(np.int64) result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan) for i in range(1, len(values)): cur = values[i] is_observations = ~np.isnan(cur) nobs += is_observations.astype(np.int64) for j in numba.prange(len(cur)): if not np.isnan(weighted_avg[j]): if is_observations[j] or not ignore_na: # note that len(deltas) = len(vals) - 1 and deltas[i] is to be # used in conjunction with vals[i+1] old_wt[j] *= old_wt_factor ** deltas[j - 1] if is_observations[j]: # avoid numerical errors on constant series if weighted_avg[j] != cur[j]: weighted_avg[j] = ( (old_wt[j] * weighted_avg[j]) + (new_wt * cur[j]) ) / (old_wt[j] + new_wt) if adjust: old_wt[j] += new_wt else: old_wt[j] = 1.0 elif is_observations[j]: weighted_avg[j] = cur[j] result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan) return result, old_wt return online_ewma class EWMMeanState: def __init__(self, com, adjust, ignore_na, axis, shape): alpha = 1.0 / (1.0 + com) self.axis = axis self.shape = shape self.adjust = adjust self.ignore_na = ignore_na self.new_wt = 1.0 if adjust else alpha self.old_wt_factor = 1.0 - alpha self.old_wt = np.ones(self.shape[self.axis - 1]) self.last_ewm = None def run_ewm(self, weighted_avg, deltas, min_periods, ewm_func): result, old_wt = ewm_func( weighted_avg, deltas, min_periods, self.old_wt_factor, self.new_wt, self.old_wt, self.adjust, self.ignore_na, ) self.old_wt = old_wt self.last_ewm = result[-1] return result def reset(self): self.old_wt = np.ones(self.shape[self.axis - 1]) self.last_ewm = None