forked from genewildish/Mainline
feat(graph): Add core graph abstraction for pipeline configuration
Introduce Node, Connection, and Graph classes for defining pipelines as graphs instead of verbose XYZStage naming convention. - Add NodeType enum (SOURCE, CAMERA, EFFECT, DISPLAY, etc.) - Add Node, Connection, and Graph dataclasses with type hints - Add validation for cycles and disconnected nodes using DFS - Add factory methods: node(), connect(), chain() for easy graph building - Support for both imperative and declarative graph construction This provides the foundation for the graph-based DSL that replaces the verbose XYZStage naming convention with intuitive node-and-connection syntax.
This commit is contained in:
206
engine/pipeline/graph.py
Normal file
206
engine/pipeline/graph.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Graph-based pipeline configuration and orchestration.
|
||||
|
||||
This module provides a graph abstraction for defining pipelines as nodes
|
||||
and connections, replacing the verbose XYZStage naming convention.
|
||||
|
||||
Usage:
|
||||
# Declarative (TOML-like)
|
||||
graph = Graph.from_dict({
|
||||
"nodes": {
|
||||
"source": "headlines",
|
||||
"camera": {"type": "camera", "mode": "scroll"},
|
||||
"display": {"type": "terminal", "positioning": "mixed"}
|
||||
},
|
||||
"connections": ["source -> camera -> display"]
|
||||
})
|
||||
|
||||
# Imperative
|
||||
graph = Graph()
|
||||
graph.node("source", "headlines")
|
||||
graph.node("camera", type="camera", mode="scroll")
|
||||
graph.connect("source", "camera", "display")
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""Types of pipeline nodes."""
|
||||
|
||||
SOURCE = "source"
|
||||
RENDER = "render"
|
||||
CAMERA = "camera"
|
||||
EFFECT = "effect"
|
||||
OVERLAY = "overlay"
|
||||
POSITION = "position"
|
||||
DISPLAY = "display"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
"""A node in the pipeline graph."""
|
||||
|
||||
name: str
|
||||
type: NodeType
|
||||
config: Dict[str, Any] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
optional: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node({self.name}, type={self.type.value})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Connection:
|
||||
"""A connection between two nodes."""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
data_type: Optional[str] = None # Optional data type constraint
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
"""Pipeline graph representation."""
|
||||
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
connections: List[Connection] = field(default_factory=list)
|
||||
|
||||
def node(self, name: str, node_type: Union[NodeType, str], **config) -> "Graph":
|
||||
"""Add a node to the graph."""
|
||||
if isinstance(node_type, str):
|
||||
# Try to parse as NodeType
|
||||
try:
|
||||
node_type = NodeType(node_type)
|
||||
except ValueError:
|
||||
node_type = NodeType.CUSTOM
|
||||
|
||||
self.nodes[name] = Node(name=name, type=node_type, config=config)
|
||||
return self
|
||||
|
||||
def connect(
|
||||
self, source: str, target: str, data_type: Optional[str] = None
|
||||
) -> "Graph":
|
||||
"""Add a connection between nodes."""
|
||||
if source not in self.nodes:
|
||||
raise ValueError(f"Source node '{source}' not found")
|
||||
if target not in self.nodes:
|
||||
raise ValueError(f"Target node '{target}' not found")
|
||||
|
||||
self.connections.append(Connection(source, target, data_type))
|
||||
return self
|
||||
|
||||
def chain(self, *names: str) -> "Graph":
|
||||
"""Connect nodes in a chain."""
|
||||
for i in range(len(names) - 1):
|
||||
self.connect(names[i], names[i + 1])
|
||||
return self
|
||||
|
||||
def from_dict(self, data: Dict[str, Any]) -> "Graph":
|
||||
"""Load graph from dictionary (TOML-compatible)."""
|
||||
# Parse nodes
|
||||
nodes_data = data.get("nodes", {})
|
||||
for name, node_info in nodes_data.items():
|
||||
if isinstance(node_info, str):
|
||||
# Simple format: "source": "headlines"
|
||||
self.node(name, NodeType.SOURCE, source=node_info)
|
||||
elif isinstance(node_info, dict):
|
||||
# Full format: {"type": "camera", "mode": "scroll"}
|
||||
node_type = node_info.get("type", "custom")
|
||||
config = {k: v for k, v in node_info.items() if k != "type"}
|
||||
self.node(name, node_type, **config)
|
||||
|
||||
# Parse connections
|
||||
connections_data = data.get("connections", [])
|
||||
for conn in connections_data:
|
||||
if isinstance(conn, str):
|
||||
# Parse "source -> target" format
|
||||
parts = conn.split("->")
|
||||
if len(parts) == 2:
|
||||
self.connect(parts[0].strip(), parts[1].strip())
|
||||
elif isinstance(conn, dict):
|
||||
# Parse dict format: {"source": "a", "target": "b"}
|
||||
self.connect(conn["source"], conn["target"])
|
||||
|
||||
return self
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert graph to dictionary."""
|
||||
return {
|
||||
"nodes": {
|
||||
name: {"type": node.type.value, **node.config}
|
||||
for name, node in self.nodes.items()
|
||||
},
|
||||
"connections": [
|
||||
{"source": conn.source, "target": conn.target}
|
||||
for conn in self.connections
|
||||
],
|
||||
}
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate graph structure and return list of errors."""
|
||||
errors = []
|
||||
|
||||
# Check for disconnected nodes
|
||||
connected_nodes = set()
|
||||
for conn in self.connections:
|
||||
connected_nodes.add(conn.source)
|
||||
connected_nodes.add(conn.target)
|
||||
|
||||
for node_name in self.nodes:
|
||||
if node_name not in connected_nodes:
|
||||
errors.append(f"Node '{node_name}' is not connected")
|
||||
|
||||
# Check for cycles (simplified)
|
||||
visited = set()
|
||||
temp = set()
|
||||
|
||||
def has_cycle(node_name: str) -> bool:
|
||||
if node_name in temp:
|
||||
return True
|
||||
if node_name in visited:
|
||||
return False
|
||||
|
||||
temp.add(node_name)
|
||||
for conn in self.connections:
|
||||
if conn.source == node_name:
|
||||
if has_cycle(conn.target):
|
||||
return True
|
||||
temp.remove(node_name)
|
||||
visited.add(node_name)
|
||||
return False
|
||||
|
||||
for node_name in self.nodes:
|
||||
if has_cycle(node_name):
|
||||
errors.append(f"Cycle detected involving node '{node_name}'")
|
||||
break
|
||||
|
||||
return errors
|
||||
|
||||
def __repr__(self) -> str:
|
||||
nodes_str = ", ".join(str(n) for n in self.nodes.values())
|
||||
return f"Graph(nodes=[{nodes_str}])"
|
||||
|
||||
|
||||
# Factory functions for common node types
|
||||
def source(name: str, source_type: str, **config) -> Node:
|
||||
"""Create a source node."""
|
||||
return Node(name, NodeType.SOURCE, {"source": source_type, **config})
|
||||
|
||||
|
||||
def camera(name: str, mode: str = "scroll", **config) -> Node:
|
||||
"""Create a camera node."""
|
||||
return Node(name, NodeType.CAMERA, {"mode": mode, **config})
|
||||
|
||||
|
||||
def display(name: str, backend: str = "terminal", **config) -> Node:
|
||||
"""Create a display node."""
|
||||
return Node(name, NodeType.DISPLAY, {"backend": backend, **config})
|
||||
|
||||
|
||||
def effect(name: str, effect_name: str, **config) -> Node:
|
||||
"""Create an effect node."""
|
||||
return Node(name, NodeType.EFFECT, {"effect": effect_name, **config})
|
||||
Reference in New Issue
Block a user