""" The `trendline_functions` module contains functions which are called by Plotly Express when the `trendline` argument is used. Valid values for `trendline` are the names of the functions in this module, and the value of the `trendline_options` argument to PX functions is passed in as the first argument to these functions when called. Note that the functions in this module are not meant to be called directly, and are exposed as part of the public API for documentation purposes. """ import pandas as pd import numpy as np __all__ = ["ols", "lowess", "rolling", "ewm", "expanding"] def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): """Ordinary Least Squares (OLS) trendline function Requires `statsmodels` to be installed. This trendline function causes fit results to be stored within the figure, accessible via the `plotly.express.get_trendline_results` function. The fit results are the output of the `statsmodels.api.OLS` function. Valid keys for the `trendline_options` dict are: - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through the origin but if `True` a y-intercept is fitted. - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with respect to the base 10 logarithm of the input. Note that this means no zeros can be present in the input. """ valid_options = ["add_constant", "log_x", "log_y"] for k in trendline_options.keys(): if k not in valid_options: raise ValueError( "OLS trendline_options keys must be one of [%s] but got '%s'" % (", ".join(valid_options), k) ) import statsmodels.api as sm add_constant = trendline_options.get("add_constant", True) log_x = trendline_options.get("log_x", False) log_y = trendline_options.get("log_y", False) if log_y: if np.any(y <= 0): raise ValueError( "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values." ) y = np.log10(y) y_label = "log10(%s)" % y_label if log_x: if np.any(x <= 0): raise ValueError( "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values." ) x = np.log10(x) x_label = "log10(%s)" % x_label if add_constant: x = sm.add_constant(x) fit_results = sm.OLS(y, x, missing="drop").fit() y_out = fit_results.predict() if log_y: y_out = np.power(10, y_out) hover_header = "OLS trendline
" if len(fit_results.params) == 2: hover_header += "%s = %g * %s + %g
" % ( y_label, fit_results.params[1], x_label, fit_results.params[0], ) elif not add_constant: hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label) else: hover_header += "%s = %g
" % (y_label, fit_results.params[0]) hover_header += "R2=%f

" % fit_results.rsquared return y_out, hover_header, fit_results def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function Requires `statsmodels` to be installed. Valid keys for the `trendline_options` dict are: - `frac` (`float`, default `0.6666666`): the `frac` parameter from the `statsmodels.api.nonparametric.lowess` function """ valid_options = ["frac"] for k in trendline_options.keys(): if k not in valid_options: raise ValueError( "LOWESS trendline_options keys must be one of [%s] but got '%s'" % (", ".join(valid_options), k) ) import statsmodels.api as sm frac = trendline_options.get("frac", 0.6666666) y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] hover_header = "LOWESS trendline

" return y_out, hover_header, None def _pandas(mode, trendline_options, x_raw, y, non_missing): modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") trendline_options = trendline_options.copy() function_name = trendline_options.pop("function", "mean") function_args = trendline_options.pop("function_args", dict()) series = pd.Series(y, index=x_raw) agg = getattr(series, mode) # e.g. series.rolling agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts) function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts) y_out = y_out[non_missing] hover_header = "%s %s trendline

" % (modes[mode], function_name) return y_out, hover_header, None def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing): """Rolling trendline function The value of the `function` key of the `trendline_options` dict is the function to use (defaults to `mean`) and the value of the `function_args` key are taken to be its arguments as a dict. The remainder of the `trendline_options` dict is passed as keyword arguments into the `pandas.Series.rolling` function. """ return _pandas("rolling", trendline_options, x_raw, y, non_missing) def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing): """Expanding trendline function The value of the `function` key of the `trendline_options` dict is the function to use (defaults to `mean`) and the value of the `function_args` key are taken to be its arguments as a dict. The remainder of the `trendline_options` dict is passed as keyword arguments into the `pandas.Series.expanding` function. """ return _pandas("expanding", trendline_options, x_raw, y, non_missing) def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing): """Exponentially Weighted Moment (EWM) trendline function The value of the `function` key of the `trendline_options` dict is the function to use (defaults to `mean`) and the value of the `function_args` key are taken to be its arguments as a dict. The remainder of the `trendline_options` dict is passed as keyword arguments into the `pandas.Series.ewm` function. """ return _pandas("ewm", trendline_options, x_raw, y, non_missing)