"""Common utilities for Numba operations with groupby ops""" from __future__ import annotations import inspect from typing import ( TYPE_CHECKING, Any, Callable, ) import numpy as np from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, NumbaUtilError, get_jit_arguments, jit_user_function, ) def validate_udf(func: Callable) -> None: """ Validate user defined function for ops when using Numba with groupby ops. The first signature arguments should include: def f(values, index, ...): ... Parameters ---------- func : function, default False user defined function Returns ------- None Raises ------ NumbaUtilError """ udf_signature = list(inspect.signature(func).parameters.keys()) expected_args = ["values", "index"] min_number_args = len(expected_args) if ( len(udf_signature) < min_number_args or udf_signature[:min_number_args] != expected_args ): raise NumbaUtilError( f"The first {min_number_args} arguments to {func.__name__} must be " f"{expected_args}" ) def generate_numba_agg_func( kwargs: dict[str, Any], func: Callable[..., Scalar], engine_kwargs: dict[str, bool] | None, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted agg function specified by values from engine_kwargs. 1. jit the user's function 2. Return a groupby agg function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's function _AND_ the groupby evaluation loop. Parameters ---------- kwargs : dict **kwargs to be passed into the function func : function function to be applied to each window and will be JITed engine_kwargs : dict dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) validate_udf(func) cache_key = (func, "groupby_agg") if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba else: numba = import_optional_dependency("numba") @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_agg( values: np.ndarray, index: np.ndarray, begin: np.ndarray, end: np.ndarray, num_columns: int, *args: Any, ) -> np.ndarray: assert len(begin) == len(end) num_groups = len(begin) result = np.empty((num_groups, num_columns)) for i in numba.prange(num_groups): group_index = index[begin[i] : end[i]] for j in numba.prange(num_columns): group = values[begin[i] : end[i], j] result[i, j] = numba_func(group, group_index, *args) return result return group_agg def generate_numba_transform_func( kwargs: dict[str, Any], func: Callable[..., np.ndarray], engine_kwargs: dict[str, bool] | None, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted transform function specified by values from engine_kwargs. 1. jit the user's function 2. Return a groupby transform function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's function _AND_ the groupby evaluation loop. Parameters ---------- kwargs : dict **kwargs to be passed into the function func : function function to be applied to each window and will be JITed engine_kwargs : dict dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) validate_udf(func) cache_key = (func, "groupby_transform") if cache_key in NUMBA_FUNC_CACHE: return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba else: numba = import_optional_dependency("numba") @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_transform( values: np.ndarray, index: np.ndarray, begin: np.ndarray, end: np.ndarray, num_columns: int, *args: Any, ) -> np.ndarray: assert len(begin) == len(end) num_groups = len(begin) result = np.empty((len(values), num_columns)) for i in numba.prange(num_groups): group_index = index[begin[i] : end[i]] for j in numba.prange(num_columns): group = values[begin[i] : end[i], j] result[begin[i] : end[i], j] = numba_func(group, group_index, *args) return result return group_transform