Source code for pyiron_workflow.draw

"""
Functions for drawing the graph.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

import graphviz
from pyiron_snippets.colors import SeabornColors

from pyiron_workflow.data import NOT_DATA

if TYPE_CHECKING:
    from pyiron_workflow.channels import Channel as WorkflowChannel  # noqa: F401
    from pyiron_workflow.channels import (
        DataChannel as WorkflowDataChannel,  # noqa: F401
    )
    from pyiron_workflow.channels import (
        SignalChannel as WorkflowSignalChannel,  # noqa: F401
    )
    from pyiron_workflow.io import DataIO, SignalIO
    from pyiron_workflow.node import Node as WorkflowNode


[docs] def directed_graph( name, label, rankdir, color_start, color_end, gradient_angle, size=None, style="filled", fillcolor_start=None, fillcolor_end=None, ): """A shortcut method for instantiating the type of graphviz graph we want""" if fillcolor_start is None: fillcolor_start = color_start if fillcolor_end is None: fillcolor_end = color_end digraph = graphviz.graphs.Digraph(name=name) digraph.attr( label=label, compound="true", rankdir=rankdir, style=style, fillcolor=f"{fillcolor_start}:{fillcolor_end}", color=f"{color_start}:{color_end}", gradientangle=gradient_angle, fontname="helvetica", size=size, ) return digraph
[docs] def reverse_rankdir(rankdir: Literal["LR", "TB"]): if rankdir == "LR": return "TB" elif rankdir == "TB": return "LR" else: raise ValueError(f"Expected rankdir of 'LR' or 'TB' but got {rankdir}")
def _to_hex(rgb: tuple[int, int, int]) -> str: """RGB [0,255] to hex color codes; no alpha values.""" return "#{:02x}{:02x}{:02x}".format(*tuple(int(c) for c in rgb)) def _to_rgb(hex_: str) -> tuple[int, int, int]: """Hex to RGB color codes; no alpha values.""" hex_ = hex_.lstrip("#") return ( int(hex_[0 : 0 + 2], 16), int(hex_[2 : 2 + 2], 16), int(hex_[4 : 4 + 2], 16), ) # mypy isn't smart enough to parse this as a 3-tuple from an iterator
[docs] def blend_colours(color_a, color_b, fraction_a=0.5): """Blends two hex code colours together""" return _to_hex( tuple( fraction_a * a + (1 - fraction_a) * b for (a, b) in zip(_to_rgb(color_a), _to_rgb(color_b), strict=False) ) )
[docs] def lighten_hex_color(color, lightness=0.7): """Blends the given hex code color with pure white.""" return blend_colours(SeabornColors.white, color, fraction_a=lightness)
[docs] class WorkflowGraphvizMap(ABC): """ A parent class defining the interface for the graphviz representation of all our workflow objects. """ @property @abstractmethod def parent(self) -> WorkflowGraphvizMap | None: pass @property @abstractmethod def name(self) -> str: pass @property @abstractmethod def label(self) -> str: pass @property @abstractmethod def graph(self) -> graphviz.graphs.Digraph: pass @property @abstractmethod def color(self) -> str: pass
WorkflowChannelType = TypeVar("WorkflowChannelType", bound="WorkflowChannel") class _Channel(WorkflowGraphvizMap, Generic[WorkflowChannelType], ABC): """ An abstract representation for channel objects, which are "nodes" in graphviz parlance. """ def __init__(self, parent: _IO, channel: WorkflowChannelType, local_name: str): self.channel: WorkflowChannelType = channel self._parent = parent self._name = self.parent.name + local_name self._label = local_name + self._build_label_suffix() self.graph.node( name=self.name, label=self.label, shape=self.shape, color=self.color, style=self.style, fontname="helvetica", ) @property @abstractmethod def shape(self) -> str: pass def _build_label_suffix(self): suffix = "" try: if self.channel.type_hint is not None: suffix += ": " + self.channel.type_hint.__name__ except AttributeError: pass # Signals have no type return suffix @property def parent(self) -> _IO: return self._parent @property def name(self) -> str: return self._name @property def label(self) -> str: return self._label @property def graph(self) -> graphviz.graphs.Digraph: return self.parent.graph @property def style(self) -> str: return "filled"
[docs] class DataChannel(_Channel["WorkflowDataChannel"]): @property def color(self) -> str: orange = "#EDB22C" return orange @property def shape(self) -> str: return "oval" @property def style(self) -> str: if len(self.channel.connections) == 0 and self.channel.value is NOT_DATA: return "bold" return "filled"
[docs] class SignalChannel(_Channel["WorkflowSignalChannel"]): @property def color(self) -> str: blue = "#21BFD8" return blue @property def shape(self) -> str: return "cds"
class _IO(WorkflowGraphvizMap, ABC): """ An abstract class for IO panels, which are represented as a "subgraph" in graphviz parlance. """ def __init__(self, parent: Node): self._parent = parent self.node: WorkflowNode = self.parent.node self.data_io, self.signals_io = self._get_node_io() self._name = self.parent.name + self.data_io.__class__.__name__ self._label = self.data_io.__class__.__name__ self._graph = directed_graph( self.name, self.label, rankdir=reverse_rankdir(self.parent.rankdir), color_start=self.color, color_end=lighten_hex_color(self.color), gradient_angle=self.gradient_angle, ) self.channels = [ SignalChannel(self, channel, panel_label) for panel_label, channel in self.signals_io.items() ] + [ DataChannel(self, channel, panel_label) for panel_label, channel in self.data_io.items() ] self.parent.graph.subgraph(self.graph) @abstractmethod def _get_node_io(self) -> tuple[DataIO, SignalIO]: pass @property @abstractmethod def gradient_angle(self) -> str: """Background fill colour angle in degrees""" @property def parent(self) -> Node: return self._parent @property def name(self) -> str: return self._name @property def label(self) -> str: return self._label @property def graph(self) -> graphviz.graphs.Digraph: return self._graph @property def color(self) -> str: gray = "#A5A4A5" return gray def __len__(self): return len(self.channels)
[docs] class Inputs(_IO): def _get_node_io(self) -> tuple[DataIO, SignalIO]: return self.node.inputs, self.node.signals.input @property def gradient_angle(self) -> str: return "0"
[docs] class Outputs(_IO): def _get_node_io(self) -> tuple[DataIO, SignalIO]: return self.node.outputs, self.node.signals.output @property def gradient_angle(self) -> str: return "180"
[docs] class Node(WorkflowGraphvizMap): """ A wrapper class to connect graphviz to our workflow nodes. The nodes are represented by a "graph" or "subgraph" in graphviz parlance (depending on whether the node being visualized is the top-most node or not). Visualized nodes show their label and type, and IO panels with label and type. Colors and shapes are exploited to differentiate various node classes, input/output, and data/signal channels. If the node is composite in nature and the `depth` argument is at least `1`, owned children are also visualized (recursively with `depth = depth - 1`) inside the scope of this node. Args: node (pyiron_workflow.node.Node): The node to visualize. parent (Optional[pyiron_workflow.draw.Node]): The visualization that owns this visualization (if any). depth (int): How deeply to decompose any child nodes beyond showing their IO. rankdir ("LR" | "TB"): Use left-right or top-bottom graphviz `rankdir`. size (tuple[int | float, int | float] | None): The size of the diagram, in inches(?); respects ratio by scaling until at least one dimension matches the requested size. (Default is None, automatically size.) """ def __init__( self, node: WorkflowNode, parent: Node | None = None, depth: int = 1, rankdir: Literal["LR", "TB"] = "LR", size: str | None = None, ): self.node = node self._parent = parent self._name = self.build_node_name() self._label = self.node.label + ": " + self.node.__class__.__name__ self.rankdir: Literal["LR", "TB"] = rankdir self._graph = directed_graph( self.name, self.label, rankdir=self.rankdir, color_start=self.color, color_end=self.color, gradient_angle="0", size=size, style=self.style, fillcolor_start=self.fillcolor, fillcolor_end=self.fillcolor, ) self.inputs = Inputs(self) self.outputs = Outputs(self) self.graph.edge( self.inputs.channels[0].name, self.outputs.channels[0].name, style="invis" ) if depth > 0: from pyiron_workflow.nodes.composite import Composite # noqa: PLC0415 # Janky in-line import to avoid circular imports but only look for children # where they exist (since nodes sometimes now actually do something on # failed attribute access, i.e. use it as delayed access on their output) if isinstance(self.node, Composite): self._connect_owned_nodes(depth) if self.parent is not None: self.parent.graph.subgraph(self.graph) def _channel_bicolor(self, start_channel, end_channel): return f"{start_channel.color};0.5:{end_channel.color};0.5" def _connect_owned_nodes(self, depth): nodes = [Node(node, self, depth - 1) for node in self.node.children.values()] internal_inputs = [ channel for node in nodes for channel in node.inputs.channels ] internal_outputs = [ channel for node in nodes for channel in node.outputs.channels ] # Loop to check for internal node output --> internal node input connections for output_channel in internal_outputs: for input_channel in internal_inputs: if input_channel.channel in output_channel.channel.connections: self.graph.edge( output_channel.name, input_channel.name, color=self._channel_bicolor(output_channel, input_channel), ) # Connect channels that are by-reference copies of each other # i.e. for Workflow IO to child IO self._connect_matching(self.inputs.channels, internal_inputs) self._connect_matching(internal_outputs, self.outputs.channels) # Connect channels that are value-linked # i.e. for Macro IO to child IO self._connect_linked(self.inputs.channels, internal_inputs) self._connect_linked(internal_outputs, self.outputs.channels) def _connect_matching(self, sources: list[_Channel], destinations: list[_Channel]): """ Draw an edge between two graph channels whose workflow channels are the same """ for source in sources: for destination in destinations: if source.channel is destination.channel: self.graph.edge( source.name, destination.name, color=self._channel_bicolor(source, destination), ) def _connect_linked(self, sources: list[_Channel], destinations: list[_Channel]): """ Draw an edge between two graph channels values are linked """ for source in sources: for destination in destinations: if ( hasattr(source.channel, "value_receiver") and source.channel.value_receiver is destination.channel ): self.graph.edge( source.name, destination.name, color=self._channel_bicolor(source, destination), style="dashed", )
[docs] def build_node_name(self, suffix=""): if self.parent is not None: # Recursively prepend parent labels to get a totally unique label string # (inside the scope of this graph) return self.parent.build_node_name(suffix=suffix + self.node.label) else: return "cluster" + self.node.label + suffix
@property def parent(self) -> Node | None: return self._parent @property def name(self) -> str: return self._name @property def label(self) -> str: return self._label @property def graph(self) -> graphviz.graphs.Digraph: return self._graph @property def color(self) -> str: bright_red = "#FF0000" if self.node.failed: return bright_red elif self.node.running: return SeabornColors.orange elif self.node.cache_hit and self.node.outputs.ready: # Successfully ran return SeabornColors.green else: return self.fillcolor @property def fillcolor(self) -> str: return lighten_hex_color(self.node.color) @property def style(self) -> str: return "filled, bold"