Files
sideline/engine/pipeline/graph.py
David Gwilliam 6646ed78b3 Add REPL effect detection and input handling to pipeline runner
- Detect REPL effect in pipeline and enable interactive mode
- Enable raw terminal mode for REPL input capture
- Add keyboard input loop for REPL commands (return, up/down arrows, backspace)
- Process commands and handle pipeline mutations from REPL
- Fix lint issues in graph and REPL modules (type annotations, imports)
2026-03-21 21:19:30 -07:00

206 lines
6.4 KiB
Python

"""Graph-based pipeline configuration and orchestration.
This module provides a graph abstraction for defining pipelines as nodes
and connections, replacing the verbose XYZStage naming convention.
Usage:
# Declarative (TOML-like)
graph = Graph.from_dict({
"nodes": {
"source": "headlines",
"camera": {"type": "camera", "mode": "scroll"},
"display": {"type": "terminal", "positioning": "mixed"}
},
"connections": ["source -> camera -> display"]
})
# Imperative
graph = Graph()
graph.node("source", "headlines")
graph.node("camera", type="camera", mode="scroll")
graph.connect("source", "camera", "display")
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class NodeType(Enum):
"""Types of pipeline nodes."""
SOURCE = "source"
RENDER = "render"
CAMERA = "camera"
EFFECT = "effect"
OVERLAY = "overlay"
POSITION = "position"
DISPLAY = "display"
CUSTOM = "custom"
@dataclass
class Node:
"""A node in the pipeline graph."""
name: str
type: NodeType
config: dict[str, Any] = field(default_factory=dict)
enabled: bool = True
optional: bool = False
def __repr__(self) -> str:
return f"Node({self.name}, type={self.type.value})"
@dataclass
class Connection:
"""A connection between two nodes."""
source: str
target: str
data_type: str | None = None # Optional data type constraint
@dataclass
class Graph:
"""Pipeline graph representation."""
nodes: dict[str, Node] = field(default_factory=dict)
connections: list[Connection] = field(default_factory=list)
def node(self, name: str, node_type: NodeType | str, **config) -> "Graph":
"""Add a node to the graph."""
if isinstance(node_type, str):
# Try to parse as NodeType
try:
node_type = NodeType(node_type)
except ValueError:
node_type = NodeType.CUSTOM
self.nodes[name] = Node(name=name, type=node_type, config=config)
return self
def connect(
self, source: str, target: str, data_type: str | None = None
) -> "Graph":
"""Add a connection between nodes."""
if source not in self.nodes:
raise ValueError(f"Source node '{source}' not found")
if target not in self.nodes:
raise ValueError(f"Target node '{target}' not found")
self.connections.append(Connection(source, target, data_type))
return self
def chain(self, *names: str) -> "Graph":
"""Connect nodes in a chain."""
for i in range(len(names) - 1):
self.connect(names[i], names[i + 1])
return self
def from_dict(self, data: dict[str, Any]) -> "Graph":
"""Load graph from dictionary (TOML-compatible)."""
# Parse nodes
nodes_data = data.get("nodes", {})
for name, node_info in nodes_data.items():
if isinstance(node_info, str):
# Simple format: "source": "headlines"
self.node(name, NodeType.SOURCE, source=node_info)
elif isinstance(node_info, dict):
# Full format: {"type": "camera", "mode": "scroll"}
node_type = node_info.get("type", "custom")
config = {k: v for k, v in node_info.items() if k != "type"}
self.node(name, node_type, **config)
# Parse connections
connections_data = data.get("connections", [])
for conn in connections_data:
if isinstance(conn, str):
# Parse "source -> target" format
parts = conn.split("->")
if len(parts) == 2:
self.connect(parts[0].strip(), parts[1].strip())
elif isinstance(conn, dict):
# Parse dict format: {"source": "a", "target": "b"}
self.connect(conn["source"], conn["target"])
return self
def to_dict(self) -> dict[str, Any]:
"""Convert graph to dictionary."""
return {
"nodes": {
name: {"type": node.type.value, **node.config}
for name, node in self.nodes.items()
},
"connections": [
{"source": conn.source, "target": conn.target}
for conn in self.connections
],
}
def validate(self) -> list[str]:
"""Validate graph structure and return list of errors."""
errors = []
# Check for disconnected nodes
connected_nodes = set()
for conn in self.connections:
connected_nodes.add(conn.source)
connected_nodes.add(conn.target)
for node_name in self.nodes:
if node_name not in connected_nodes:
errors.append(f"Node '{node_name}' is not connected")
# Check for cycles (simplified)
visited = set()
temp = set()
def has_cycle(node_name: str) -> bool:
if node_name in temp:
return True
if node_name in visited:
return False
temp.add(node_name)
for conn in self.connections:
if conn.source == node_name and has_cycle(conn.target):
return True
temp.remove(node_name)
visited.add(node_name)
return False
for node_name in self.nodes:
if has_cycle(node_name):
errors.append(f"Cycle detected involving node '{node_name}'")
break
return errors
def __repr__(self) -> str:
nodes_str = ", ".join(str(n) for n in self.nodes.values())
return f"Graph(nodes=[{nodes_str}])"
# Factory functions for common node types
def source(name: str, source_type: str, **config) -> Node:
"""Create a source node."""
return Node(name, NodeType.SOURCE, {"source": source_type, **config})
def camera(name: str, mode: str = "scroll", **config) -> Node:
"""Create a camera node."""
return Node(name, NodeType.CAMERA, {"mode": mode, **config})
def display(name: str, backend: str = "terminal", **config) -> Node:
"""Create a display node."""
return Node(name, NodeType.DISPLAY, {"backend": backend, **config})
def effect(name: str, effect_name: str, **config) -> Node:
"""Create an effect node."""
return Node(name, NodeType.EFFECT, {"effect": effect_name, **config})