"""Helper functions for graphics with Matplotlib.""" from statsmodels.compat.python import lrange __all__ = ['create_mpl_ax', 'create_mpl_fig'] def _import_mpl(): """This function is not needed outside this utils module.""" try: import matplotlib.pyplot as plt except: raise ImportError("Matplotlib is not found.") return plt def create_mpl_ax(ax=None): """Helper function for when a single plot axis is needed. Parameters ---------- ax : AxesSubplot, optional If given, this subplot is used to plot in instead of a new figure being created. Returns ------- fig : Figure If `ax` is None, the created figure. Otherwise the figure to which `ax` is connected. ax : AxesSubplot The created axis if `ax` is None, otherwise the axis that was passed in. Notes ----- This function imports `matplotlib.pyplot`, which should only be done to create (a) figure(s) with ``plt.figure``. All other functionality exposed by the pyplot module can and should be imported directly from its Matplotlib module. See Also -------- create_mpl_fig Examples -------- A plotting function has a keyword ``ax=None``. Then calls: >>> from statsmodels.graphics import utils >>> fig, ax = utils.create_mpl_ax(ax) """ if ax is None: plt = _import_mpl() fig = plt.figure() ax = fig.add_subplot(111) else: fig = ax.figure return fig, ax def create_mpl_fig(fig=None, figsize=None): """Helper function for when multiple plot axes are needed. Those axes should be created in the functions they are used in, with ``fig.add_subplot()``. Parameters ---------- fig : Figure, optional If given, this figure is simply returned. Otherwise a new figure is created. Returns ------- Figure If `fig` is None, the created figure. Otherwise the input `fig` is returned. See Also -------- create_mpl_ax """ if fig is None: plt = _import_mpl() fig = plt.figure(figsize=figsize) return fig def maybe_name_or_idx(idx, model): """ Give a name or an integer and return the name and integer location of the column in a design matrix. """ if idx is None: idx = lrange(model.exog.shape[1]) if isinstance(idx, int): exog_name = model.exog_names[idx] exog_idx = idx # anticipate index as list and recurse elif isinstance(idx, (tuple, list)): exog_name = [] exog_idx = [] for item in idx: exog_name_item, exog_idx_item = maybe_name_or_idx(item, model) exog_name.append(exog_name_item) exog_idx.append(exog_idx_item) else: # assume we've got a string variable exog_name = idx exog_idx = model.exog_names.index(idx) return exog_name, exog_idx def get_data_names(series_or_dataframe): """ Input can be an array or pandas-like. Will handle 1d array-like but not 2d. Returns a str for 1d data or a list of strings for 2d data. """ names = getattr(series_or_dataframe, 'name', None) if not names: names = getattr(series_or_dataframe, 'columns', None) if not names: shape = getattr(series_or_dataframe, 'shape', [1]) nvars = 1 if len(shape) == 1 else series_or_dataframe.shape[1] names = ["X%d" for _ in range(nvars)] if nvars == 1: names = names[0] else: names = names.tolist() return names def annotate_axes(index, labels, points, offset_points, size, ax, **kwargs): """ Annotate Axes with labels, points, offset_points according to the given index. """ for i in index: label = labels[i] point = points[i] offset = offset_points[i] ax.annotate(label, point, xytext=offset, textcoords="offset points", size=size, **kwargs) return ax