"""
A base class for nodal objects that have internal structure -- i.e. they hold a
sub-graph
"""
from __future__ import annotations
from abc import ABC
from collections.abc import Callable
from time import sleep
from typing import TYPE_CHECKING
import typeguard
from pyiron_snippets.colors import SeabornColors
from pyiron_snippets.dotdict import DotDict
from pyiron_workflow.create import HasCreator
from pyiron_workflow.mixin.lexical import LexicalParent
from pyiron_workflow.node import Node
from pyiron_workflow.topology import (
get_nodes_in_data_tree,
set_run_connections_according_to_dag,
set_run_connections_according_to_linear_dag,
)
if TYPE_CHECKING:
from pyiron_workflow.channels import (
InputSignal,
OutputSignal,
)
from pyiron_workflow.storage import BackendIdentifier, StorageInterface
def _get_graph_as_dict(composite: Composite) -> dict:
if not isinstance(composite, Composite):
return composite
return {
"object": composite,
"nodes": {n.full_label: _get_graph_as_dict(n) for n in composite},
"edges": {
"data": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.outputs
for inp in out.connections
},
"signal": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.signals.output
for inp in out.connections
},
},
}
[docs]
class FailedChildError(RuntimeError):
"""Raise when one or more child nodes raise exceptions."""
[docs]
class Composite(LexicalParent[Node], HasCreator, Node, ABC):
"""
A base class for nodes that have internal graph structure -- i.e. they hold a
collection of child nodes and their computation is to execute that graph.
Attributes:
strict_naming (bool): When true, repeated assignment of a new node to an
existing node label will raise an error, otherwise the label gets appended
with an index and the assignment proceeds. (Default is true: disallow assigning
to existing labels.)
create (Creator): A tool for adding new nodes to this subgraph.
provenance_by_completion (list[str]): The child nodes (by label) in the order
that they completed on the last :meth:`run` call.
provenance_by_execution (list[str]): The child nodes (by label) in the order
that they started executing on the last :meth:`run` call.
running_children (list[str]): The names of children who are currently running.
signal_queue (list[tuple[OutputSignal, InputSignal]]): Pending signal event
pairs from child execution flow connections.
starting_nodes (None | list[pyiron_workflow.node.Node]): A subset
of the owned nodes to be used on running. Only necessary if the execution graph
has been manually specified with `run` signals. (Default is an empty list.)
wrap (Wrappers): A tool for accessing node-creating decorators
"""
def __init__(
self,
*args,
label: str | None = None,
parent: Composite | None = None,
delete_existing_savefiles: bool = False,
autoload: BackendIdentifier | StorageInterface | None = None,
autorun: bool = False,
checkpoint: BackendIdentifier | StorageInterface | None = None,
strict_naming: bool = True,
**kwargs,
):
self.starting_nodes: list[Node] = []
self.provenance_by_execution: list[str] = []
self.provenance_by_completion: list[str] = []
self.running_children: list[str] = []
self.signal_queue: list[tuple[OutputSignal, InputSignal]] = []
self._child_sleep_interval = 0.01 # How long to wait when the signal_queue is
# empty but the running_children list is not
super().__init__(
*args,
label=label,
parent=parent,
delete_existing_savefiles=delete_existing_savefiles,
autoload=autoload,
autorun=autorun,
checkpoint=checkpoint,
strict_naming=strict_naming,
**kwargs,
)
[docs]
@classmethod
def child_type(cls) -> type[Node]:
return Node
[docs]
def activate_strict_hints(self):
"""Recursively activate strict type hints."""
super().activate_strict_hints()
for node in self:
node.activate_strict_hints()
[docs]
def deactivate_strict_hints(self):
"""Recursively de-activate strict type hints."""
super().deactivate_strict_hints()
for node in self:
node.deactivate_strict_hints()
@property
def use_cache(self) -> bool:
"""
Composite nodes determine the cache usage by the cache usage of all children
recursively.
Setting this property at the composite level sets it for all children
recursively.
"""
return all(c.use_cache for c in self.children.values())
@use_cache.setter
def use_cache(self, value: bool):
for c in self.children.values():
c.use_cache = value
@property
def cache_hit(self) -> bool:
return not any(c.running for c in self.children.values()) and super().cache_hit
def _on_cache_miss(self) -> None:
super()._on_cache_miss()
# Reset provenance and run status trackers
self.provenance_by_execution = []
self.provenance_by_completion = []
self.running_children = [n.label for n in self if n.running]
self.signal_queue = []
def _on_run(self):
if len(self.running_children) > 0: # Start from a broken process
for label in self.running_children:
if self.children[label]._is_using_wrapped_excutorlib_executor():
self.running_children.remove(label)
self.children[label].run()
# Running children will find serialized result and proceed,
# or raise an error because they're already running
else: # Start fresh
for node in self.starting_nodes:
node.run()
self._run_while_children_or_signals_exist()
return self
def _run_while_children_or_signals_exist(self):
errors = {}
while len(self.running_children) > 0 or len(self.signal_queue) > 0:
try:
firing, receiving = self.signal_queue.pop(0)
try:
receiving(firing)
except Exception as e:
errors[receiving.full_label] = e
except IndexError:
# The signal queue is empty, but there is still someone running...
sleep(self._child_sleep_interval)
if len(errors) == 1:
raise FailedChildError(
f"{self.full_label} encountered error in child: {errors}"
) from next(iter(errors.values()))
elif len(errors) > 1:
raise FailedChildError(
f"{self.full_label} encountered multiple errors in children: {errors}"
) from None
[docs]
def register_child_starting(self, child: Node) -> None:
"""
To be called by children when they start their run cycle.
Args:
child [Node]: The child that is finished and would like to fire its `ran`
signal. Should always be a child of `self`, but this is not explicitly
verified at runtime.
"""
self.provenance_by_execution.append(child.label)
self.running_children.append(child.label)
[docs]
def register_child_finished(self, child: Node) -> None:
"""
To be called by children when they are finished their run.
Args:
child [Node]: The child that is finished and would like to fire its `ran`
signal. Should always be a child of `self`, but this is not explicitly
verified at runtime.
"""
try:
self.running_children.remove(child.label)
self.provenance_by_completion.append(child.label)
except ValueError as e:
raise KeyError(
f"No element {child.label} to remove while {self.running_children}, "
f"{self.provenance_by_execution}, {self.provenance_by_completion}"
) from e
[docs]
def register_child_emitting(self, child: Node) -> None:
"""
To be called by children when they want to emit their signals.
Args:
child [Node]: The child that is finished and would like to fire its `ran`
signal (and possibly others). Should always be a child of `self`, but
this is not explicitly verified at runtime.
"""
for firing in child.emitting_channels:
for receiving in firing.connections:
self.signal_queue.append((firing, receiving))
@property
def run_args(self) -> tuple[tuple, dict]:
return (), {}
[docs]
def process_run_result(self, run_output):
if run_output is not self:
self._parse_remotely_executed_self(run_output)
return self._outputs_to_run_return()
def _parse_remotely_executed_self(self, other_self):
# Un-parent existing nodes before ditching them
for node in self:
node._parent = None
node._detached_parent_path = None
other_self.running = False # It's done now
state = self._get_state_from_remote_other(other_self)
self.__setstate__(state)
def _get_state_from_remote_other(self, other_self):
state = other_self.__getstate__()
state.pop("_executor") # Got overridden to None for __getstate__, so keep local
state.pop("_parent") # Got overridden to None for __getstate__, so keep local
state.pop("_detached_parent_path")
return state
[docs]
def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]:
"""
Disconnect all `signals.input.run` connections on all child nodes.
Returns:
list[tuple[InputSignal, OutputSignal]]: Any disconnected pairs.
"""
disconnected_pairs = []
for node in self.children.values():
disconnected_pairs.extend(node.signals.disconnect_run())
return disconnected_pairs
[docs]
def set_run_signals_to_dag_execution(self):
"""
Disconnects all `signals.input.run` connections among children and attempts to
reconnect these according to the DAG flow of the data. On success, sets the
starting nodes to just be the upstream-most node in this linear DAG flow.
"""
if len(self.children) > 0:
_, upstream_most_nodes = set_run_connections_according_to_dag(self.children)
self.starting_nodes = upstream_most_nodes
[docs]
def add_child(
self,
child: Node,
label: str | None = None,
strict_naming: bool | None = None,
) -> Node:
"""Add the node instance to this subgraph."""
self.clear_cache() # Reset cache after graph change
return super().add_child(child, label=label, strict_naming=strict_naming)
[docs]
def push_child(self, child: Node | str, *args, **kwargs):
"""
Run a child node in a "push" configuration.
Args:
child (Node|str): The child node to push.
*args: Additional positional arguments passed to the child node.
**kwargs: Additional keyword arguments passed to the child node.
Returns:
(Any | Future): The result of running the node, or a futures object (if
running on an executor).
"""
typeguard.check_type(child, Node | str)
problem: str | None = None
if isinstance(child, Node):
if child.parent is not self:
problem = child.full_label
else:
child_node = child
elif isinstance(child, str):
if child not in self.child_labels:
problem = child
else:
child_node = self.children[child]
if problem is not None:
raise ValueError(
f"Child {problem} not found among {self.full_label}'s children: "
f"{self.child_labels}"
)
return child_node.run(*args, **kwargs)
[docs]
def remove_child(self, child: Node | str) -> Node:
"""
Remove a child from the :attr:`children` collection, disconnecting it and
setting its :attr:`parent` to None.
Args:
child (Node|str): The child (or its label) to remove.
Returns:
(Node): The (now disconnected and de-parented) (former) child node.
"""
child = super().remove_child(child)
child.disconnect()
if child in self.starting_nodes:
self.starting_nodes.remove(child)
self.clear_cache() # Reset cache after graph change
return child
[docs]
def executor_shutdown(self, wait=True, *, cancel_futures=False):
"""
Invoke shutdown on the executor (if present), and recursively invoke shutdown
for children.
"""
super().executor_shutdown(wait=wait, cancel_futures=cancel_futures)
for node in self:
node.executor_shutdown(wait=wait, cancel_futures=cancel_futures)
def __setattr__(self, key: str, node: Node):
if isinstance(node, Composite) and key in ["_parent", "parent"]:
# This is an edge case for assigning a node to an attribute
super().__setattr__(key, node)
elif isinstance(node, Node):
self.add_child(node, label=key)
else:
super().__setattr__(key, node)
def __getitem__(self, item):
return self.__getattr__(item)
def __setitem__(self, key, value):
self.__setattr__(key, value)
@property
def color(self) -> str:
"""For drawing the graph"""
return SeabornColors.brown
@property
def graph_as_dict(self) -> dict:
"""
A nested dictionary representation of the computation graph using full labels
as keys and objects as values.
"""
return _get_graph_as_dict(self)
def _get_connections_as_strings(
self, panel_getter: Callable
) -> list[tuple[tuple[str, str], tuple[str, str]]]:
"""
Connections between children in string representation based on labels.
The string representation helps storage, and having it as a property ensures
the name is protected.
"""
return [
((inp.owner.label, inp.label), (out.owner.label, out.label))
for child in self
for inp in panel_getter(child)
for out in inp.connections
]
@staticmethod
def _get_data_inputs(node: Node):
return node.inputs
@staticmethod
def _get_signals_input(node: Node):
return node.signals.input
@property
def _child_data_connections(self) -> list[tuple[tuple[str, str], tuple[str, str]]]:
return self._get_connections_as_strings(self._get_data_inputs)
@property
def _child_signal_connections(
self,
) -> list[tuple[tuple[str, str], tuple[str, str]]]:
return self._get_connections_as_strings(self._get_signals_input)
@property
def _starting_node_labels(self):
# As a property so it appears in `__dir__` and thus is guaranteed to not
# conflict with a child node name in the state
return tuple(n.label for n in self.starting_nodes)
def __getstate__(self):
state = super().__getstate__()
# Store connections as strings
state["_child_data_connections"] = self._child_data_connections
state["_child_signal_connections"] = self._child_signal_connections
# Also remove the starting node instances
del state["starting_nodes"]
state["_starting_node_labels"] = self._starting_node_labels
return state
def __setstate__(self, state):
# Purge child connection info from the state
child_data_connections = state.pop("_child_data_connections")
child_signal_connections = state.pop("_child_signal_connections")
# Restore starting nodes
state["starting_nodes"] = [
state[label] for label in state.pop("_starting_node_labels")
]
super().__setstate__(state)
# Nodes don't store connection information, so restore it to them
self._restore_data_connections_from_strings(child_data_connections)
self._restore_signal_connections_from_strings(child_signal_connections)
@staticmethod
def _restore_connections_from_strings(
nodes: dict[str, Node] | DotDict[str, Node],
connections: list[tuple[tuple[str, str], tuple[str, str]]],
input_panel_getter: Callable,
output_panel_getter: Callable,
) -> None:
"""
Set connections among a dictionary of nodes.
This is useful for recreating node connections after (de)serialization of the
individual nodes, which don't know about their connections (that information is
the responsibility of their parent `Composite`).
Args:
nodes (dict[Node]): The nodes to connect.
connections (list[tuple[tuple[str, str], tuple[str, str]]]): Connections
among these nodes in the format ((input node label, input channel label
), (output node label, output channel label)).
"""
for (inp_node, inp), (out_node, out) in connections:
input_panel_getter(nodes[inp_node])[inp].connect(
output_panel_getter(nodes[out_node])[out]
)
@staticmethod
def _get_data_outputs(node: Node):
return node.outputs
@staticmethod
def _get_signals_output(node: Node):
return node.signals.output
def _restore_data_connections_from_strings(
self, connections: list[tuple[tuple[str, str], tuple[str, str]]]
) -> None:
self._restore_connections_from_strings(
self.children,
connections,
self._get_data_inputs,
self._get_data_outputs,
)
def _restore_signal_connections_from_strings(
self, connections: list[tuple[tuple[str, str], tuple[str, str]]]
) -> None:
self._restore_connections_from_strings(
self.children,
connections,
self._get_signals_input,
self._get_signals_output,
)
@property
def import_ready(self) -> bool:
return super().import_ready and all(node.import_ready for node in self)
[docs]
def report_import_readiness(self, tabs=0, report_so_far=""):
report = super().report_import_readiness(tabs=tabs, report_so_far=report_so_far)
for node in self:
report = node.report_import_readiness(tabs=tabs + 1, report_so_far=report)
return report
[docs]
def run_data_tree_for_child(self, node: Node) -> None:
"""
Use topological analysis to build a tree of all upstream dependencies and run
them.
This method is called by a child node when it needs to run its data tree and has
a parent. The parent (this composite) handles the execution of the data tree.
Args:
node (Node): The child node that initiated the data tree run.
"""
data_tree_nodes = get_nodes_in_data_tree(node)
for n in data_tree_nodes:
if n.executor is not None:
raise ValueError(
f"Running the data tree is pull-paradigm action, and is "
f"incompatible with using executors. While running "
f"{node.full_label}, an executor request was found on "
f"{n.full_label}"
)
nodes = {n.label: n for n in data_tree_nodes}
disconnected_pairs, starters = set_run_connections_according_to_linear_dag(
nodes
)
data_tree_starters = list(set(starters).intersection(data_tree_nodes))
original_starting_nodes = self.starting_nodes
# We need these for state recovery later, even if we crash
try:
if len(data_tree_starters) > 1 or data_tree_starters[0] is not node:
node.signals.disconnect_run()
# Don't let anything upstream trigger _this_ node
self.starting_nodes = data_tree_starters
self.run()
# Otherwise the requested node is the only one in the data tree, so there's
# nothing upstream to run.
finally:
# No matter what, restore the original connections and labels afterwards
for n in nodes.values():
n.signals.disconnect_run()
for c1, c2 in disconnected_pairs:
c1.connect(c2)
self.starting_nodes = original_starting_nodes