Skip to content

graph

Graph

Directed graph that may contain normal NodeDefinition objects and NodeGroupDefinition objects.

Source code in src/lmflux/graphs/base/graph.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class Graph:
    """Directed graph that may contain normal ``NodeDefinition`` objects **and** ``NodeGroupDefinition`` objects."""

    # ------------------------------------------------------------------
    #  Construction / mutation API
    # ------------------------------------------------------------------
    def __init__(self, draw_labels_around=False) -> None:
        self.G = nx.DiGraph()  # holds the actual objects
        self.draw_labels_around = draw_labels_around

    def __add_node__(self, obj: NodeDefinition, _metadata={}) -> None:
        """Add a ``NodeDefinition`` or a ``NodeGroupDefinition`` to the graph."""
        self.G.add_node(
            obj.id,
            label=obj.name,
            obj=obj,
            shape="box" if isinstance(obj, NodeGroupDefinition) else "ellipse",
            _metadata=_metadata
        )

    def __find_object_in_graph_by_name__(self, name:str) -> NodeDefinition:
        for node in self.G.nodes:
            if node.label == name:
                return self.G.nodes[node]["obj"]

    def __add_edge__(self, src: NodeDefinition, dst: NodeDefinition, _metadata={}) -> None:
        """Create a directed edge ``src → dst``."""
        self.G.add_edge(src.id, dst.id, _metadata=_metadata)

    # ------------------------------------------------------------------
    #  Mermaid rendering – a **single** recursive implementation
    # ------------------------------------------------------------------
    def to_mermaid(self, direction: str = "TB") -> str:
        """
        Return a Mermaid flow-chart string that represents the whole graph,
        including any nested ``TaskGroup`` sub-graphs.
        """
        lines = [f"graph {direction}"]
        lines.extend(self._to_mermaid_lines(indent=0))
        return "\n".join(lines)

    def _to_mermaid_lines(self, indent: int, label_start="start", label_end="end", has_loopback=False, loopback_label="") -> List[str]:
        """
        Private helper that builds *only* the body lines (no ``graph …`` header).

        ``indent`` denotes the level of visual indentation (4 spaces per level).
        """
        ind = "    " * indent
        body: List[str] = []
        # ---- control getes --------------------------------------------
        if self.draw_labels_around:
            nid_start = str(uuid.uuid4())
            body.append(f"{ind}{nid_start}({label_start})")

            nid_end = str(uuid.uuid4())
            body.append(f"{ind}{nid_end}({label_end})")
            body.append(f"style {nid_start} fill:#ffcc00,stroke:#333,stroke-width:2px,color:#000")
            body.append(f"style {nid_end} fill:#ffcc00,stroke:#333,stroke-width:2px,color:#000")
        # ---- nodes ----------------------------------------------------
        for nid, data in self.G.nodes(data=True):
            label = data["label"]
            body.append(f"{ind}{nid}({label})")

        # ---- edges ----------------------------------------------------
        index, max_index = 0, len(list(self.G.nodes))-2
        first_id, last_id = None, None
        for src, dst, data in self.G.edges(data=True):
            label = data.get("_metadata").get("label")
            if label:
                body.append(f"{ind}{src} --{label}--> {dst}")
            else:
                body.append(f"{ind}{src} --> {dst}")
            if index == 0:
                first_id = src
            if index == max_index:
                last_id = dst
            index += 1
        # >> Check if we need to draw labels around
        if self.draw_labels_around:
            body.append(f"{ind}{nid_start} --> {first_id}")
            body.append(f"{ind}{last_id} --> {nid_end}")
        # >> Check if we need to add a loopback
        if has_loopback:
            body.append(f"{ind}{nid_end} --{loopback_label}--> {nid_start}")
        # ---- recurse into sub-graphs ---------------------------------
        for nid, data in self.G.nodes(data=True):
            obj: NodeDefinition = data["obj"]
            if isinstance(obj, NodeGroupDefinition):
                # Render the inner graph (one level deeper)
                body.extend(obj.to_mermaid(indent + 1))
        return body

    # ------------------------------------------------------------------
    #  Visualization helpers (unchanged, but now use the new renderer)
    # ------------------------------------------------------------------
    def show_mermaid(self, direction: str = "TB") -> None:
        show_markdown(f"```mermaid\n{self.to_mermaid(direction)}\n```")

__add_edge__(src, dst, _metadata={})

Create a directed edge src → dst.

Source code in src/lmflux/graphs/base/graph.py
51
52
53
def __add_edge__(self, src: NodeDefinition, dst: NodeDefinition, _metadata={}) -> None:
    """Create a directed edge ``src → dst``."""
    self.G.add_edge(src.id, dst.id, _metadata=_metadata)

__add_node__(obj, _metadata={})

Add a NodeDefinition or a NodeGroupDefinition to the graph.

Source code in src/lmflux/graphs/base/graph.py
36
37
38
39
40
41
42
43
44
def __add_node__(self, obj: NodeDefinition, _metadata={}) -> None:
    """Add a ``NodeDefinition`` or a ``NodeGroupDefinition`` to the graph."""
    self.G.add_node(
        obj.id,
        label=obj.name,
        obj=obj,
        shape="box" if isinstance(obj, NodeGroupDefinition) else "ellipse",
        _metadata=_metadata
    )

to_mermaid(direction='TB')

Return a Mermaid flow-chart string that represents the whole graph, including any nested TaskGroup sub-graphs.

Source code in src/lmflux/graphs/base/graph.py
58
59
60
61
62
63
64
65
def to_mermaid(self, direction: str = "TB") -> str:
    """
    Return a Mermaid flow-chart string that represents the whole graph,
    including any nested ``TaskGroup`` sub-graphs.
    """
    lines = [f"graph {direction}"]
    lines.extend(self._to_mermaid_lines(indent=0))
    return "\n".join(lines)

NodeDefinition

A simple task - the building block of a graph.

Source code in src/lmflux/graphs/base/graph.py
16
17
18
19
20
21
22
23
24
class NodeDefinition:
    """A simple task - the building block of a graph."""

    def __init__(self, name: str):
        self.id: str = str(uuid.uuid4())  # string IDs are easier for Mermaid
        self.name = name

    def __repr__(self) -> str:
        return f"<Node {self.name!r} ({self.id[:8]})>"

NodeGroupDefinition

Bases: NodeDefinition

A node that hides a sub-graph (its own TaskGraph).

Source code in src/lmflux/graphs/base/graph.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class NodeGroupDefinition(NodeDefinition):
    """A node that hides a sub-graph (its own ``TaskGraph``)."""

    def __init__(self, name: str):
        super().__init__(name)
        self.graph = Graph()

    # ------------------------------------------------------------------
    #  Convenience API that proxies to the inner graph
    # ------------------------------------------------------------------
    def add_task(self, task: NodeDefinition) -> None:
        self.graph.add_node(task)

    def add_edge(self, src: NodeDefinition, dst: NodeDefinition) -> None:
        self.graph.add_edge(src, dst)

    # ------------------------------------------------------------------
    #  By default the public ``TaskGroup`` just forwards the rendering
    #  request to its inner graph.  Sub-classes can override this method
    #  to add extra visual elements (see ``IterativeTask`` example below).
    # ------------------------------------------------------------------
    def to_mermaid(self, indent=1) -> list[str]:
        """
        Render the inner graph
        """
        body = [f"subgraph {self.id}[\"{self.name}\"]"]
        inner_body = self.graph._to_mermaid_lines(indent=indent+1)
        body.extend(inner_body)
        body.append("end")
        return body

    def show_mermaid(self, direction: str = "TB") -> None:
        self.graph.show_mermaid(direction)

    def defines_sub_graph(self) -> bool:
        return True

to_mermaid(indent=1)

Render the inner graph

Source code in src/lmflux/graphs/base/graph.py
149
150
151
152
153
154
155
156
157
def to_mermaid(self, indent=1) -> list[str]:
    """
    Render the inner graph
    """
    body = [f"subgraph {self.id}[\"{self.name}\"]"]
    inner_body = self.graph._to_mermaid_lines(indent=indent+1)
    body.extend(inner_body)
    body.append("end")
    return body