"""Graph-based pipeline configuration and orchestration. This module provides a graph abstraction for defining pipelines as nodes and connections, replacing the verbose XYZStage naming convention. Usage: # Declarative (TOML-like) graph = Graph.from_dict({ "nodes": { "source": "headlines", "camera": {"type": "camera", "mode": "scroll"}, "display": {"type": "terminal", "positioning": "mixed"} }, "connections": ["source -> camera -> display"] }) # Imperative graph = Graph() graph.node("source", "headlines") graph.node("camera", type="camera", mode="scroll") graph.connect("source", "camera", "display") """ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union from enum import Enum class NodeType(Enum): """Types of pipeline nodes.""" SOURCE = "source" RENDER = "render" CAMERA = "camera" EFFECT = "effect" OVERLAY = "overlay" POSITION = "position" DISPLAY = "display" CUSTOM = "custom" @dataclass class Node: """A node in the pipeline graph.""" name: str type: NodeType config: Dict[str, Any] = field(default_factory=dict) enabled: bool = True optional: bool = False def __repr__(self) -> str: return f"Node({self.name}, type={self.type.value})" @dataclass class Connection: """A connection between two nodes.""" source: str target: str data_type: Optional[str] = None # Optional data type constraint @dataclass class Graph: """Pipeline graph representation.""" nodes: Dict[str, Node] = field(default_factory=dict) connections: List[Connection] = field(default_factory=list) def node(self, name: str, node_type: Union[NodeType, str], **config) -> "Graph": """Add a node to the graph.""" if isinstance(node_type, str): # Try to parse as NodeType try: node_type = NodeType(node_type) except ValueError: node_type = NodeType.CUSTOM self.nodes[name] = Node(name=name, type=node_type, config=config) return self def connect( self, source: str, target: str, data_type: Optional[str] = None ) -> "Graph": """Add a connection between nodes.""" if source not in self.nodes: raise ValueError(f"Source node '{source}' not found") if target not in self.nodes: raise ValueError(f"Target node '{target}' not found") self.connections.append(Connection(source, target, data_type)) return self def chain(self, *names: str) -> "Graph": """Connect nodes in a chain.""" for i in range(len(names) - 1): self.connect(names[i], names[i + 1]) return self def from_dict(self, data: Dict[str, Any]) -> "Graph": """Load graph from dictionary (TOML-compatible).""" # Parse nodes nodes_data = data.get("nodes", {}) for name, node_info in nodes_data.items(): if isinstance(node_info, str): # Simple format: "source": "headlines" self.node(name, NodeType.SOURCE, source=node_info) elif isinstance(node_info, dict): # Full format: {"type": "camera", "mode": "scroll"} node_type = node_info.get("type", "custom") config = {k: v for k, v in node_info.items() if k != "type"} self.node(name, node_type, **config) # Parse connections connections_data = data.get("connections", []) for conn in connections_data: if isinstance(conn, str): # Parse "source -> target" format parts = conn.split("->") if len(parts) == 2: self.connect(parts[0].strip(), parts[1].strip()) elif isinstance(conn, dict): # Parse dict format: {"source": "a", "target": "b"} self.connect(conn["source"], conn["target"]) return self def to_dict(self) -> Dict[str, Any]: """Convert graph to dictionary.""" return { "nodes": { name: {"type": node.type.value, **node.config} for name, node in self.nodes.items() }, "connections": [ {"source": conn.source, "target": conn.target} for conn in self.connections ], } def validate(self) -> List[str]: """Validate graph structure and return list of errors.""" errors = [] # Check for disconnected nodes connected_nodes = set() for conn in self.connections: connected_nodes.add(conn.source) connected_nodes.add(conn.target) for node_name in self.nodes: if node_name not in connected_nodes: errors.append(f"Node '{node_name}' is not connected") # Check for cycles (simplified) visited = set() temp = set() def has_cycle(node_name: str) -> bool: if node_name in temp: return True if node_name in visited: return False temp.add(node_name) for conn in self.connections: if conn.source == node_name: if has_cycle(conn.target): return True temp.remove(node_name) visited.add(node_name) return False for node_name in self.nodes: if has_cycle(node_name): errors.append(f"Cycle detected involving node '{node_name}'") break return errors def __repr__(self) -> str: nodes_str = ", ".join(str(n) for n in self.nodes.values()) return f"Graph(nodes=[{nodes_str}])" # Factory functions for common node types def source(name: str, source_type: str, **config) -> Node: """Create a source node.""" return Node(name, NodeType.SOURCE, {"source": source_type, **config}) def camera(name: str, mode: str = "scroll", **config) -> Node: """Create a camera node.""" return Node(name, NodeType.CAMERA, {"mode": mode, **config}) def display(name: str, backend: str = "terminal", **config) -> Node: """Create a display node.""" return Node(name, NodeType.DISPLAY, {"backend": backend, **config}) def effect(name: str, effect_name: str, **config) -> Node: """Create an effect node.""" return Node(name, NodeType.EFFECT, {"effect": effect_name, **config})