"""Plotting functions for linear models (broadly construed).""" import copy from textwrap import dedent import warnings import numpy as np import pandas as pd from scipy.spatial import distance import matplotlib as mpl import matplotlib.pyplot as plt try: import statsmodels assert statsmodels _has_statsmodels = True except ImportError: _has_statsmodels = False from . import utils from . import algorithms as algo from .axisgrid import FacetGrid, _facet_docs from ._decorators import _deprecate_positional_args __all__ = ["lmplot", "regplot", "residplot"] class _LinearPlotter(object): """Base class for plotting relational data in tidy format. To get anything useful done you'll have to inherit from this, but setup code that can be abstracted out should be put here. """ def establish_variables(self, data, **kws): """Extract variables from data or use directly.""" self.data = data # Validate the inputs any_strings = any([isinstance(v, str) for v in kws.values()]) if any_strings and data is None: raise ValueError("Must pass `data` if using named variables.") # Set the variables for var, val in kws.items(): if isinstance(val, str): vector = data[val] elif isinstance(val, list): vector = np.asarray(val) else: vector = val if vector is not None and vector.shape != (1,): vector = np.squeeze(vector) if np.ndim(vector) > 1: err = "regplot inputs must be 1d" raise ValueError(err) setattr(self, var, vector) def dropna(self, *vars): """Remove observations with missing data.""" vals = [getattr(self, var) for var in vars] vals = [v for v in vals if v is not None] not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1) for var in vars: val = getattr(self, var) if val is not None: setattr(self, var, val[not_na]) def plot(self, ax): raise NotImplementedError class _RegressionPlotter(_LinearPlotter): """Plotter for numeric independent variables with regression model. This does the computations and drawing for the `regplot` function, and is thus also used indirectly by `lmplot`. """ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, seed=None, order=1, logistic=False, lowess=False, robust=False, logx=False, x_partial=None, y_partial=None, truncate=False, dropna=True, x_jitter=None, y_jitter=None, color=None, label=None): # Set member attributes self.x_estimator = x_estimator self.ci = ci self.x_ci = ci if x_ci == "ci" else x_ci self.n_boot = n_boot self.seed = seed self.scatter = scatter self.fit_reg = fit_reg self.order = order self.logistic = logistic self.lowess = lowess self.robust = robust self.logx = logx self.truncate = truncate self.x_jitter = x_jitter self.y_jitter = y_jitter self.color = color self.label = label # Validate the regression options: if sum((order > 1, logistic, robust, lowess, logx)) > 1: raise ValueError("Mutually exclusive regression options.") # Extract the data vals from the arguments or passed dataframe self.establish_variables(data, x=x, y=y, units=units, x_partial=x_partial, y_partial=y_partial) # Drop null observations if dropna: self.dropna("x", "y", "units", "x_partial", "y_partial") # Regress nuisance variables out of the data if self.x_partial is not None: self.x = self.regress_out(self.x, self.x_partial) if self.y_partial is not None: self.y = self.regress_out(self.y, self.y_partial) # Possibly bin the predictor variable, which implies a point estimate if x_bins is not None: self.x_estimator = np.mean if x_estimator is None else x_estimator x_discrete, x_bins = self.bin_predictor(x_bins) self.x_discrete = x_discrete else: self.x_discrete = self.x # Disable regression in case of singleton inputs if len(self.x) <= 1: self.fit_reg = False # Save the range of the x variable for the grid later if self.fit_reg: self.x_range = self.x.min(), self.x.max() @property def scatter_data(self): """Data where each observation is a point.""" x_j = self.x_jitter if x_j is None: x = self.x else: x = self.x + np.random.uniform(-x_j, x_j, len(self.x)) y_j = self.y_jitter if y_j is None: y = self.y else: y = self.y + np.random.uniform(-y_j, y_j, len(self.y)) return x, y @property def estimate_data(self): """Data with a point estimate and CI for each discrete x value.""" x, y = self.x_discrete, self.y vals = sorted(np.unique(x)) points, cis = [], [] for val in vals: # Get the point estimate of the y variable _y = y[x == val] est = self.x_estimator(_y) points.append(est) # Compute the confidence interval for this estimate if self.x_ci is None: cis.append(None) else: units = None if self.x_ci == "sd": sd = np.std(_y) _ci = est - sd, est + sd else: if self.units is not None: units = self.units[x == val] boots = algo.bootstrap(_y, func=self.x_estimator, n_boot=self.n_boot, units=units, seed=self.seed) _ci = utils.ci(boots, self.x_ci) cis.append(_ci) return vals, points, cis def fit_regression(self, ax=None, x_range=None, grid=None): """Fit the regression model.""" # Create the grid for the regression if grid is None: if self.truncate: x_min, x_max = self.x_range else: if ax is None: x_min, x_max = x_range else: x_min, x_max = ax.get_xlim() grid = np.linspace(x_min, x_max, 100) ci = self.ci # Fit the regression if self.order > 1: yhat, yhat_boots = self.fit_poly(grid, self.order) elif self.logistic: from statsmodels.genmod.generalized_linear_model import GLM from statsmodels.genmod.families import Binomial yhat, yhat_boots = self.fit_statsmodels(grid, GLM, family=Binomial()) elif self.lowess: ci = None grid, yhat = self.fit_lowess() elif self.robust: from statsmodels.robust.robust_linear_model import RLM yhat, yhat_boots = self.fit_statsmodels(grid, RLM) elif self.logx: yhat, yhat_boots = self.fit_logx(grid) else: yhat, yhat_boots = self.fit_fast(grid) # Compute the confidence interval at each grid point if ci is None: err_bands = None else: err_bands = utils.ci(yhat_boots, ci, axis=0) return grid, yhat, err_bands def fit_fast(self, grid): """Low-level regression and prediction using linear algebra.""" def reg_func(_x, _y): return np.linalg.pinv(_x).dot(_y) X, y = np.c_[np.ones(len(self.x)), self.x], self.y grid = np.c_[np.ones(len(grid)), grid] yhat = grid.dot(reg_func(X, y)) if self.ci is None: return yhat, None beta_boots = algo.bootstrap(X, y, func=reg_func, n_boot=self.n_boot, units=self.units, seed=self.seed).T yhat_boots = grid.dot(beta_boots).T return yhat, yhat_boots def fit_poly(self, grid, order): """Regression using numpy polyfit for higher-order trends.""" def reg_func(_x, _y): return np.polyval(np.polyfit(_x, _y, order), grid) x, y = self.x, self.y yhat = reg_func(x, y) if self.ci is None: return yhat, None yhat_boots = algo.bootstrap(x, y, func=reg_func, n_boot=self.n_boot, units=self.units, seed=self.seed) return yhat, yhat_boots def fit_statsmodels(self, grid, model, **kwargs): """More general regression function using statsmodels objects.""" import statsmodels.genmod.generalized_linear_model as glm X, y = np.c_[np.ones(len(self.x)), self.x], self.y grid = np.c_[np.ones(len(grid)), grid] def reg_func(_x, _y): try: yhat = model(_y, _x, **kwargs).fit().predict(grid) except glm.PerfectSeparationError: yhat = np.empty(len(grid)) yhat.fill(np.nan) return yhat yhat = reg_func(X, y) if self.ci is None: return yhat, None yhat_boots = algo.bootstrap(X, y, func=reg_func, n_boot=self.n_boot, units=self.units, seed=self.seed) return yhat, yhat_boots def fit_lowess(self): """Fit a locally-weighted regression, which returns its own grid.""" from statsmodels.nonparametric.smoothers_lowess import lowess grid, yhat = lowess(self.y, self.x).T return grid, yhat def fit_logx(self, grid): """Fit the model in log-space.""" X, y = np.c_[np.ones(len(self.x)), self.x], self.y grid = np.c_[np.ones(len(grid)), np.log(grid)] def reg_func(_x, _y): _x = np.c_[_x[:, 0], np.log(_x[:, 1])] return np.linalg.pinv(_x).dot(_y) yhat = grid.dot(reg_func(X, y)) if self.ci is None: return yhat, None beta_boots = algo.bootstrap(X, y, func=reg_func, n_boot=self.n_boot, units=self.units, seed=self.seed).T yhat_boots = grid.dot(beta_boots).T return yhat, yhat_boots def bin_predictor(self, bins): """Discretize a predictor by assigning value to closest bin.""" x = self.x if np.isscalar(bins): percentiles = np.linspace(0, 100, bins + 2)[1:-1] bins = np.c_[np.percentile(x, percentiles)] else: bins = np.c_[np.ravel(bins)] dist = distance.cdist(np.c_[x], bins) x_binned = bins[np.argmin(dist, axis=1)].ravel() return x_binned, bins.ravel() def regress_out(self, a, b): """Regress b from a keeping a's original mean.""" a_mean = a.mean() a = a - a_mean b = b - b.mean() b = np.c_[b] a_prime = a - b.dot(np.linalg.pinv(b).dot(a)) return np.asarray(a_prime + a_mean).reshape(a.shape) def plot(self, ax, scatter_kws, line_kws): """Draw the full plot.""" # Insert the plot label into the correct set of keyword arguments if self.scatter: scatter_kws["label"] = self.label else: line_kws["label"] = self.label # Use the current color cycle state as a default if self.color is None: lines, = ax.plot([], []) color = lines.get_color() lines.remove() else: color = self.color # Ensure that color is hex to avoid matplotlib weirdness color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color)) # Let color in keyword arguments override overall plot color scatter_kws.setdefault("color", color) line_kws.setdefault("color", color) # Draw the constituent plots if self.scatter: self.scatterplot(ax, scatter_kws) if self.fit_reg: self.lineplot(ax, line_kws) # Label the axes if hasattr(self.x, "name"): ax.set_xlabel(self.x.name) if hasattr(self.y, "name"): ax.set_ylabel(self.y.name) def scatterplot(self, ax, kws): """Draw the data.""" # Treat the line-based markers specially, explicitly setting larger # linewidth than is provided by the seaborn style defaults. # This would ideally be handled better in matplotlib (i.e., distinguish # between edgewidth for solid glyphs and linewidth for line glyphs # but this should do for now. line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"] if self.x_estimator is None: if "marker" in kws and kws["marker"] in line_markers: lw = mpl.rcParams["lines.linewidth"] else: lw = mpl.rcParams["lines.markeredgewidth"] kws.setdefault("linewidths", lw) if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4: kws.setdefault("alpha", .8) x, y = self.scatter_data ax.scatter(x, y, **kws) else: # TODO abstraction ci_kws = {"color": kws["color"]} ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75 kws.setdefault("s", 50) xs, ys, cis = self.estimate_data if [ci for ci in cis if ci is not None]: for x, ci in zip(xs, cis): ax.plot([x, x], ci, **ci_kws) ax.scatter(xs, ys, **kws) def lineplot(self, ax, kws): """Draw the model.""" # Fit the regression model grid, yhat, err_bands = self.fit_regression(ax) edges = grid[0], grid[-1] # Get set default aesthetics fill_color = kws["color"] lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5) kws.setdefault("linewidth", lw) # Draw the regression line and confidence interval line, = ax.plot(grid, yhat, **kws) if not self.truncate: line.sticky_edges.x[:] = edges # Prevent mpl from adding margin if err_bands is not None: ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15) _regression_docs = dict( model_api=dedent("""\ There are a number of mutually exclusive options for estimating the regression model. See the :ref:`tutorial ` for more information.\ """), regplot_vs_lmplot=dedent("""\ The :func:`regplot` and :func:`lmplot` functions are closely related, but the former is an axes-level function while the latter is a figure-level function that combines :func:`regplot` and :class:`FacetGrid`.\ """), x_estimator=dedent("""\ x_estimator : callable that maps vector -> scalar, optional Apply this function to each unique value of ``x`` and plot the resulting estimate. This is useful when ``x`` is a discrete variable. If ``x_ci`` is given, this estimate will be bootstrapped and a confidence interval will be drawn.\ """), x_bins=dedent("""\ x_bins : int or vector, optional Bin the ``x`` variable into discrete bins and then estimate the central tendency and a confidence interval. This binning only influences how the scatterplot is drawn; the regression is still fit to the original data. This parameter is interpreted either as the number of evenly-sized (not necessary spaced) bins or the positions of the bin centers. When this parameter is used, it implies that the default of ``x_estimator`` is ``numpy.mean``.\ """), x_ci=dedent("""\ x_ci : "ci", "sd", int in [0, 100] or None, optional Size of the confidence interval used when plotting a central tendency for discrete values of ``x``. If ``"ci"``, defer to the value of the ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the standard deviation of the observations in each bin.\ """), scatter=dedent("""\ scatter : bool, optional If ``True``, draw a scatterplot with the underlying observations (or the ``x_estimator`` values).\ """), fit_reg=dedent("""\ fit_reg : bool, optional If ``True``, estimate and plot a regression model relating the ``x`` and ``y`` variables.\ """), ci=dedent("""\ ci : int in [0, 100] or None, optional Size of the confidence interval for the regression estimate. This will be drawn using translucent bands around the regression line. The confidence interval is estimated using a bootstrap; for large datasets, it may be advisable to avoid that computation by setting this parameter to None.\ """), n_boot=dedent("""\ n_boot : int, optional Number of bootstrap resamples used to estimate the ``ci``. The default value attempts to balance time and stability; you may want to increase this value for "final" versions of plots.\ """), units=dedent("""\ units : variable name in ``data``, optional If the ``x`` and ``y`` observations are nested within sampling units, those can be specified here. This will be taken into account when computing the confidence intervals by performing a multilevel bootstrap that resamples both units and observations (within unit). This does not otherwise influence how the regression is estimated or drawn.\ """), seed=dedent("""\ seed : int, numpy.random.Generator, or numpy.random.RandomState, optional Seed or random number generator for reproducible bootstrapping.\ """), order=dedent("""\ order : int, optional If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a polynomial regression.\ """), logistic=dedent("""\ logistic : bool, optional If ``True``, assume that ``y`` is a binary variable and use ``statsmodels`` to estimate a logistic regression model. Note that this is substantially more computationally intensive than linear regression, so you may wish to decrease the number of bootstrap resamples (``n_boot``) or set ``ci`` to None.\ """), lowess=dedent("""\ lowess : bool, optional If ``True``, use ``statsmodels`` to estimate a nonparametric lowess model (locally weighted linear regression). Note that confidence intervals cannot currently be drawn for this kind of model.\ """), robust=dedent("""\ robust : bool, optional If ``True``, use ``statsmodels`` to estimate a robust regression. This will de-weight outliers. Note that this is substantially more computationally intensive than standard linear regression, so you may wish to decrease the number of bootstrap resamples (``n_boot``) or set ``ci`` to None.\ """), logx=dedent("""\ logx : bool, optional If ``True``, estimate a linear regression of the form y ~ log(x), but plot the scatterplot and regression model in the input space. Note that ``x`` must be positive for this to work.\ """), xy_partial=dedent("""\ {x,y}_partial : strings in ``data`` or matrices Confounding variables to regress out of the ``x`` or ``y`` variables before plotting.\ """), truncate=dedent("""\ truncate : bool, optional If ``True``, the regression line is bounded by the data limits. If ``False``, it extends to the ``x`` axis limits. """), xy_jitter=dedent("""\ {x,y}_jitter : floats, optional Add uniform random noise of this size to either the ``x`` or ``y`` variables. The noise is added to a copy of the data after fitting the regression, and only influences the look of the scatterplot. This can be helpful when plotting variables that take discrete values.\ """), scatter_line_kws=dedent("""\ {scatter,line}_kws : dictionaries Additional keyword arguments to pass to ``plt.scatter`` and ``plt.plot``.\ """), ) _regression_docs.update(_facet_docs) @_deprecate_positional_args def lmplot( *, x=None, y=None, data=None, hue=None, col=None, row=None, # TODO move before data once * is enforced palette=None, col_wrap=None, height=5, aspect=1, markers="o", sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None, legend=True, legend_out=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, seed=None, order=1, logistic=False, lowess=False, robust=False, logx=False, x_partial=None, y_partial=None, truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None, line_kws=None, facet_kws=None, size=None, ): # Handle deprecations if size is not None: height = size msg = ("The `size` parameter has been renamed to `height`; " "please update your code.") warnings.warn(msg, UserWarning) if facet_kws is None: facet_kws = {} def facet_kw_deprecation(key, val): msg = ( f"{key} is deprecated from the `lmplot` function signature. " "Please update your code to pass it using `facet_kws`." ) if val is not None: warnings.warn(msg, UserWarning) facet_kws[key] = val facet_kw_deprecation("sharex", sharex) facet_kw_deprecation("sharey", sharey) facet_kw_deprecation("legend_out", legend_out) if data is None: raise TypeError("Missing required keyword argument `data`.") # Reduce the dataframe to only needed columns need_cols = [x, y, hue, col, row, units, x_partial, y_partial] cols = np.unique([a for a in need_cols if a is not None]).tolist() data = data[cols] # Initialize the grid facets = FacetGrid( data, row=row, col=col, hue=hue, palette=palette, row_order=row_order, col_order=col_order, hue_order=hue_order, height=height, aspect=aspect, col_wrap=col_wrap, **facet_kws, ) # Add the markers here as FacetGrid has figured out how many levels of the # hue variable are needed and we don't want to duplicate that process if facets.hue_names is None: n_markers = 1 else: n_markers = len(facets.hue_names) if not isinstance(markers, list): markers = [markers] * n_markers if len(markers) != n_markers: raise ValueError(("markers must be a singeton or a list of markers " "for each level of the hue variable")) facets.hue_kws = {"marker": markers} def update_datalim(data, x, y, ax, **kws): xys = np.asarray(data[[x, y]]).astype(float) ax.update_datalim(xys, updatey=False) ax.autoscale_view(scaley=False) facets.map_dataframe(update_datalim, x=x, y=y) # Draw the regression plot on each facet regplot_kws = dict( x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci, scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units, seed=seed, order=order, logistic=logistic, lowess=lowess, robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial, truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter, scatter_kws=scatter_kws, line_kws=line_kws, ) facets.map_dataframe(regplot, x=x, y=y, **regplot_kws) facets.set_axis_labels(x, y) # Add a legend if legend and (hue is not None) and (hue not in [col, row]): facets.add_legend() return facets lmplot.__doc__ = dedent("""\ Plot data and regression model fits across a FacetGrid. This function combines :func:`regplot` and :class:`FacetGrid`. It is intended as a convenient interface to fit regression models across conditional subsets of a dataset. When thinking about how to assign variables to different facets, a general rule is that it makes sense to use ``hue`` for the most important comparison, followed by ``col`` and ``row``. However, always think about your particular dataset and the goals of the visualization you are creating. {model_api} The parameters to this function span most of the options in :class:`FacetGrid`, although there may be occasional cases where you will want to use that class and :func:`regplot` directly. Parameters ---------- x, y : strings, optional Input variables; these should be column names in ``data``. {data} hue, col, row : strings Variables that define subsets of the data, which will be drawn on separate facets in the grid. See the ``*_order`` parameters to control the order of levels of this variable. {palette} {col_wrap} {height} {aspect} markers : matplotlib marker code or list of marker codes, optional Markers for the scatterplot. If a list, each marker in the list will be used for each level of the ``hue`` variable. {share_xy} .. deprecated:: 0.12.0 Pass using the `facet_kws` dictionary. {{hue,col,row}}_order : lists, optional Order for the levels of the faceting variables. By default, this will be the order that the levels appear in ``data`` or, if the variables are pandas categoricals, the category order. legend : bool, optional If ``True`` and there is a ``hue`` variable, add a legend. {legend_out} .. deprecated:: 0.12.0 Pass using the `facet_kws` dictionary. {x_estimator} {x_bins} {x_ci} {scatter} {fit_reg} {ci} {n_boot} {units} {seed} {order} {logistic} {lowess} {robust} {logx} {xy_partial} {truncate} {xy_jitter} {scatter_line_kws} facet_kws : dict Dictionary of keyword arguments for :class:`FacetGrid`. See Also -------- regplot : Plot data and a conditional model fit. FacetGrid : Subplot grid for plotting conditional relationships. pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with ``kind="reg"``). Notes ----- {regplot_vs_lmplot} Examples -------- These examples focus on basic regression model plots to exhibit the various faceting options; see the :func:`regplot` docs for demonstrations of the other options for plotting the data and models. There are also other examples for how to manipulate plot using the returned object on the :class:`FacetGrid` docs. Plot a simple linear relationship between two variables: .. plot:: :context: close-figs >>> import seaborn as sns; sns.set_theme(color_codes=True) >>> tips = sns.load_dataset("tips") >>> g = sns.lmplot(x="total_bill", y="tip", data=tips) Condition on a third variable and plot the levels in different colors: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips) Use different markers as well as colors so the plot will reproduce to black-and-white more easily: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, ... markers=["o", "x"]) Use a different color palette: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, ... palette="Set1") Map ``hue`` levels to colors with a dictionary: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, ... palette=dict(Yes="g", No="m")) Plot the levels of the third variable across different columns: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", col="smoker", data=tips) Change the height and aspect ratio of the facets: .. plot:: :context: close-figs >>> g = sns.lmplot(x="size", y="total_bill", hue="day", col="day", ... data=tips, height=6, aspect=.4, x_jitter=.1) Wrap the levels of the column variable into multiple rows: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", col="day", hue="day", ... data=tips, col_wrap=2, height=3) Condition on two variables to make a full grid: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time", ... data=tips, height=3) Use methods on the returned :class:`FacetGrid` instance to further tweak the plot: .. plot:: :context: close-figs >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time", ... data=tips, height=3) >>> g = (g.set_axis_labels("Total bill (US Dollars)", "Tip") ... .set(xlim=(0, 60), ylim=(0, 12), ... xticks=[10, 30, 50], yticks=[2, 6, 10]) ... .fig.subplots_adjust(wspace=.02)) """).format(**_regression_docs) @_deprecate_positional_args def regplot( *, x=None, y=None, data=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, seed=None, order=1, logistic=False, lowess=False, robust=False, logx=False, x_partial=None, y_partial=None, truncate=True, dropna=True, x_jitter=None, y_jitter=None, label=None, color=None, marker="o", scatter_kws=None, line_kws=None, ax=None ): plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci, scatter, fit_reg, ci, n_boot, units, seed, order, logistic, lowess, robust, logx, x_partial, y_partial, truncate, dropna, x_jitter, y_jitter, color, label) if ax is None: ax = plt.gca() scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws) scatter_kws["marker"] = marker line_kws = {} if line_kws is None else copy.copy(line_kws) plotter.plot(ax, scatter_kws, line_kws) return ax regplot.__doc__ = dedent("""\ Plot data and a linear regression model fit. {model_api} Parameters ---------- x, y: string, series, or vector array Input variables. If strings, these should correspond with column names in ``data``. When pandas objects are used, axes will be labeled with the series name. {data} {x_estimator} {x_bins} {x_ci} {scatter} {fit_reg} {ci} {n_boot} {units} {seed} {order} {logistic} {lowess} {robust} {logx} {xy_partial} {truncate} {xy_jitter} label : string Label to apply to either the scatterplot or regression line (if ``scatter`` is ``False``) for use in a legend. color : matplotlib color Color to apply to all plot elements; will be superseded by colors passed in ``scatter_kws`` or ``line_kws``. marker : matplotlib marker code Marker to use for the scatterplot glyphs. {scatter_line_kws} ax : matplotlib Axes, optional Axes object to draw the plot onto, otherwise uses the current Axes. Returns ------- ax : matplotlib Axes The Axes object containing the plot. See Also -------- lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple linear relationships in a dataset. jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with ``kind="reg"``). pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with ``kind="reg"``). residplot : Plot the residuals of a linear regression model. Notes ----- {regplot_vs_lmplot} It's also easy to combine combine :func:`regplot` and :class:`JointGrid` or :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot` functions, although these do not directly accept all of :func:`regplot`'s parameters. Examples -------- Plot the relationship between two variables in a DataFrame: .. plot:: :context: close-figs >>> import seaborn as sns; sns.set_theme(color_codes=True) >>> tips = sns.load_dataset("tips") >>> ax = sns.regplot(x="total_bill", y="tip", data=tips) Plot with two variables defined as numpy arrays; use a different color: .. plot:: :context: close-figs >>> import numpy as np; np.random.seed(8) >>> mean, cov = [4, 6], [(1.5, .7), (.7, 1)] >>> x, y = np.random.multivariate_normal(mean, cov, 80).T >>> ax = sns.regplot(x=x, y=y, color="g") Plot with two variables defined as pandas Series; use a different marker: .. plot:: :context: close-figs >>> import pandas as pd >>> x, y = pd.Series(x, name="x_var"), pd.Series(y, name="y_var") >>> ax = sns.regplot(x=x, y=y, marker="+") Use a 68% confidence interval, which corresponds with the standard error of the estimate, and extend the regression line to the axis limits: .. plot:: :context: close-figs >>> ax = sns.regplot(x=x, y=y, ci=68, truncate=False) Plot with a discrete ``x`` variable and add some jitter: .. plot:: :context: close-figs >>> ax = sns.regplot(x="size", y="total_bill", data=tips, x_jitter=.1) Plot with a discrete ``x`` variable showing means and confidence intervals for unique values: .. plot:: :context: close-figs >>> ax = sns.regplot(x="size", y="total_bill", data=tips, ... x_estimator=np.mean) Plot with a continuous variable divided into discrete bins: .. plot:: :context: close-figs >>> ax = sns.regplot(x=x, y=y, x_bins=4) Fit a higher-order polynomial regression: .. plot:: :context: close-figs >>> ans = sns.load_dataset("anscombe") >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "II"], ... scatter_kws={{"s": 80}}, ... order=2, ci=None) Fit a robust regression and don't plot a confidence interval: .. plot:: :context: close-figs >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "III"], ... scatter_kws={{"s": 80}}, ... robust=True, ci=None) Fit a logistic regression; jitter the y variable and use fewer bootstrap iterations: .. plot:: :context: close-figs >>> tips["big_tip"] = (tips.tip / tips.total_bill) > .175 >>> ax = sns.regplot(x="total_bill", y="big_tip", data=tips, ... logistic=True, n_boot=500, y_jitter=.03) Fit the regression model using log(x): .. plot:: :context: close-figs >>> ax = sns.regplot(x="size", y="total_bill", data=tips, ... x_estimator=np.mean, logx=True) """).format(**_regression_docs) @_deprecate_positional_args def residplot( *, x=None, y=None, data=None, lowess=False, x_partial=None, y_partial=None, order=1, robust=False, dropna=True, label=None, color=None, scatter_kws=None, line_kws=None, ax=None ): """Plot the residuals of a linear regression. This function will regress y on x (possibly as a robust or polynomial regression) and then draw a scatterplot of the residuals. You can optionally fit a lowess smoother to the residual plot, which can help in determining if there is structure to the residuals. Parameters ---------- x : vector or string Data or column name in `data` for the predictor variable. y : vector or string Data or column name in `data` for the response variable. data : DataFrame, optional DataFrame to use if `x` and `y` are column names. lowess : boolean, optional Fit a lowess smoother to the residual scatterplot. {x, y}_partial : matrix or string(s) , optional Matrix with same first dimension as `x`, or column name(s) in `data`. These variables are treated as confounding and are removed from the `x` or `y` variables before plotting. order : int, optional Order of the polynomial to fit when calculating the residuals. robust : boolean, optional Fit a robust linear regression when calculating the residuals. dropna : boolean, optional If True, ignore observations with missing data when fitting and plotting. label : string, optional Label that will be used in any plot legends. color : matplotlib color, optional Color to use for all elements of the plot. {scatter, line}_kws : dictionaries, optional Additional keyword arguments passed to scatter() and plot() for drawing the components of the plot. ax : matplotlib axis, optional Plot into this axis, otherwise grab the current axis or make a new one if not existing. Returns ------- ax: matplotlib axes Axes with the regression plot. See Also -------- regplot : Plot a simple linear regression model. jointplot : Draw a :func:`residplot` with univariate marginal distributions (when used with ``kind="resid"``). """ plotter = _RegressionPlotter(x, y, data, ci=None, order=order, robust=robust, x_partial=x_partial, y_partial=y_partial, dropna=dropna, color=color, label=label) if ax is None: ax = plt.gca() # Calculate the residual from a linear regression _, yhat, _ = plotter.fit_regression(grid=plotter.x) plotter.y = plotter.y - yhat # Set the regression option on the plotter if lowess: plotter.lowess = True else: plotter.fit_reg = False # Plot a horizontal line at 0 ax.axhline(0, ls=":", c=".2") # Draw the scatterplot scatter_kws = {} if scatter_kws is None else scatter_kws.copy() line_kws = {} if line_kws is None else line_kws.copy() plotter.plot(ax, scatter_kws, line_kws) return ax