diff --git a/engine/pipeline/graph.py b/engine/pipeline/graph.py new file mode 100644 index 0000000..507ebf2 --- /dev/null +++ b/engine/pipeline/graph.py @@ -0,0 +1,206 @@ +"""Graph-based pipeline configuration and orchestration. + +This module provides a graph abstraction for defining pipelines as nodes +and connections, replacing the verbose XYZStage naming convention. + +Usage: + # Declarative (TOML-like) + graph = Graph.from_dict({ + "nodes": { + "source": "headlines", + "camera": {"type": "camera", "mode": "scroll"}, + "display": {"type": "terminal", "positioning": "mixed"} + }, + "connections": ["source -> camera -> display"] + }) + + # Imperative + graph = Graph() + graph.node("source", "headlines") + graph.node("camera", type="camera", mode="scroll") + graph.connect("source", "camera", "display") +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union +from enum import Enum + + +class NodeType(Enum): + """Types of pipeline nodes.""" + + SOURCE = "source" + RENDER = "render" + CAMERA = "camera" + EFFECT = "effect" + OVERLAY = "overlay" + POSITION = "position" + DISPLAY = "display" + CUSTOM = "custom" + + +@dataclass +class Node: + """A node in the pipeline graph.""" + + name: str + type: NodeType + config: Dict[str, Any] = field(default_factory=dict) + enabled: bool = True + optional: bool = False + + def __repr__(self) -> str: + return f"Node({self.name}, type={self.type.value})" + + +@dataclass +class Connection: + """A connection between two nodes.""" + + source: str + target: str + data_type: Optional[str] = None # Optional data type constraint + + +@dataclass +class Graph: + """Pipeline graph representation.""" + + nodes: Dict[str, Node] = field(default_factory=dict) + connections: List[Connection] = field(default_factory=list) + + def node(self, name: str, node_type: Union[NodeType, str], **config) -> "Graph": + """Add a node to the graph.""" + if isinstance(node_type, str): + # Try to parse as NodeType + try: + node_type = NodeType(node_type) + except ValueError: + node_type = NodeType.CUSTOM + + self.nodes[name] = Node(name=name, type=node_type, config=config) + return self + + def connect( + self, source: str, target: str, data_type: Optional[str] = None + ) -> "Graph": + """Add a connection between nodes.""" + if source not in self.nodes: + raise ValueError(f"Source node '{source}' not found") + if target not in self.nodes: + raise ValueError(f"Target node '{target}' not found") + + self.connections.append(Connection(source, target, data_type)) + return self + + def chain(self, *names: str) -> "Graph": + """Connect nodes in a chain.""" + for i in range(len(names) - 1): + self.connect(names[i], names[i + 1]) + return self + + def from_dict(self, data: Dict[str, Any]) -> "Graph": + """Load graph from dictionary (TOML-compatible).""" + # Parse nodes + nodes_data = data.get("nodes", {}) + for name, node_info in nodes_data.items(): + if isinstance(node_info, str): + # Simple format: "source": "headlines" + self.node(name, NodeType.SOURCE, source=node_info) + elif isinstance(node_info, dict): + # Full format: {"type": "camera", "mode": "scroll"} + node_type = node_info.get("type", "custom") + config = {k: v for k, v in node_info.items() if k != "type"} + self.node(name, node_type, **config) + + # Parse connections + connections_data = data.get("connections", []) + for conn in connections_data: + if isinstance(conn, str): + # Parse "source -> target" format + parts = conn.split("->") + if len(parts) == 2: + self.connect(parts[0].strip(), parts[1].strip()) + elif isinstance(conn, dict): + # Parse dict format: {"source": "a", "target": "b"} + self.connect(conn["source"], conn["target"]) + + return self + + def to_dict(self) -> Dict[str, Any]: + """Convert graph to dictionary.""" + return { + "nodes": { + name: {"type": node.type.value, **node.config} + for name, node in self.nodes.items() + }, + "connections": [ + {"source": conn.source, "target": conn.target} + for conn in self.connections + ], + } + + def validate(self) -> List[str]: + """Validate graph structure and return list of errors.""" + errors = [] + + # Check for disconnected nodes + connected_nodes = set() + for conn in self.connections: + connected_nodes.add(conn.source) + connected_nodes.add(conn.target) + + for node_name in self.nodes: + if node_name not in connected_nodes: + errors.append(f"Node '{node_name}' is not connected") + + # Check for cycles (simplified) + visited = set() + temp = set() + + def has_cycle(node_name: str) -> bool: + if node_name in temp: + return True + if node_name in visited: + return False + + temp.add(node_name) + for conn in self.connections: + if conn.source == node_name: + if has_cycle(conn.target): + return True + temp.remove(node_name) + visited.add(node_name) + return False + + for node_name in self.nodes: + if has_cycle(node_name): + errors.append(f"Cycle detected involving node '{node_name}'") + break + + return errors + + def __repr__(self) -> str: + nodes_str = ", ".join(str(n) for n in self.nodes.values()) + return f"Graph(nodes=[{nodes_str}])" + + +# Factory functions for common node types +def source(name: str, source_type: str, **config) -> Node: + """Create a source node.""" + return Node(name, NodeType.SOURCE, {"source": source_type, **config}) + + +def camera(name: str, mode: str = "scroll", **config) -> Node: + """Create a camera node.""" + return Node(name, NodeType.CAMERA, {"mode": mode, **config}) + + +def display(name: str, backend: str = "terminal", **config) -> Node: + """Create a display node.""" + return Node(name, NodeType.DISPLAY, {"backend": backend, **config}) + + +def effect(name: str, effect_name: str, **config) -> Node: + """Create an effect node.""" + return Node(name, NodeType.EFFECT, {"effect": effect_name, **config})