Skip to content

mesh

Graph

Directed graph that may contain normal NodeDefinition objects and NodeGroupDefinition objects.

Source code in src/lmflux/graphs/base/graph.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class Graph:
    """Directed graph that may contain normal ``NodeDefinition`` objects **and** ``NodeGroupDefinition`` objects."""

    # ------------------------------------------------------------------
    #  Construction / mutation API
    # ------------------------------------------------------------------
    def __init__(self, draw_labels_around=False) -> None:
        self.G = nx.DiGraph()  # holds the actual objects
        self.draw_labels_around = draw_labels_around

    def __add_node__(self, obj: NodeDefinition, _metadata={}) -> None:
        """Add a ``NodeDefinition`` or a ``NodeGroupDefinition`` to the graph."""
        self.G.add_node(
            obj.id,
            label=obj.name,
            obj=obj,
            shape="box" if isinstance(obj, NodeGroupDefinition) else "ellipse",
            _metadata=_metadata
        )

    def __find_object_in_graph_by_name__(self, name:str) -> NodeDefinition:
        for node in self.G.nodes:
            if node.label == name:
                return self.G.nodes[node]["obj"]

    def __add_edge__(self, src: NodeDefinition, dst: NodeDefinition, _metadata={}) -> None:
        """Create a directed edge ``src → dst``."""
        self.G.add_edge(src.id, dst.id, _metadata=_metadata)

    # ------------------------------------------------------------------
    #  Mermaid rendering – a **single** recursive implementation
    # ------------------------------------------------------------------
    def to_mermaid(self, direction: str = "TB") -> str:
        """
        Return a Mermaid flow-chart string that represents the whole graph,
        including any nested ``TaskGroup`` sub-graphs.
        """
        lines = [f"graph {direction}"]
        lines.extend(self._to_mermaid_lines(indent=0))
        return "\n".join(lines)

    def _to_mermaid_lines(self, indent: int, label_start="start", label_end="end", has_loopback=False, loopback_label="") -> List[str]:
        """
        Private helper that builds *only* the body lines (no ``graph …`` header).

        ``indent`` denotes the level of visual indentation (4 spaces per level).
        """
        ind = "    " * indent
        body: List[str] = []
        # ---- control getes --------------------------------------------
        if self.draw_labels_around:
            nid_start = str(uuid.uuid4())
            body.append(f"{ind}{nid_start}({label_start})")

            nid_end = str(uuid.uuid4())
            body.append(f"{ind}{nid_end}({label_end})")
            body.append(f"style {nid_start} fill:#ffcc00,stroke:#333,stroke-width:2px,color:#000")
            body.append(f"style {nid_end} fill:#ffcc00,stroke:#333,stroke-width:2px,color:#000")
        # ---- nodes ----------------------------------------------------
        for nid, data in self.G.nodes(data=True):
            label = data["label"]
            body.append(f"{ind}{nid}({label})")

        # ---- edges ----------------------------------------------------
        index, max_index = 0, len(list(self.G.nodes))-2
        first_id, last_id = None, None
        for src, dst, data in self.G.edges(data=True):
            label = data.get("_metadata").get("label")
            if label:
                body.append(f"{ind}{src} --{label}--> {dst}")
            else:
                body.append(f"{ind}{src} --> {dst}")
            if index == 0:
                first_id = src
            if index == max_index:
                last_id = dst
            index += 1
        # >> Check if we need to draw labels around
        if self.draw_labels_around:
            body.append(f"{ind}{nid_start} --> {first_id}")
            body.append(f"{ind}{last_id} --> {nid_end}")
        # >> Check if we need to add a loopback
        if has_loopback:
            body.append(f"{ind}{nid_end} --{loopback_label}--> {nid_start}")
        # ---- recurse into sub-graphs ---------------------------------
        for nid, data in self.G.nodes(data=True):
            obj: NodeDefinition = data["obj"]
            if isinstance(obj, NodeGroupDefinition):
                # Render the inner graph (one level deeper)
                body.extend(obj.to_mermaid(indent + 1))
        return body

    # ------------------------------------------------------------------
    #  Visualization helpers (unchanged, but now use the new renderer)
    # ------------------------------------------------------------------
    def show_mermaid(self, direction: str = "TB") -> None:
        show_markdown(f"```mermaid\n{self.to_mermaid(direction)}\n```")

__add_edge__(src, dst, _metadata={})

Create a directed edge src → dst.

Source code in src/lmflux/graphs/base/graph.py
51
52
53
def __add_edge__(self, src: NodeDefinition, dst: NodeDefinition, _metadata={}) -> None:
    """Create a directed edge ``src → dst``."""
    self.G.add_edge(src.id, dst.id, _metadata=_metadata)

__add_node__(obj, _metadata={})

Add a NodeDefinition or a NodeGroupDefinition to the graph.

Source code in src/lmflux/graphs/base/graph.py
36
37
38
39
40
41
42
43
44
def __add_node__(self, obj: NodeDefinition, _metadata={}) -> None:
    """Add a ``NodeDefinition`` or a ``NodeGroupDefinition`` to the graph."""
    self.G.add_node(
        obj.id,
        label=obj.name,
        obj=obj,
        shape="box" if isinstance(obj, NodeGroupDefinition) else "ellipse",
        _metadata=_metadata
    )

to_mermaid(direction='TB')

Return a Mermaid flow-chart string that represents the whole graph, including any nested TaskGroup sub-graphs.

Source code in src/lmflux/graphs/base/graph.py
58
59
60
61
62
63
64
65
def to_mermaid(self, direction: str = "TB") -> str:
    """
    Return a Mermaid flow-chart string that represents the whole graph,
    including any nested ``TaskGroup`` sub-graphs.
    """
    lines = [f"graph {direction}"]
    lines.extend(self._to_mermaid_lines(indent=0))
    return "\n".join(lines)

NodeDefinition

A simple task - the building block of a graph.

Source code in src/lmflux/graphs/base/graph.py
16
17
18
19
20
21
22
23
24
class NodeDefinition:
    """A simple task - the building block of a graph."""

    def __init__(self, name: str):
        self.id: str = str(uuid.uuid4())  # string IDs are easier for Mermaid
        self.name = name

    def __repr__(self) -> str:
        return f"<Node {self.name!r} ({self.id[:8]})>"

NodeGroupDefinition

Bases: NodeDefinition

A node that hides a sub-graph (its own TaskGraph).

Source code in src/lmflux/graphs/base/graph.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class NodeGroupDefinition(NodeDefinition):
    """A node that hides a sub-graph (its own ``TaskGraph``)."""

    def __init__(self, name: str):
        super().__init__(name)
        self.graph = Graph()

    # ------------------------------------------------------------------
    #  Convenience API that proxies to the inner graph
    # ------------------------------------------------------------------
    def add_task(self, task: NodeDefinition) -> None:
        self.graph.add_node(task)

    def add_edge(self, src: NodeDefinition, dst: NodeDefinition) -> None:
        self.graph.add_edge(src, dst)

    # ------------------------------------------------------------------
    #  By default the public ``TaskGroup`` just forwards the rendering
    #  request to its inner graph.  Sub-classes can override this method
    #  to add extra visual elements (see ``IterativeTask`` example below).
    # ------------------------------------------------------------------
    def to_mermaid(self, indent=1) -> list[str]:
        """
        Render the inner graph
        """
        body = [f"subgraph {self.id}[\"{self.name}\"]"]
        inner_body = self.graph._to_mermaid_lines(indent=indent+1)
        body.extend(inner_body)
        body.append("end")
        return body

    def show_mermaid(self, direction: str = "TB") -> None:
        self.graph.show_mermaid(direction)

    def defines_sub_graph(self) -> bool:
        return True

to_mermaid(indent=1)

Render the inner graph

Source code in src/lmflux/graphs/base/graph.py
149
150
151
152
153
154
155
156
157
def to_mermaid(self, indent=1) -> list[str]:
    """
    Render the inner graph
    """
    body = [f"subgraph {self.id}[\"{self.name}\"]"]
    inner_body = self.graph._to_mermaid_lines(indent=indent+1)
    body.extend(inner_body)
    body.append("end")
    return body

check_compatible(function, func_name, expected_params)

Check if a function is compatible with the expected signature.

Args: - function (callable): The function to check. - func_name (str): The name of the function. - expected_params (list): A list of dictionaries containing the expected parameter name, type, and position.

Returns: - function (callable): The original function if it's compatible.

Raises: - AttributeError: If the function is not compatible with the expected signature.

Source code in src/lmflux/utils/signature_checker.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def check_compatible(function: callable, func_name: str, expected_params: list):
    """
    Check if a function is compatible with the expected signature.

    Args:
    - function (callable): The function to check.
    - func_name (str): The name of the function.
    - expected_params (list): A list of dictionaries containing the expected parameter name, type, and position.

    Returns:
    - function (callable): The original function if it's compatible.

    Raises:
    - AttributeError: If the function is not compatible with the expected signature.
    """
    param_count = 0
    params_match = [False] * len(expected_params)
    signature = inspect.signature(function)
    sigs = {
        param.name: param.annotation
        for param in signature.parameters.values()
    }

    for param_name, py_type in sigs.items():
        for i, expected_param in enumerate(expected_params):
            if (param_name == expected_param['name'] and 
                (expected_param['type'] is None or py_type == expected_param['type']) and 
                param_count == expected_param['position']):
                params_match[i] = True
        param_count += 1

    if all(params_match) and param_count == len(expected_params):
        return function

    correct_sig = f"def {function.__name__}(" + ", ".join([f"{param['name']}: {param['type'].__name__ if param['type'] else 'any'}" for param in expected_params]) + "):..."
    raise AttributeError(f"{func_name} must be defined as {correct_sig}")

transformer_node(func)

Decorator for creating an TransformerTask.

Usage

@transformer_task def my_task(session: Session): ...

Source code in src/lmflux/graphs/mesh/definitions.py
209
210
211
212
213
214
215
216
217
218
219
def transformer_node(func:callable):
    """
    Decorator for creating an TransformerTask.

    Usage:
        @transformer_task
        def my_task(session: Session):
            ...
    """
    check_compatible(func, "run", EXPECTED_TRANSFORMER_CALLBACK)
    return TransformerMeshNode(func.__name__, func)