Source code for pyiron_workflow.mixin.injection

"""
We want to be able to operate freely on (single-output) nodes and channels as though
they were regular objects, but still retain our graph paradigm.

To accomplish this, we overload :class:`pyiron_workflow.channel.OutputData` to be able
to inject new nodes into the graph dynamically.
"""

from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any

from pyiron_workflow.channels import OutputData
from pyiron_workflow.data import NOT_DATA
from pyiron_workflow.io import GenericOutputs
from pyiron_workflow.mixin.has_interface_mixins import HasChannel

if TYPE_CHECKING:
    from pyiron_workflow.node import Node


[docs] class OutputDataWithInjection(OutputData): """ Output data that must have a :class:`pyiron_workflow.node.Node` for its :attr:`owner`, and which is able to inject new nodes into that owner's graph, e.g. to accomplish operations on the channel. This class facilitates many (but not all) python operators by injecting a new node to perform that operation. Where the operator is not supported, we try to support using the operator's dunder name as a method, e.g. `==` gives us trouble with hashing, but this exploits the dunder method `.__eq__(other)`, so you can call `.eq(other)` on output data. These new nodes are instructed to run at the end of instantiation, but this fails cleanly in case they are not ready. This is intended to accommodate two likely scenarios: if you're injecting a node on top of an existing result you probably want the injection result to also be immediately available, but if you're injecting it at the end of something that hasn't run yet you don't want to see an error. """ def __init__( self, label: str, owner: Node, default: Any | None = NOT_DATA, type_hint: Any | None = None, strict_hints: bool = True, value_receiver: OutputData | None = None, ): # Override parent method to give the new owner type hint super().__init__( label=label, owner=owner, default=default, type_hint=type_hint, strict_hints=strict_hints, value_receiver=value_receiver, ) @staticmethod def _other_label(other): return ( other.channel.scoped_label if isinstance(other, HasChannel) else str(other) ) def _get_injection_label(self, injection_class, *args): other_labels = "_".join(self._other_label(other) for other in args) suffix = f"_{other_labels}" if len(args) > 0 else "" nominal_label = f"{self.scoped_label}_{injection_class.__name__}{suffix}" hashed = str(hash(nominal_label)).replace("-", "m") return f"injected_{injection_class.__name__}_{hashed}" def _node_injection(self, injection_class, *args, inject_self=True): """ Create a new node with the same parent as this channel's node, and feed it arguments, or load such a node if it already exists on the parent (based on a name dynamically generated from the injected node class and arguments). Args: injection_class (type[Node]): The new node class to instantiate *args: Any arguments for that function node inject_self (bool): Whether to pre-pend the args with self. (Default is True.) Returns: (Node): The instantiated or loaded node. """ label = self._get_injection_label(injection_class, *args) try: # First check if the node already exists return self.owner.parent.children[label] except (AttributeError, KeyError): # Fall back on creating a new node in case parent is None or node nexists node_args = (self, *args) if inject_self else args return injection_class( *node_args, parent=self.owner.parent, label=label, autorun=True ) # We don't wrap __all__ the operators, because you might really want the string or # hash or whatever of the actual channel. But we do wrap all the dunder methods # that should be unambiguously referring to an operation on values def __getattr__(self, name): from pyiron_workflow.nodes.standard import GetAttr if name == "to_hdf": raise AttributeError( "This is just a failsafe to protect us against other elements of the " "pyiron ecosystem (pyiron_base's DataContainer) running a " "`hasattr('to_hdf')` check on us and accidentally injecting a new " "getattr node." ) if name.startswith("_"): raise AttributeError( f"{self.full_label} ({OutputDataWithInjection.__name__}) tried to " f"inject on the attribute {name}, but injecting on private attributes " f"is forbidden -- if you really need it create a {GetAttr.__name__} " f"node manually." ) if name == "shape": raise AttributeError( "This is a hack to stop jupyter notebook cells from asking for a `shape`." "If you are _actually_ trying to get delayed access to a `shape` field on " "your output, you'll need to manually add an attribute access node to do " "it." ) return self._node_injection(GetAttr, name) def __getitem__(self, item): # Break slices into deeper injections, if any slice arguments are channel-like if isinstance(item, slice) and any( isinstance(slice_input, HasChannel) for slice_input in [item.start, item.stop, item.step] ): from pyiron_workflow.nodes.standard import Slice item = self._node_injection( Slice, item.start, item.stop, item.step, inject_self=False ) from pyiron_workflow.nodes.standard import GetItem return self._node_injection(GetItem, item) def __lt__(self, other): from pyiron_workflow.nodes.standard import LessThan return self._node_injection(LessThan, other) def __le__(self, other): from pyiron_workflow.nodes.standard import LessThanEquals return self._node_injection(LessThanEquals, other)
[docs] def eq(self, other): from pyiron_workflow.nodes.standard import Equals return self._node_injection(Equals, other)
def __ne__(self, other): from pyiron_workflow.nodes.standard import NotEquals return self._node_injection(NotEquals, other) def __gt__(self, other): from pyiron_workflow.nodes.standard import GreaterThan return self._node_injection(GreaterThan, other) def __ge__(self, other): from pyiron_workflow.nodes.standard import GreaterThanEquals return self._node_injection(GreaterThanEquals, other)
[docs] def bool(self): from pyiron_workflow.nodes.standard import Bool return self._node_injection(Bool)
[docs] def len(self): from pyiron_workflow.nodes.standard import Length return self._node_injection(Length)
[docs] def contains(self, other): from pyiron_workflow.nodes.standard import Contains return self._node_injection(Contains, other)
def __add__(self, other): from pyiron_workflow.nodes.standard import Add return self._node_injection(Add, other) def __sub__(self, other): from pyiron_workflow.nodes.standard import Subtract return self._node_injection(Subtract, other) def __mul__(self, other): from pyiron_workflow.nodes.standard import Multiply return self._node_injection(Multiply, other) def __rmul__(self, other): from pyiron_workflow.nodes.standard import RightMultiply return self._node_injection(RightMultiply, other) def __matmul__(self, other): from pyiron_workflow.nodes.standard import MatrixMultiply return self._node_injection(MatrixMultiply, other) def __truediv__(self, other): from pyiron_workflow.nodes.standard import Divide return self._node_injection(Divide, other) def __floordiv__(self, other): from pyiron_workflow.nodes.standard import FloorDivide return self._node_injection(FloorDivide, other) def __mod__(self, other): from pyiron_workflow.nodes.standard import Modulo return self._node_injection(Modulo, other) def __pow__(self, other): from pyiron_workflow.nodes.standard import Power return self._node_injection(Power, other) def __and__(self, other): from pyiron_workflow.nodes.standard import And return self._node_injection(And, other) def __xor__(self, other): from pyiron_workflow.nodes.standard import XOr return self._node_injection(XOr, other) def __or__(self, other): from pyiron_workflow.nodes.standard import Or return self._node_injection(Or, other) def __neg__(self): from pyiron_workflow.nodes.standard import Negative return self._node_injection(Negative) def __pos__(self): from pyiron_workflow.nodes.standard import Positive return self._node_injection(Positive) def __abs__(self): from pyiron_workflow.nodes.standard import Absolute return self._node_injection(Absolute) def __invert__(self): from pyiron_workflow.nodes.standard import Invert return self._node_injection(Invert)
[docs] def int(self): from pyiron_workflow.nodes.standard import Int return self._node_injection(Int)
[docs] def float(self): from pyiron_workflow.nodes.standard import Float return self._node_injection(Float)
def __round__(self): from pyiron_workflow.nodes.standard import Round return self._node_injection(Round)
[docs] class OutputsWithInjection(GenericOutputs[OutputDataWithInjection]): @property def _channel_class(self) -> type[OutputDataWithInjection]: return OutputDataWithInjection
[docs] class InjectsOnChannel(HasChannel, abc.ABC): @property @abc.abstractmethod def channel(self) -> OutputDataWithInjection: ... def __getattr__(self, item): try: return super().__getattr__(item) except AttributeError: channel = self.channel return getattr(channel, item) def __getitem__(self, item): return self.channel.__getitem__(item) def __lt__(self, other): return self.channel.__lt__(other) def __le__(self, other): return self.channel.__le__(other)
[docs] def eq(self, other): return self.channel.eq(other)
def __ne__(self, other): return self.channel.__ne__(other) def __gt__(self, other): return self.channel.__gt__(other) def __ge__(self, other): return self.channel.__ge__(other)
[docs] def bool(self): return self.channel.bool()
[docs] def len(self): return self.channel.len()
[docs] def contains(self, other): return self.channel.contains(other)
def __add__(self, other): return self.channel.__add__(other) def __sub__(self, other): return self.channel.__sub__(other) def __mul__(self, other): return self.channel.__mul__(other) def __rmul__(self, other): return self.channel.__rmul__(other) def __matmul__(self, other): return self.channel.__matmul__(other) def __truediv__(self, other): return self.channel.__truediv__(other) def __floordiv__(self, other): return self.channel.__floordiv__(other) def __mod__(self, other): return self.channel.__mod__(other) def __pow__(self, other): return self.channel.__pow__(other) def __and__(self, other): return self.channel.__and__(other) def __xor__(self, other): return self.channel.__xor__(other) def __or__(self, other): return self.channel.__or__(other) def __neg__(self): return self.channel.__neg__() def __pos__(self): return self.channel.__pos__() def __abs__(self): return self.channel.__abs__() def __invert__(self): return self.channel.__invert__()
[docs] def int(self): return self.channel.int()
[docs] def float(self): return self.channel.float()
def __round__(self): return self.channel.__round__()