import uuid from .plugin import SchedulerPlugin class GraphLayout(SchedulerPlugin): """Dynamic graph layout during computation This assigns (x, y) locations to all tasks quickly and dynamically as new tasks are added. This scales to a few thousand nodes. It is commonly used with distributed/dashboard/components/scheduler.py::TaskGraph, which is rendered at /graph on the diagnostic dashboard. """ def __init__(self, scheduler): self.name = f"graph-layout-{uuid.uuid4()}" self.x = {} self.y = {} self.collision = {} self.scheduler = scheduler self.index = {} self.index_edge = {} self.next_y = 0 self.next_index = 0 self.next_edge_index = 0 self.new = [] self.new_edges = [] self.state_updates = [] self.visible_updates = [] self.visible_edge_updates = [] if self.scheduler.tasks: dependencies = { k: [ds.key for ds in ts.dependencies] for k, ts in scheduler.tasks.items() } priority = {k: ts.priority for k, ts in scheduler.tasks.items()} self.update_graph( self.scheduler, tasks=self.scheduler.tasks, dependencies=dependencies, priority=priority, ) def update_graph( self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs ): stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True) while stack: key = stack.pop() if key in self.x or key not in scheduler.tasks: continue deps = dependencies.get(key, ()) if deps: if not all(dep in self.y for dep in deps): stack.append(key) stack.extend( sorted(deps, key=lambda k: priority.get(k, 0), reverse=True) ) continue else: total_deps = sum( len(scheduler.tasks[dep].dependents) for dep in deps ) y = sum( self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps for dep in deps ) x = max(self.x[dep] for dep in deps) + 1 else: x = 0 y = self.next_y self.next_y += 1 if (x, y) in self.collision: old_x, old_y = x, y x, y = self.collision[(x, y)] y += 0.1 self.collision[old_x, old_y] = (x, y) else: self.collision[(x, y)] = (x, y) self.x[key] = x self.y[key] = y self.index[key] = self.next_index self.next_index = self.next_index + 1 self.new.append(key) for dep in deps: edge = (dep, key) self.index_edge[edge] = self.next_edge_index self.next_edge_index += 1 self.new_edges.append(edge) def transition(self, key, start, finish, *args, **kwargs): if finish != "forgotten": self.state_updates.append((self.index[key], finish)) else: self.visible_updates.append((self.index[key], "False")) task = self.scheduler.tasks[key] for dep in task.dependents: edge = (key, dep.key) self.visible_edge_updates.append((self.index_edge.pop(edge), "False")) for dep in task.dependencies: self.visible_edge_updates.append( (self.index_edge.pop((dep.key, key)), "False") ) try: del self.collision[(self.x[key], self.y[key])] except KeyError: pass for collection in [self.x, self.y, self.index]: del collection[key] def reset_index(self): """Reset the index and refill new and new_edges From time to time TaskGraph wants to remove invisible nodes and reset all of its indices. This helps. """ self.new = [] self.new_edges = [] self.visible_updates = [] self.state_updates = [] self.visible_edge_updates = [] self.index = {} self.next_index = 0 self.index_edge = {} self.next_edge_index = 0 for key in self.x: self.index[key] = self.next_index self.next_index += 1 self.new.append(key) for dep in self.scheduler.tasks[key].dependencies: edge = (dep.key, key) self.index_edge[edge] = self.next_edge_index self.next_edge_index += 1 self.new_edges.append(edge)