Source code for pyiron_workflow.storage

"""
A bit of abstraction connecting generic storage routines to nodes.
"""

from __future__ import annotations

import pickle
from abc import ABC, abstractmethod
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypeAlias

import bagofholding
import cloudpickle

if TYPE_CHECKING:
    from pyiron_workflow.node import Node


[docs] class TypeNotFoundError(ImportError): """ Raised when you try to save a node, but importing its module and class give something other than its type. """
[docs] class FileTypeError(TypeError): """ Raised when you're looking for a save file, and you find something, but the file appears to be the wrong type. """
[docs] def delete_files_and_directories_recursively(path: Path): """ Recursively delete all files and directories in the given path. Args: path (Path): The path to the directory to delete. """ if not path.exists(): return for item in path.rglob("*"): if item.is_file(): item.unlink() else: delete_files_and_directories_recursively(item) path.rmdir()
[docs] class StorageInterface(ABC): """ Abstract base class defining the interface for saving, loading, and managing node storage. Public methods accept `**kwargs` so that specific implementations can have extra behaviour. In general, whatever can be set for the kwargs should also be able to be specified at initialization as a default, so that users can configure a storage back-end the way they want and then run it from there -- kwargs are just exposed to make special instances of use easier. """ @abstractmethod def _save(self, node: Node, filename: Path, /, *args, **kwargs): """ Save a node to file. Args: node (Node): The node to save filename (Path | None): The path to the save location (WITHOUT file extension.) **kwargs: Additional keyword arguments. """ @abstractmethod def _load(self, filename: Path, /, *args, **kwargs) -> Node: """ Instantiate a node from file. Args: filename (Path): The path to the file to load (WITHOUT file extension). **kwargs: Additional keyword arguments. Returns: Node: The node stored there. """ @abstractmethod def _has_saved_content(self, filename: Path, /, *args, **kwargs) -> bool: """ Check for a save file matching this storage interface. Args: filename (Path): The path to the file to look for (WITHOUT file extension). **kwargs: Additional keyword arguments. Returns: bool: Whether a commensurate file was found. """ @abstractmethod def _delete(self, filename: Path, /, *args, **kwargs): """ Remove an existing save-file for this backend. Args: filename (Path): The path to the file to delete (WITHOUT file extension). **kwargs: Additional keyword arguments. """
[docs] def save(self, node: Node, filename: str | Path | None = None, **kwargs): """ Save a node to file. Args: node (Node): The node to save filename (Path | None): The path to the save location (WITHOUT file extension.) **kwargs: Additional keyword arguments. """ filename = self._parse_filename( node=node if filename is None else None, filename=filename, ) filename.parent.mkdir(parents=True, exist_ok=True) try: self._save(node, filename, **kwargs) except Exception as e: raise e finally: # If nothing got written due to the exception, clean up the directory # (as long as there's nothing else in it) if not any(filename.parent.iterdir()): filename.parent.rmdir()
[docs] def load( self, node: Node | None = None, filename: str | Path | None = None, **kwargs ) -> Node: """ Load a node from a file. Args: node (Node | None): The node to load. Optional if filename is provided. filename (str | Path | None): The path to the file to load (without file extension). Uses the canonical filename based on the node's lexical path instead if this is None. **kwargs: Additional keyword arguments. Returns: Node: The loaded node. """ return self._load(self._parse_filename(node=node, filename=filename), **kwargs)
[docs] def has_saved_content( self, node: Node | None = None, filename: str | Path | None = None, **kwargs, ) -> bool: """ Check if a file has contents related to a node. Args: node (Node | None): The node to check. Optional if filename is provided. filename (str | Path | None): The path to the file to check (without file extension). Optional if the node is provided. **kwargs: Additional keyword arguments. Returns: bool: True if contents exist, False otherwise. """ return self._has_saved_content( self._parse_filename(node=node, filename=filename), **kwargs )
[docs] def delete( self, node: Node | None = None, filename: str | Path | None = None, delete_even_if_not_empty: bool = False, **kwargs, ): """ Delete a file associated with a node. Args: node (Node | None): The node whose associated file is to be deleted. Optional if filename is provided. filename (str | Path | None): The path to the file to delete (without file extension). Optional if the node is provided. **kwargs: Additional keyword arguments. """ filename = self._parse_filename(node=node, filename=filename) if self._has_saved_content(filename, **kwargs): self._delete(filename, **kwargs) if delete_even_if_not_empty: delete_files_and_directories_recursively(filename.parent) elif filename.parent.exists() and not any(filename.parent.iterdir()): filename.parent.rmdir()
def _parse_filename( self, node: Node | None, filename: str | Path | None = None ) -> Path: """ Make sure the node xor filename was provided, and if it's the node, convert it into a canonical filename by exploiting the node's lexical path. """ if node is None and filename is None: raise ValueError( "At least one of node or filename must be specified, or we can't know " "where to load from." ) elif node is None and filename is not None: if isinstance(filename, Path): return filename elif isinstance(filename, str): return Path(filename) else: raise TypeError( f"Expected filename to be str, pathlib.Path, or None, but got " f"{filename}" ) elif node is not None and filename is None: return node.as_path() / self.__class__.__name__.lower() elif node is not None and filename is not None: raise ValueError( f"Both the node ({node.full_label}) and filename ({filename}) were " f"specified for loading -- please only specify one or the other." ) else: raise AssertionError( "This is an unreachable state -- we have covered all four cases of the " "boolean `is (not) None` square." )
[docs] class PickleStorage(StorageInterface): _PICKLE = ".pckl" _CLOUDPICKLE = ".cpckl" def __init__(self, cloudpickle_fallback: bool = True): self.cloudpickle_fallback = cloudpickle_fallback def _fallback(self, cpf: bool | None) -> bool: return self.cloudpickle_fallback if cpf is None else cpf def _save( self, node: Node, filename: Path, /, cloudpickle_fallback: bool | None = None ): if not self._fallback(cloudpickle_fallback) and not node.import_ready: raise TypeNotFoundError( f"{node.label} cannot be saved with the storage interface " f"{self.__class__.__name__} because it (or one of its children) has " f"a type that cannot be imported. Is this node defined inside <locals>? " f"\n" f"Import readiness report: \n" f"{node.report_import_readiness()}" ) attacks = [(self._PICKLE, pickle.dump)] if self._fallback(cloudpickle_fallback): attacks += [(self._CLOUDPICKLE, cloudpickle.dump)] e: Exception | None = None for suffix, save_method in attacks: e = None p = filename.with_suffix(suffix) try: with open(p, "wb") as filehandle: save_method(node, filehandle) return except Exception as ee: e = ee p.unlink(missing_ok=True) if e is not None: raise e def _load( self, filename: Path, /, cloudpickle_fallback: bool | None = None ) -> Node: attacks = [(self._PICKLE, pickle.load)] if self._fallback(cloudpickle_fallback): attacks += [(self._CLOUDPICKLE, cloudpickle.load)] for suffix, load_method in attacks: p = filename.with_suffix(suffix) if p.is_file(): with open(p, "rb") as filehandle: inst = load_method(filehandle) return inst raise FileNotFoundError(f"Could not load {filename}, no such file found.") def _delete(self, filename: Path, /, cloudpickle_fallback: bool | None = None): suffixes = ( [self._PICKLE, self._CLOUDPICKLE] if self._fallback(cloudpickle_fallback) else [self._PICKLE] ) for suffix in suffixes: filename.with_suffix(suffix).unlink(missing_ok=True) def _has_saved_content( self, filename: Path, /, cloudpickle_fallback: bool | None = None ) -> bool: suffixes = ( [self._PICKLE, self._CLOUDPICKLE] if self._fallback(cloudpickle_fallback) else [self._PICKLE] ) return any(filename.with_suffix(suffix).exists() for suffix in suffixes)
[docs] class H5BagStorage(StorageInterface): """ Storage using :class:`bagofholding.H5Bag`. """ @staticmethod def _ensure_file_extension(filename: Path | str) -> str: suffix = ".h5" fname = str(filename) return fname if fname.endswith(suffix) else fname + suffix def _save(self, node: Node, filename: Path, /, *args, **kwargs): if not node.import_ready: raise TypeNotFoundError( f"{node.label} cannot be saved with the storage interface " f"{self.__class__.__name__} because it (or one of its children) has " f"a type that cannot be imported. Is this node defined inside <locals>? " f"\n" f"Import readiness report: \n" f"{node.report_import_readiness()}" ) bagofholding.H5Bag.save( node, self._ensure_file_extension(filename), *args, **kwargs ) def _load(self, filename: Path, /, *args, **kwargs) -> Node: bag = bagofholding.H5Bag(self._ensure_file_extension(filename)) return bag.load(*args, **kwargs) def _has_saved_content(self, filename: Path) -> bool: alleged_location = Path(self._ensure_file_extension(filename)) if not alleged_location.exists(): return False try: bagofholding.H5Bag(alleged_location) except Exception as e: raise FileTypeError( f"Found a file at {alleged_location}, but it could not be resolved as " f"a {bagofholding.H5Bag.__name__} bag." ) from e return True def _delete(self, filename: Path): Path(self._ensure_file_extension(filename)).unlink(missing_ok=True)
BackendIdentifier: TypeAlias = Literal["h5bag", "pickle"] _standard_backends: dict[BackendIdentifier, type[StorageInterface]] = { "h5bag": H5BagStorage, "pickle": PickleStorage, }
[docs] def available_backends( backend: BackendIdentifier | StorageInterface | None = None, only_requested: bool = False, ) -> Generator[StorageInterface, None, None]: """ A generator for accessing available :class:`StorageInterface` instances, starting with the one requested. Args: backend (BackendIdentifier | StorageInterface | None): The interface to yield first. only_requested (bool): Stop after yielding whatever was specified by :param:`backend`. Yields: StorageInterface: An interface for serializing :class:`Node`. """ backend_instance = ( _standard_backends.get(backend, PickleStorage)() if isinstance(backend, str) else backend ) if backend_instance is not None: yield backend_instance if only_requested: return yield from (v() for k, v in _standard_backends.items() if k != backend)