"""
A base class for macro nodes, which are composite like workflows but have a static
interface and are not intended to be internally modified after instantiation.
"""
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from collections.abc import Callable
from inspect import getsource
from typing import TYPE_CHECKING, Self, TypeAlias
from pyiron_snippets.factory import classfactory
from pyiron_workflow.channels import OutputData
from pyiron_workflow.io import Inputs
from pyiron_workflow.mixin.injection import OutputsWithInjection
from pyiron_workflow.mixin.preview import ScrapesIO
from pyiron_workflow.nodes.composite import Composite
from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels
from pyiron_workflow.nodes.static_io import StaticNode
if TYPE_CHECKING:
from pyiron_workflow.channels import Channel
GraphOutput: TypeAlias = StaticNode | OutputData
[docs]
class Macro(Composite, StaticNode, ScrapesIO, ABC):
"""
A macro is a composite node that holds a graph with a fixed interface, like a
pre-populated workflow that is the same every time you instantiate it.
At instantiation, the macro uses a provided callable to build and wire the graph,
then builds a static IO interface for this graph.
This callable must use the macro object itself as the first argument (e.g. adding
nodes to it).
The provided callable may optionally specify further args and kwargs; these are
used to pre-populate the macro with :class:`UserInput` nodes, although they may
later be trimmed if the IO can be connected directly to child node IO without any
loss of functionality.
This can be especially helpful when more than one child node needs access to the
same input value.
Similarly, the callable may return any number of child nodes' output channels (or
the node itself in the case of single-output nodes) as long as a commensurate
number of labels for these outputs were provided to the class constructor.
These function-like definitions of the graph creator callable can be used
to build only input xor output, or both together.
Macro input channel labels are scraped from the signature of the graph creator;
for output, output labels can be provided explicitly as a class attribute or, as a
fallback, they are scraped from the graph creator code return statement (stripping
off the "{first argument}.", where {first argument} is whatever the name of the
first argument is.
Macro IO is _value linked_ to the child IO, so that their values stay synchronized,
but the child nodes of a macro form an isolated sub-graph.
As with function nodes, subclasses of :class:`Macro` may define a method for
creating the graph.
As with :class:`Workflow``, all DAG macros can determine their execution flow
automatically, if you have cycles in your data flow, or otherwise want more control
over the execution, all you need to do is specify the `node.signals.input.run`
connections and :attr:`starting_nodes` list yourself.
If only _one_ of these is specified, you'll get an error, but if you've provided
both then no further checks of their validity/reasonableness are performed, so be
careful.
Unlike :class:`Workflow`, this execution flow automation is set up once at
instantiation;
If the macro is modified post-facto, you may need to manually re-invoke
:meth:`configure_graph_execution`.
Examples:
Let's consider the simplest case of macros that just consecutively add 1 to
their input:
>>> from pyiron_workflow import as_macro_node, macro_node
>>> def add_one(x):
... result = x + 1
... return result
>>>
>>> def add_three_macro(self, one__x):
... self.one = self.create.function_node(add_one, x=one__x)
... self.two = self.create.function_node(add_one, self.one)
... self.three = self.create.function_node(add_one, self.two)
... self.one >> self.two >> self.three
... self.starting_nodes = [self.one]
... return self.three
In this case we had _no need_ to specify the execution order and starting nodes
--it's just an extremely simple DAG after all! -- but it's done here to
demonstrate the syntax.
We can make a macro by passing this graph-building function (that takes a macro
as its first argument, i.e. `self` from the macro's perspective) to the :class:`Macro`
class. Then, we can use it like a regular node! Just like a workflow, the
io is constructed from unconnected owned-node IO by combining node and channel
labels.
>>> macro = macro_node(add_three_macro, output_labels="three__result")
>>> out = macro(one__x=3)
>>> out.three__result
6
We can also nest macros, rename their IO, and provide access to
internally-connected IO by inputs and outputs maps:
>>> def nested_macro(self, inp):
... self.a = self.create.function_node(add_one, x=inp)
... self.b = self.create.macro_node(
... add_three_macro, one__x=self.a, output_labels="three__result"
... )
... self.c = self.create.function_node(add_one, x=self.b)
... return self.c, self.b
>>>
>>> macro = macro_node(
... nested_macro, output_labels=("out", "intermediate")
... )
>>> macro(inp=1)
{'out': 6, 'intermediate': 5}
Macros and workflows automatically generate execution flows when their data
is acyclic.
Let's build a simple macro with two independent tracks:
>>> def modified_flow_macro(self, a__x=0, b__x=0):
... self.a = self.create.function_node(add_one, x=a__x)
... self.b = self.create.function_node(add_one, x=b__x)
... self.c = self.create.function_node(add_one, x=self.b)
... return self.a, self.c
>>>
>>> m = macro_node(modified_flow_macro, output_labels=("a", "c"))
>>> m(a__x=1, b__x=2)
{'a': 2, 'c': 4}
We can override which nodes get used to start by specifying the
:attr:`starting_nodes` property and (if necessary) reconfiguring the execution
signals.
Care should be taken here, as macro nodes may be creating extra input
nodes that need to be considered.
It's advisable to use :meth:`draw()` or to otherwise inspect the macro's
children and their connections before manually updating execution flows.
Let's use this and then observe how the `a` sub-node no longer gets run:
>>> _ = m.disconnect_run()
>>> m.starting_nodes = [m.b]
>>> _ = m.b >> m.c
>>> m(a__x=1000, b__x=2000)
{'a': 2, 'c': 2002}
(The `_` is just to catch and ignore output for the doctest, you don't
typically need this.)
Note how the `a` node is no longer getting run, so the output is not updated!
Manually controlling execution flow is necessary for cyclic graphs (cf. the
while loop meta-node), but best to avoid when possible as it's easy to miss
intended connections in complex graphs.
If there's a particular macro we're going to use again and again, we might want
to consider making a new class for it using the decorator, just like we do for
function nodes. If no output labels are explicitly provided as arguments to the
decorator itself, these are scraped from the function return value, just like
for function nodes (except the initial `macro` (or `self` or whatever the first
argument is named) on any return values is ignored):
>>> from pyiron_workflow.api import Macro
>>> @Macro.wrap.as_macro_node
... def AddThreeMacro(self, x):
... add_three_macro(self, one__x=x)
... # We could also simply have decorated that function to begin with
... return self.three
>>>
>>> macro = AddThreeMacro()
>>> macro(x=0).three
3
Alternatively (and not recommended) is to make a new child class of
:class:`Macro` that overrides the :meth:`graph_creator` arg such that
the same graph is always created.
>>> class AddThreeMacro(Macro):
... _output_labels = ["three"]
...
... def graph_creator(self, x):
... add_three_macro(self, one__x=x)
... return self.three
>>>
>>> macro = AddThreeMacro()
>>> macro(x=0).three
3
We can also modify an existing macro at runtime by replacing nodes within it, as
long as the replacement has fully compatible IO. There are three syntacic ways
to do this. Let's explore these by going back to our `add_three_macro` and
replacing each of its children with a node that adds 2 instead of 1.
It's possible for the macro to hold nodes which are not publicly exposed for
data and signal connections, but which will still internally execute and store
data, e.g.:
>>> @Macro.wrap.as_macro_node("lout", "n_plus_2")
... def LikeAFunction(self, lin: list, n: int = 1):
... self.plus_two = n + 2
... self.sliced_list = lin[n:self.plus_two]
... self.double_fork = 2 * n
... return self.sliced_list, self.plus_two.channel
>>>
>>> like_functions = LikeAFunction(lin=[1,2,3,4,5,6], n=3)
>>> sorted(like_functions().items())
[('lout', [4, 5]), ('n_plus_2', 5)]
>>> like_functions.double_fork.value
6
"""
def _setup_node(self) -> None:
super()._setup_node()
ui_nodes = self._prepopulate_ui_nodes_from_graph_creator_signature()
returned_has_channel_objects = self.graph_creator(*ui_nodes)
if returned_has_channel_objects is None:
returned_has_channel_objects = ()
elif isinstance(returned_has_channel_objects, GraphOutput):
returned_has_channel_objects = (returned_has_channel_objects,)
for node in ui_nodes:
self.inputs[node.label].value_receiver = node.inputs.user_input
for graph_output, output_channel_label in zip(
returned_has_channel_objects,
() if self._output_labels is None else self._output_labels,
strict=False,
):
graph_output.channel.value_receiver = self.outputs[output_channel_label]
remaining_ui_nodes = self._purge_single_use_ui_nodes(ui_nodes)
self._configure_graph_execution(remaining_ui_nodes)
[docs]
@abstractmethod
def graph_creator(
self: Self, *args, **kwargs
) -> GraphOutput | tuple[GraphOutput, ...] | None:
"""Build the graph the node will run."""
@classmethod
def _io_defining_function(cls) -> Callable:
return cls.graph_creator
_io_defining_function_uses_self = True
@classmethod
def _scrape_output_labels(cls):
scraped_labels = super()._scrape_output_labels()
if scraped_labels is not None:
# Strip off the first argument, e.g. self.foo just becomes foo
self_argument = list(cls._get_input_args().keys())[0]
cleaned_labels = [
re.sub(r"^" + re.escape(f"{self_argument}."), "", label)
for label in scraped_labels
]
if any("." in label for label in cleaned_labels):
raise ValueError(
f"Tried to scrape cleaned labels for {cls.__name__}, but at least "
f"one of {cleaned_labels} still contains a '.' -- please provide "
f"explicit labels"
)
return cleaned_labels
else:
return scraped_labels
def _prepopulate_ui_nodes_from_graph_creator_signature(self):
ui_nodes = []
for label, (type_hint, default) in self.preview_inputs().items():
n = self.create.std.UserInput(
default,
label=label,
parent=self,
)
n.inputs.user_input.type_hint = type_hint
ui_nodes.append(n)
return tuple(ui_nodes)
def _purge_single_use_ui_nodes(self, ui_nodes):
"""
We (may) create UI nodes based on the :meth:`graph_creator` signature;
If these are connected to only a single node actually defined in the creator,
they are superfluous, and we can remove them -- linking the macro input
directly to the child node input.
"""
remaining_ui_nodes = list(ui_nodes)
for macro_input in self.inputs:
target_node = macro_input.value_receiver.owner
if (
target_node in ui_nodes # Value link is a UI node
and target_node.channel.value_receiver is None # That doesn't forward
# its value directly to the output
and len(target_node.channel.connections) <= 1 # And isn't forked to
# multiple children
):
receiver = (
target_node.channel.connections[0]
if len(target_node.channel.connections) == 1
else None
)
self.remove_child(target_node)
remaining_ui_nodes.remove(target_node)
if receiver is not None:
macro_input.value_receiver = receiver
return tuple(remaining_ui_nodes)
@property
def inputs(self) -> Inputs:
return self._inputs
@property
def outputs(self) -> OutputsWithInjection:
return self._outputs
def _parse_remotely_executed_self(self, other_self):
local_connection_data = [
[(c, c.label, c.connections) for c in io_panel]
for io_panel in [
self.inputs,
self.outputs,
self.signals.input,
self.signals.output,
]
]
super()._parse_remotely_executed_self(other_self)
for old_data, io_panel in zip(
local_connection_data,
[self.inputs, self.outputs, self.signals.input, self.signals.output],
strict=False,
# Get fresh copies of the IO panels post-update
):
for original_channel, label, connections in old_data:
new_channel = io_panel[label] # Fetch it from the fresh IO panel
new_channel.connections = connections
for other_channel in connections:
self._replace_connection(
other_channel, original_channel, new_channel
)
@staticmethod
def _replace_connection(
channel: Channel, old_connection: Channel, new_connection: Channel
):
"""Brute-force replace an old connection in a channel with a new one"""
channel.connections = [
c if c is not old_connection else new_connection for c in channel
]
def _configure_graph_execution(self, ui_nodes):
run_signals = self.disconnect_run()
has_signals = len(run_signals) > 0
has_starters = len(self.starting_nodes) > 0
if has_signals and has_starters:
# Assume the user knows what they're doing
self._reconnect_run(run_signals)
# Then put the UI upstream of the original starting nodes
for n in self.starting_nodes:
n << ui_nodes
self.starting_nodes = ui_nodes if len(ui_nodes) > 0 else self.starting_nodes
elif not has_signals and not has_starters:
# Automate construction of the execution graph
self.set_run_signals_to_dag_execution()
else:
raise ValueError(
f"The macro {self.full_label} has {len(run_signals)} run signals "
f"internally and {len(self.starting_nodes)} starting nodes. Either "
f"the entire execution graph must be specified manually, or both run "
f"signals and starting nodes must be left entirely unspecified for "
f"automatic construction of the execution graph."
)
def _reconnect_run(self, run_signal_pairs_to_restore):
self.disconnect_run()
for pairs in run_signal_pairs_to_restore:
pairs[0].connect(pairs[1])
@property
def _input_value_links(self):
"""
Value connections between child output and macro in string representation based
on labels.
The string representation helps storage, and having it as a property ensures
the name is protected.
"""
return [
(c.label, (c.value_receiver.owner.label, c.value_receiver.label))
for c in self.inputs
]
@property
def _output_value_links(self):
"""
Value connections between macro and child input in string representation based
on labels.
The string representation helps storage, and having it as a property ensures
the name is protected.
"""
return [
((c.owner.label, c.label), c.value_receiver.label)
for child in self
for c in child.outputs
if c.value_receiver is not None
]
def __getstate__(self):
state = super().__getstate__()
state["_input_value_links"] = self._input_value_links
state["_output_value_links"] = self._output_value_links
return state
def __setstate__(self, state):
# Purge value links from the state
input_links = state.pop("_input_value_links")
output_links = state.pop("_output_value_links")
super().__setstate__(state)
# Re-forge value links
for inp, (child, child_inp) in input_links:
self.inputs[inp].value_receiver = self.children[child].inputs[child_inp]
for (child, child_out), out in output_links:
self.children[child].outputs[child_out].value_receiver = self.outputs[out]
@classmethod
def _extra_info(cls) -> str:
return getsource(cls.graph_creator)
@classfactory
def macro_node_factory(
graph_creator: Callable,
validate_output_labels: bool,
use_cache: bool = True,
/,
*output_labels: str,
) -> type[Macro]:
"""
Create a new :class:`Macro` subclass using the given graph creator function.
Args:
graph_creator (callable): Function to create the graph for this subclass of
:class:`Macro`.
validate_output_labels (bool): Whether to validate the output labels against
the return values of the wrapped function.
use_cache (bool): Whether nodes of this type should default to caching their
values.
output_labels (tuple[str, ...]): Optional labels for the :class:`Macro`'s
outputs.
Returns:
type[Macro]: A new :class:`Macro` subclass.
"""
return ( # type: ignore[return-value]
graph_creator.__name__,
(Macro,), # Define parentage
{
"graph_creator": graph_creator,
"__module__": graph_creator.__module__,
"__qualname__": graph_creator.__qualname__,
"_output_labels": None if len(output_labels) == 0 else output_labels,
"_validate_output_labels": validate_output_labels,
"__doc__": Macro._io_defining_documentation(graph_creator, "graph_creator"),
"use_cache": use_cache,
},
{},
)
@dispatch_output_labels
def as_macro_node(
*output_labels: str, validate_output_labels: bool = True, use_cache: bool = True
):
"""
Decorator to convert a function into a :class:`Macro` node.
Args:
*output_labels (str): Optional labels for the :class:`Macro`'s outputs.
validate_output_labels (bool): Whether to validate the output labels.
use_cache (bool): Whether nodes of this type should default to caching their
values. (Default is True.)
Returns:
callable: A decorator that converts a function into a Macro node.
"""
def decorator(graph_creator):
macro_node_factory.clear(graph_creator.__name__) # Force a fresh class
factory_made = macro_node_factory(
graph_creator, validate_output_labels, use_cache, *output_labels
)
factory_made._reduce_imports_as = (
graph_creator.__module__,
graph_creator.__qualname__,
)
factory_made.preview_io()
return factory_made
return decorator
[docs]
def macro_node(
graph_creator: Callable[..., GraphOutput | tuple[GraphOutput, ...] | None],
*node_args,
output_labels: str | tuple[str, ...] | None = None,
validate_output_labels: bool = True,
use_cache: bool = True,
**node_kwargs,
):
"""
Create and return a :class:`Macro` node instance using the given node function.
Args:
graph_creator (callable): Function to create the graph for the :class:`Macro`.
node_args: Positional arguments for the :class:`Macro` initialization --
parsed as node input data.
output_labels (str | tuple[str, ...] | None): Labels for the :class:`Macro`'s
outputs. Default is None, which scrapes these from the return statement in
the decorated function's source code.
validate_output_labels (bool): Whether to validate the output labels. Defaults
to True.
use_cache (bool): Whether this node should default to caching its values.
(Default is True.)
node_kwargs: Keyword arguments for the :class:`Macro` initialization --
parsed as node input data when the keyword matches an input channel.
Returns:
Macro: An instance of the :class:`Macro` subclass.
"""
if output_labels is None:
output_labels = ()
elif isinstance(output_labels, str):
output_labels = (output_labels,)
macro_node_factory.clear(graph_creator.__name__) # Force a fresh class
factory_made = macro_node_factory(
graph_creator, validate_output_labels, use_cache, *output_labels
)
factory_made.preview_io()
return factory_made(*node_args, **node_kwargs)