from __future__ import absolute_import, division, unicode_literals import param import numpy as np from matplotlib.collections import LineCollection, PolyCollection from ...core.data import Dataset from ...core.options import Cycle, abbreviated_exception from ...core.util import basestring, unique_array, search_indices, is_number, isscalar from ...util.transform import dim from ..mixins import ChordMixin from ..util import process_cmap, get_directed_graph_paths from .element import ColorbarPlot from .util import filter_styles class GraphPlot(ColorbarPlot): arrowhead_length = param.Number(default=0.025, doc=""" If directed option is enabled this determines the length of the arrows as fraction of the overall extent of the graph.""") directed = param.Boolean(default=False, doc=""" Whether to draw arrows on the graph edges to indicate the directionality of each edge.""") # Deprecated options color_index = param.ClassSelector(default=None, class_=(basestring, int), allow_None=True, doc=""" Deprecated in favor of color style mapping, e.g. `node_color=dim('color')`""") edge_color_index = param.ClassSelector(default=None, class_=(basestring, int), allow_None=True, doc=""" Deprecated in favor of color style mapping, e.g. `edge_color=dim('color')`""") style_opts = ['edge_alpha', 'edge_color', 'edge_linestyle', 'edge_linewidth', 'node_alpha', 'node_color', 'node_edgecolors', 'node_facecolors', 'node_linewidth', 'node_marker', 'node_size', 'visible', 'cmap', 'edge_cmap', 'node_cmap'] _style_groups = ['node', 'edge'] _nonvectorized_styles = [ 'edge_alpha', 'edge_linestyle', 'edge_cmap', 'cmap', 'visible', 'node_marker' ] filled = False def _compute_styles(self, element, ranges, style): elstyle = self.lookup_options(element, 'style') color = elstyle.kwargs.get('node_color') cdim = element.nodes.get_dimension(self.color_index) cmap = elstyle.kwargs.get('cmap', 'tab20') if cdim: cs = element.nodes.dimension_values(self.color_index) # Check if numeric otherwise treat as categorical if cs.dtype.kind == 'f': style['node_c'] = cs else: factors = unique_array(cs) cmap = color if isinstance(color, Cycle) else cmap if isinstance(cmap, dict): colors = [cmap.get(f, cmap.get('NaN', {'color': self._default_nan})['color']) for f in factors] else: colors = process_cmap(cmap, len(factors)) cs = search_indices(cs, factors) style['node_facecolors'] = [colors[v%len(colors)] for v in cs] style.pop('node_color', None) if 'node_c' in style: self._norm_kwargs(element.nodes, ranges, style, cdim) elif color and 'node_color' in style: style['node_facecolors'] = style.pop('node_color') style['node_edgecolors'] = style.pop('node_edgecolors', 'none') if is_number(style.get('node_size')): style['node_s'] = style.pop('node_size')**2 edge_cdim = element.get_dimension(self.edge_color_index) if not edge_cdim: if not isscalar(style.get('edge_color')): opt = 'edge_facecolors' if self.filled else 'edge_edgecolors' style[opt] = style.pop('edge_color') return style elstyle = self.lookup_options(element, 'style') cycle = elstyle.kwargs.get('edge_color') idx = element.get_dimension_index(edge_cdim) cvals = element.dimension_values(edge_cdim) if idx in [0, 1]: factors = element.nodes.dimension_values(2, expanded=False) elif idx == 2 and cvals.dtype.kind in 'uif': factors = None else: factors = unique_array(cvals) if factors is None or (factors.dtype.kind == 'f' and idx not in [0, 1]): style['edge_c'] = cvals else: cvals = search_indices(cvals, factors) factors = list(factors) cmap = elstyle.kwargs.get('edge_cmap', 'tab20') cmap = cycle if isinstance(cycle, Cycle) else cmap if isinstance(cmap, dict): colors = [cmap.get(f, cmap.get('NaN', {'color': self._default_nan})['color']) for f in factors] else: colors = process_cmap(cmap, len(factors)) style['edge_colors'] = [colors[v%len(colors)] for v in cvals] style.pop('edge_color', None) if 'edge_c' in style: self._norm_kwargs(element, ranges, style, edge_cdim, prefix='edge_') else: style.pop('edge_cmap', None) return style def get_data(self, element, ranges, style): with abbreviated_exception(): style = self._apply_transforms(element, ranges, style) xidx, yidx = (1, 0) if self.invert_axes else (0, 1) pxs, pys = (element.nodes.dimension_values(i) for i in range(2)) dims = element.nodes.dimensions() self._compute_styles(element, ranges, style) if 'edge_vmin' in style: style['edge_clim'] = (style.pop('edge_vmin'), style.pop('edge_vmax')) if 'edge_c' in style: style['edge_array'] = style.pop('edge_c') if self.directed: xdim, ydim = element.nodes.kdims[:2] x_range = ranges[xdim.name]['combined'] y_range = ranges[ydim.name]['combined'] arrow_len = np.hypot(y_range[1]-y_range[0], x_range[1]-x_range[0])*self.arrowhead_length paths = get_directed_graph_paths(element, arrow_len) else: paths = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims) if self.invert_axes: paths = [p[:, ::-1] for p in paths] return {'nodes': (pxs, pys), 'edges': paths}, style, {'dimensions': dims} def get_extents(self, element, ranges, range_type='combined'): return super(GraphPlot, self).get_extents(element.nodes, ranges, range_type) def init_artists(self, ax, plot_args, plot_kwargs): # Draw edges color_opts = ['c', 'cmap', 'vmin', 'vmax', 'norm'] groups = [g for g in self._style_groups if g != 'edge'] edge_opts = filter_styles(plot_kwargs, 'edge', groups, color_opts) if 'c' in edge_opts: edge_opts['array'] = edge_opts.pop('c') paths = plot_args['edges'] if self.filled: coll = PolyCollection if 'colors' in edge_opts: edge_opts['facecolors'] = edge_opts.pop('colors') else: coll = LineCollection edgecolors = edge_opts.pop('edgecolors', None) edges = coll(paths, **edge_opts) if edgecolors is not None: edges.set_edgecolors(edgecolors) ax.add_collection(edges) # Draw nodes xs, ys = plot_args['nodes'] groups = [g for g in self._style_groups if g != 'node'] node_opts = filter_styles(plot_kwargs, 'node', groups) nodes = ax.scatter(xs, ys, **node_opts) return {'nodes': nodes, 'edges': edges} def _update_nodes(self, element, data, style): nodes = self.handles['nodes'] xs, ys = data['nodes'] nodes.set_offsets(np.column_stack([xs, ys])) if 'node_facecolors' in style: nodes.set_facecolors(style['node_facecolors']) if 'node_edgecolors' in style: nodes.set_edgecolors(style['node_edgecolors']) if 'node_c' in style: nodes.set_array(style['node_c']) if 'node_vmin' in style: nodes.set_clim((style['node_vmin'], style['node_vmax'])) if 'node_norm' in style: nodes.norm = style['norm'] if 'node_linewidth' in style: nodes.set_linewidths(style['node_linewidth']) if 'node_s' in style: sizes = style['node_s'] if isscalar(sizes): nodes.set_sizes([sizes]) else: nodes.set_sizes(sizes) def _update_edges(self, element, data, style): edges = self.handles['edges'] paths = data['edges'] edges.set_paths(paths) edges.set_visible(style.get('visible', True)) if 'edge_facecolors' in style: edges.set_facecolors(style['edge_facecolors']) if 'edge_edgecolors' in style: edges.set_edgecolors(style['edge_edgecolors']) if 'edge_array' in style: edges.set_array(style['edge_array']) elif 'edge_colors' in style: if self.filled: edges.set_facecolors(style['edge_colors']) else: edges.set_edgecolors(style['edge_colors']) if 'edge_clim' in style: edges.set_clim(style['edge_clim']) if 'edge_norm' in style: edges.norm = style['edge_norm'] if 'edge_linewidth' in style: edges.set_linewidths(style['edge_linewidth']) def update_handles(self, key, axis, element, ranges, style): data, style, axis_kwargs = self.get_data(element, ranges, style) self._update_nodes(element, data, style) self._update_edges(element, data, style) return axis_kwargs class TriMeshPlot(GraphPlot): filled = param.Boolean(default=False, doc=""" Whether the triangles should be drawn as filled.""") style_opts = GraphPlot.style_opts + ['edge_facecolors'] def get_data(self, element, ranges, style): edge_color = style.get('edge_color') if edge_color not in element.nodes: edge_color = self.edge_color_index simplex_dim = element.get_dimension(edge_color) vertex_dim = element.nodes.get_dimension(edge_color) if not isinstance(self.edge_color_index, int) and vertex_dim and not simplex_dim: simplices = element.array([0, 1, 2]) z = element.nodes.dimension_values(vertex_dim) z = z[simplices].mean(axis=1) element = element.add_dimension(vertex_dim, len(element.vdims), z, vdim=True) # Ensure the edgepaths for the triangles are generated before plotting element.edgepaths return super(TriMeshPlot, self).get_data(element, ranges, style) class ChordPlot(ChordMixin, GraphPlot): labels = param.ClassSelector(class_=(basestring, dim), doc=""" The dimension or dimension value transform used to draw labels from.""") # Deprecated options label_index = param.ClassSelector(default=None, class_=(basestring, int), allow_None=True, doc=""" Index of the dimension from which the node labels will be drawn""") style_opts = GraphPlot.style_opts + ['text_font_size', 'label_offset'] _style_groups = ['edge', 'node', 'arc'] def get_data(self, element, ranges, style): data, style, plot_kwargs = super(ChordPlot, self).get_data(element, ranges, style) angles = element._angles paths = [] for i in range(len(element.nodes)): start, end = angles[i:i+2] vals = np.linspace(start, end, 20) paths.append(np.column_stack([np.cos(vals), np.sin(vals)])) data['arcs'] = paths if 'node_c' in style: style['arc_array'] = style['node_c'] style['arc_clim'] = style['node_vmin'], style['node_vmax'] style['arc_cmap'] = style['node_cmap'] elif 'node_facecolors' in style: style['arc_colors'] = style['node_facecolors'] style['arc_linewidth'] = 10 label_dim = element.nodes.get_dimension(self.label_index) labels = self.labels if label_dim and labels: self.param.warning( "Cannot declare style mapping for 'labels' option " "and declare a label_index; ignoring the label_index.") elif label_dim: labels = label_dim if isinstance(labels, basestring): labels = element.nodes.get_dimension(labels) if labels is None: return data, style, plot_kwargs nodes = element.nodes if element.vdims: values = element.dimension_values(element.vdims[0]) if values.dtype.kind in 'uif': edges = Dataset(element)[values>0] nodes = list(np.unique([edges.dimension_values(i) for i in range(2)])) nodes = element.nodes.select(**{element.nodes.kdims[2].name: nodes}) offset = style.get('label_offset', 1.05) xs, ys = (nodes.dimension_values(i)*offset for i in range(2)) if isinstance(labels, dim): text = labels.apply(element, flat=True) else: text = element.nodes.dimension_values(labels) text = [labels.pprint_value(v) for v in text] angles = np.rad2deg(np.arctan2(ys, xs)) data['text'] = (xs, ys, text, angles) return data, style, plot_kwargs def init_artists(self, ax, plot_args, plot_kwargs): artists = {} if 'arcs' in plot_args: color_opts = ['c', 'cmap', 'vmin', 'vmax', 'norm'] groups = [g for g in self._style_groups if g != 'arc'] edge_opts = filter_styles(plot_kwargs, 'arc', groups, color_opts) paths = plot_args['arcs'] edges = LineCollection(paths, **edge_opts) ax.add_collection(edges) artists['arcs'] = edges artists.update(super(ChordPlot, self).init_artists(ax, plot_args, plot_kwargs)) if 'text' in plot_args: fontsize = plot_kwargs.get('text_font_size', 8) labels = [] for (x, y, l, a) in zip(*plot_args['text']): label = ax.annotate(l, xy=(x, y), xycoords='data', rotation=a, horizontalalignment='left', fontsize=fontsize, verticalalignment='center', rotation_mode='anchor') labels.append(label) artists['labels'] = labels return artists def _update_arcs(self, element, data, style): edges = self.handles['arcs'] paths = data['arcs'] edges.set_paths(paths) edges.set_visible(style.get('visible', True)) if 'arc_array' in style: edges.set_array(style['arc_array']) if 'arc_clim' in style: edges.set_clim(style['arc_clim']) if 'arc_norm' in style: edges.set_norm(style['arc_norm']) if 'arc_colors' in style: edges.set_edgecolors(style['arc_colors']) elif 'arc_edgecolors' in style: edges.set_edgecolors(style['arc_edgecolors']) def _update_labels(self, ax, element, data, style): labels = self.handles.get('labels', []) for label in labels: try: label.remove() except: pass if 'text' not in data: self.handles['labels'] = [] return labels = [] fontsize = style.get('text_font_size', 8) for (x, y, l, a) in zip(*data['text']): label = ax.annotate(l, xy=(x, y), xycoords='data', rotation=a, horizontalalignment='left', fontsize=fontsize, verticalalignment='center', rotation_mode='anchor') labels.append(label) self.handles['labels'] = labels def update_handles(self, key, axis, element, ranges, style): data, style, axis_kwargs = self.get_data(element, ranges, style) self._update_nodes(element, data, style) self._update_edges(element, data, style) self._update_arcs(element, data, style) self._update_labels(axis, element, data, style) return axis_kwargs