forked from genewildish/Mainline
- Detect REPL effect in pipeline and enable interactive mode - Enable raw terminal mode for REPL input capture - Add keyboard input loop for REPL commands (return, up/down arrows, backspace) - Process commands and handle pipeline mutations from REPL - Fix lint issues in graph and REPL modules (type annotations, imports)
206 lines
6.4 KiB
Python
206 lines
6.4 KiB
Python
"""Graph-based pipeline configuration and orchestration.
|
|
|
|
This module provides a graph abstraction for defining pipelines as nodes
|
|
and connections, replacing the verbose XYZStage naming convention.
|
|
|
|
Usage:
|
|
# Declarative (TOML-like)
|
|
graph = Graph.from_dict({
|
|
"nodes": {
|
|
"source": "headlines",
|
|
"camera": {"type": "camera", "mode": "scroll"},
|
|
"display": {"type": "terminal", "positioning": "mixed"}
|
|
},
|
|
"connections": ["source -> camera -> display"]
|
|
})
|
|
|
|
# Imperative
|
|
graph = Graph()
|
|
graph.node("source", "headlines")
|
|
graph.node("camera", type="camera", mode="scroll")
|
|
graph.connect("source", "camera", "display")
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
|
|
class NodeType(Enum):
|
|
"""Types of pipeline nodes."""
|
|
|
|
SOURCE = "source"
|
|
RENDER = "render"
|
|
CAMERA = "camera"
|
|
EFFECT = "effect"
|
|
OVERLAY = "overlay"
|
|
POSITION = "position"
|
|
DISPLAY = "display"
|
|
CUSTOM = "custom"
|
|
|
|
|
|
@dataclass
|
|
class Node:
|
|
"""A node in the pipeline graph."""
|
|
|
|
name: str
|
|
type: NodeType
|
|
config: dict[str, Any] = field(default_factory=dict)
|
|
enabled: bool = True
|
|
optional: bool = False
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Node({self.name}, type={self.type.value})"
|
|
|
|
|
|
@dataclass
|
|
class Connection:
|
|
"""A connection between two nodes."""
|
|
|
|
source: str
|
|
target: str
|
|
data_type: str | None = None # Optional data type constraint
|
|
|
|
|
|
@dataclass
|
|
class Graph:
|
|
"""Pipeline graph representation."""
|
|
|
|
nodes: dict[str, Node] = field(default_factory=dict)
|
|
connections: list[Connection] = field(default_factory=list)
|
|
|
|
def node(self, name: str, node_type: NodeType | str, **config) -> "Graph":
|
|
"""Add a node to the graph."""
|
|
if isinstance(node_type, str):
|
|
# Try to parse as NodeType
|
|
try:
|
|
node_type = NodeType(node_type)
|
|
except ValueError:
|
|
node_type = NodeType.CUSTOM
|
|
|
|
self.nodes[name] = Node(name=name, type=node_type, config=config)
|
|
return self
|
|
|
|
def connect(
|
|
self, source: str, target: str, data_type: str | None = None
|
|
) -> "Graph":
|
|
"""Add a connection between nodes."""
|
|
if source not in self.nodes:
|
|
raise ValueError(f"Source node '{source}' not found")
|
|
if target not in self.nodes:
|
|
raise ValueError(f"Target node '{target}' not found")
|
|
|
|
self.connections.append(Connection(source, target, data_type))
|
|
return self
|
|
|
|
def chain(self, *names: str) -> "Graph":
|
|
"""Connect nodes in a chain."""
|
|
for i in range(len(names) - 1):
|
|
self.connect(names[i], names[i + 1])
|
|
return self
|
|
|
|
def from_dict(self, data: dict[str, Any]) -> "Graph":
|
|
"""Load graph from dictionary (TOML-compatible)."""
|
|
# Parse nodes
|
|
nodes_data = data.get("nodes", {})
|
|
for name, node_info in nodes_data.items():
|
|
if isinstance(node_info, str):
|
|
# Simple format: "source": "headlines"
|
|
self.node(name, NodeType.SOURCE, source=node_info)
|
|
elif isinstance(node_info, dict):
|
|
# Full format: {"type": "camera", "mode": "scroll"}
|
|
node_type = node_info.get("type", "custom")
|
|
config = {k: v for k, v in node_info.items() if k != "type"}
|
|
self.node(name, node_type, **config)
|
|
|
|
# Parse connections
|
|
connections_data = data.get("connections", [])
|
|
for conn in connections_data:
|
|
if isinstance(conn, str):
|
|
# Parse "source -> target" format
|
|
parts = conn.split("->")
|
|
if len(parts) == 2:
|
|
self.connect(parts[0].strip(), parts[1].strip())
|
|
elif isinstance(conn, dict):
|
|
# Parse dict format: {"source": "a", "target": "b"}
|
|
self.connect(conn["source"], conn["target"])
|
|
|
|
return self
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert graph to dictionary."""
|
|
return {
|
|
"nodes": {
|
|
name: {"type": node.type.value, **node.config}
|
|
for name, node in self.nodes.items()
|
|
},
|
|
"connections": [
|
|
{"source": conn.source, "target": conn.target}
|
|
for conn in self.connections
|
|
],
|
|
}
|
|
|
|
def validate(self) -> list[str]:
|
|
"""Validate graph structure and return list of errors."""
|
|
errors = []
|
|
|
|
# Check for disconnected nodes
|
|
connected_nodes = set()
|
|
for conn in self.connections:
|
|
connected_nodes.add(conn.source)
|
|
connected_nodes.add(conn.target)
|
|
|
|
for node_name in self.nodes:
|
|
if node_name not in connected_nodes:
|
|
errors.append(f"Node '{node_name}' is not connected")
|
|
|
|
# Check for cycles (simplified)
|
|
visited = set()
|
|
temp = set()
|
|
|
|
def has_cycle(node_name: str) -> bool:
|
|
if node_name in temp:
|
|
return True
|
|
if node_name in visited:
|
|
return False
|
|
|
|
temp.add(node_name)
|
|
for conn in self.connections:
|
|
if conn.source == node_name and has_cycle(conn.target):
|
|
return True
|
|
temp.remove(node_name)
|
|
visited.add(node_name)
|
|
return False
|
|
|
|
for node_name in self.nodes:
|
|
if has_cycle(node_name):
|
|
errors.append(f"Cycle detected involving node '{node_name}'")
|
|
break
|
|
|
|
return errors
|
|
|
|
def __repr__(self) -> str:
|
|
nodes_str = ", ".join(str(n) for n in self.nodes.values())
|
|
return f"Graph(nodes=[{nodes_str}])"
|
|
|
|
|
|
# Factory functions for common node types
|
|
def source(name: str, source_type: str, **config) -> Node:
|
|
"""Create a source node."""
|
|
return Node(name, NodeType.SOURCE, {"source": source_type, **config})
|
|
|
|
|
|
def camera(name: str, mode: str = "scroll", **config) -> Node:
|
|
"""Create a camera node."""
|
|
return Node(name, NodeType.CAMERA, {"mode": mode, **config})
|
|
|
|
|
|
def display(name: str, backend: str = "terminal", **config) -> Node:
|
|
"""Create a display node."""
|
|
return Node(name, NodeType.DISPLAY, {"backend": backend, **config})
|
|
|
|
|
|
def effect(name: str, effect_name: str, **config) -> Node:
|
|
"""Create an effect node."""
|
|
return Node(name, NodeType.EFFECT, {"effect": effect_name, **config})
|