Skip to content

Graph Definitions

Graph

Source code in libs/langgraph/langgraph/graph/graph.py
class Graph:
    def __init__(self) -> None:
        self.nodes: dict[str, NodeSpec] = {}
        self.edges = set[tuple[str, str]]()
        self.branches: defaultdict[str, dict[str, Branch]] = defaultdict(dict)
        self.support_multiple_edges = False
        self.compiled = False

    @property
    def _all_edges(self) -> set[tuple[str, str]]:
        return self.edges

    @overload
    def add_node(
        self,
        node: RunnableLike,
        *,
        metadata: Optional[dict[str, Any]] = None,
    ) -> Self: ...

    @overload
    def add_node(
        self,
        node: str,
        action: RunnableLike,
        *,
        metadata: Optional[dict[str, Any]] = None,
    ) -> Self: ...

    def add_node(
        self,
        node: Union[str, RunnableLike],
        action: Optional[RunnableLike] = None,
        *,
        metadata: Optional[dict[str, Any]] = None,
    ) -> Self:
        if isinstance(node, str):
            for character in (NS_SEP, NS_END):
                if character in node:
                    raise ValueError(
                        f"'{character}' is a reserved character and is not allowed in the node names."
                    )

        if self.compiled:
            logger.warning(
                "Adding a node to a graph that has already been compiled. This will "
                "not be reflected in the compiled graph."
            )
        if not isinstance(node, str):
            action = node
            node = getattr(action, "name", getattr(action, "__name__"))
            if node is None:
                raise ValueError(
                    "Node name must be provided if action is not a function"
                )
        if action is None:
            raise RuntimeError(
                "Expected a function or Runnable action in add_node. Received None."
            )
        if node in self.nodes:
            raise ValueError(f"Node `{node}` already present.")
        if node == END or node == START:
            raise ValueError(f"Node `{node}` is reserved.")

        self.nodes[cast(str, node)] = NodeSpec(
            coerce_to_runnable(action, name=cast(str, node), trace=False), metadata
        )
        return self

    def add_edge(self, start_key: str, end_key: str) -> Self:
        if self.compiled:
            logger.warning(
                "Adding an edge to a graph that has already been compiled. This will "
                "not be reflected in the compiled graph."
            )
        if start_key == END:
            raise ValueError("END cannot be a start node")
        if end_key == START:
            raise ValueError("START cannot be an end node")

        # run this validation only for non-StateGraph graphs
        if not hasattr(self, "channels") and start_key in set(
            start for start, _ in self.edges
        ):
            raise ValueError(
                f"Already found path for node '{start_key}'.\n"
                "For multiple edges, use StateGraph with an Annotated state key."
            )

        self.edges.add((start_key, end_key))
        return self

    def add_conditional_edges(
        self,
        source: str,
        path: Union[
            Callable[..., Union[Hashable, list[Hashable]]],
            Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
            Runnable[Any, Union[Hashable, list[Hashable]]],
        ],
        path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
        then: Optional[str] = None,
    ) -> Self:
        """Add a conditional edge from the starting node to any number of destination nodes.

        Args:
            source (str): The starting node. This conditional edge will run when
                exiting this node.
            path (Union[Callable, Runnable]): The callable that determines the next
                node or nodes. If not specifying `path_map` it should return one or
                more nodes. If it returns END, the graph will stop execution.
            path_map (Optional[dict[Hashable, str]]): Optional mapping of paths to node
                names. If omitted the paths returned by `path` should be node names.
            then (Optional[str]): The name of a node to execute after the nodes
                selected by `path`.

        Returns:
            None

        Note: Without typehints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`)
            or a path_map, the graph visualization assumes the edge could transition to any node in the graph.

        """  # noqa: E501
        if self.compiled:
            logger.warning(
                "Adding an edge to a graph that has already been compiled. This will "
                "not be reflected in the compiled graph."
            )
        # coerce path_map to a dictionary
        try:
            if isinstance(path_map, dict):
                path_map_ = path_map.copy()
            elif isinstance(path_map, list):
                path_map_ = {name: name for name in path_map}
            elif isinstance(path, Runnable):
                path_map_ = None
            elif rtn_type := get_type_hints(path.__call__).get(  # type: ignore[operator]
                "return"
            ) or get_type_hints(path).get("return"):
                if get_origin(rtn_type) is Literal:
                    path_map_ = {name: name for name in get_args(rtn_type)}
                else:
                    path_map_ = None
            else:
                path_map_ = None
        except Exception:
            path_map_ = None
        # find a name for the condition
        path = coerce_to_runnable(path, name=None, trace=True)
        name = path.name or "condition"
        # validate the condition
        if name in self.branches[source]:
            raise ValueError(
                f"Branch with name `{path.name}` already exists for node " f"`{source}`"
            )
        # save it
        self.branches[source][name] = Branch(path, path_map_, then)
        return self

    def set_entry_point(self, key: str) -> Self:
        """Specifies the first node to be called in the graph.

        Equivalent to calling `add_edge(START, key)`.

        Parameters:
            key (str): The key of the node to set as the entry point.

        Returns:
            None
        """
        return self.add_edge(START, key)

    def set_conditional_entry_point(
        self,
        path: Union[
            Callable[..., Union[Hashable, list[Hashable]]],
            Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
            Runnable[Any, Union[Hashable, list[Hashable]]],
        ],
        path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
        then: Optional[str] = None,
    ) -> Self:
        """Sets a conditional entry point in the graph.

        Args:
            path (Union[Callable, Runnable]): The callable that determines the next
                node or nodes. If not specifying `path_map` it should return one or
                more nodes. If it returns END, the graph will stop execution.
            path_map (Optional[dict[str, str]]): Optional mapping of paths to node
                names. If omitted the paths returned by `path` should be node names.
            then (Optional[str]): The name of a node to execute after the nodes
                selected by `path`.

        Returns:
            None
        """
        return self.add_conditional_edges(START, path, path_map, then)

    def set_finish_point(self, key: str) -> Self:
        """Marks a node as a finish point of the graph.

        If the graph reaches this node, it will cease execution.

        Parameters:
            key (str): The key of the node to set as the finish point.

        Returns:
            None
        """
        return self.add_edge(key, END)

    def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self:
        # assemble sources
        all_sources = {src for src, _ in self._all_edges}
        for start, branches in self.branches.items():
            all_sources.add(start)
            for cond, branch in branches.items():
                if branch.then is not None:
                    if branch.ends is not None:
                        for end in branch.ends.values():
                            if end != END:
                                all_sources.add(end)
                    else:
                        for node in self.nodes:
                            if node != start and node != branch.then:
                                all_sources.add(node)
        for name, spec in self.nodes.items():
            if spec.ends:
                all_sources.add(name)
        # validate sources
        for source in all_sources:
            if source not in self.nodes and source != START:
                raise ValueError(f"Found edge starting at unknown node '{source}'")

        # assemble targets
        all_targets = {end for _, end in self._all_edges}
        for start, branches in self.branches.items():
            for cond, branch in branches.items():
                if branch.then is not None:
                    all_targets.add(branch.then)
                if branch.ends is not None:
                    for end in branch.ends.values():
                        if end not in self.nodes and end != END:
                            raise ValueError(
                                f"At '{start}' node, '{cond}' branch found unknown target '{end}'"
                            )
                        all_targets.add(end)
                else:
                    all_targets.add(END)
                    for node in self.nodes:
                        if node != start and node != branch.then:
                            all_targets.add(node)
        for name, spec in self.nodes.items():
            if spec.ends:
                all_targets.update(spec.ends)
        # validate targets
        for node in self.nodes:
            if node not in all_targets:
                raise ValueError(f"Node `{node}` is not reachable")
        for target in all_targets:
            if target not in self.nodes and target != END:
                raise ValueError(f"Found edge ending at unknown node `{target}`")
        # validate interrupts
        if interrupt:
            for node in interrupt:
                if node not in self.nodes:
                    raise ValueError(f"Interrupt node `{node}` not found")

        self.compiled = True
        return self

    def compile(
        self,
        checkpointer: Checkpointer = None,
        interrupt_before: Optional[Union[All, list[str]]] = None,
        interrupt_after: Optional[Union[All, list[str]]] = None,
        debug: bool = False,
    ) -> "CompiledGraph":
        # assign default values
        interrupt_before = interrupt_before or []
        interrupt_after = interrupt_after or []

        # validate the graph
        self.validate(
            interrupt=(
                (interrupt_before if interrupt_before != "*" else []) + interrupt_after
                if interrupt_after != "*"
                else []
            )
        )

        # create empty compiled graph
        compiled = CompiledGraph(
            builder=self,
            nodes={},
            channels={START: EphemeralValue(Any), END: EphemeralValue(Any)},
            input_channels=START,
            output_channels=END,
            stream_mode="values",
            stream_channels=[],
            checkpointer=checkpointer,
            interrupt_before_nodes=interrupt_before,
            interrupt_after_nodes=interrupt_after,
            auto_validate=False,
            debug=debug,
        )

        # attach nodes, edges, and branches
        for key, node in self.nodes.items():
            compiled.attach_node(key, node)

        for start, end in self.edges:
            compiled.attach_edge(start, end)

        for start, branches in self.branches.items():
            for name, branch in branches.items():
                compiled.attach_branch(start, name, branch)

        # validate the compiled graph
        return compiled.validate()

add_conditional_edges(source: str, path: Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None) -> Self

Add a conditional edge from the starting node to any number of destination nodes.

Parameters:

  • source (str) –

    The starting node. This conditional edge will run when exiting this node.

  • path (Union[Callable, Runnable]) –

    The callable that determines the next node or nodes. If not specifying path_map it should return one or more nodes. If it returns END, the graph will stop execution.

  • path_map (Optional[dict[Hashable, str]], default: None ) –

    Optional mapping of paths to node names. If omitted the paths returned by path should be node names.

  • then (Optional[str], default: None ) –

    The name of a node to execute after the nodes selected by path.

Returns:

  • Self

    None

Without typehints on the path function's return value (e.g., -> Literal["foo", "__end__"]:)

or a path_map, the graph visualization assumes the edge could transition to any node in the graph.

Source code in libs/langgraph/langgraph/graph/graph.py
def add_conditional_edges(
    self,
    source: str,
    path: Union[
        Callable[..., Union[Hashable, list[Hashable]]],
        Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
        Runnable[Any, Union[Hashable, list[Hashable]]],
    ],
    path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
    then: Optional[str] = None,
) -> Self:
    """Add a conditional edge from the starting node to any number of destination nodes.

    Args:
        source (str): The starting node. This conditional edge will run when
            exiting this node.
        path (Union[Callable, Runnable]): The callable that determines the next
            node or nodes. If not specifying `path_map` it should return one or
            more nodes. If it returns END, the graph will stop execution.
        path_map (Optional[dict[Hashable, str]]): Optional mapping of paths to node
            names. If omitted the paths returned by `path` should be node names.
        then (Optional[str]): The name of a node to execute after the nodes
            selected by `path`.

    Returns:
        None

    Note: Without typehints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`)
        or a path_map, the graph visualization assumes the edge could transition to any node in the graph.

    """  # noqa: E501
    if self.compiled:
        logger.warning(
            "Adding an edge to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    # coerce path_map to a dictionary
    try:
        if isinstance(path_map, dict):
            path_map_ = path_map.copy()
        elif isinstance(path_map, list):
            path_map_ = {name: name for name in path_map}
        elif isinstance(path, Runnable):
            path_map_ = None
        elif rtn_type := get_type_hints(path.__call__).get(  # type: ignore[operator]
            "return"
        ) or get_type_hints(path).get("return"):
            if get_origin(rtn_type) is Literal:
                path_map_ = {name: name for name in get_args(rtn_type)}
            else:
                path_map_ = None
        else:
            path_map_ = None
    except Exception:
        path_map_ = None
    # find a name for the condition
    path = coerce_to_runnable(path, name=None, trace=True)
    name = path.name or "condition"
    # validate the condition
    if name in self.branches[source]:
        raise ValueError(
            f"Branch with name `{path.name}` already exists for node " f"`{source}`"
        )
    # save it
    self.branches[source][name] = Branch(path, path_map_, then)
    return self

set_entry_point(key: str) -> Self

Specifies the first node to be called in the graph.

Equivalent to calling add_edge(START, key).

Parameters:

  • key (str) –

    The key of the node to set as the entry point.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_entry_point(self, key: str) -> Self:
    """Specifies the first node to be called in the graph.

    Equivalent to calling `add_edge(START, key)`.

    Parameters:
        key (str): The key of the node to set as the entry point.

    Returns:
        None
    """
    return self.add_edge(START, key)

set_conditional_entry_point(path: Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None) -> Self

Sets a conditional entry point in the graph.

Parameters:

  • path (Union[Callable, Runnable]) –

    The callable that determines the next node or nodes. If not specifying path_map it should return one or more nodes. If it returns END, the graph will stop execution.

  • path_map (Optional[dict[str, str]], default: None ) –

    Optional mapping of paths to node names. If omitted the paths returned by path should be node names.

  • then (Optional[str], default: None ) –

    The name of a node to execute after the nodes selected by path.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_conditional_entry_point(
    self,
    path: Union[
        Callable[..., Union[Hashable, list[Hashable]]],
        Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
        Runnable[Any, Union[Hashable, list[Hashable]]],
    ],
    path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
    then: Optional[str] = None,
) -> Self:
    """Sets a conditional entry point in the graph.

    Args:
        path (Union[Callable, Runnable]): The callable that determines the next
            node or nodes. If not specifying `path_map` it should return one or
            more nodes. If it returns END, the graph will stop execution.
        path_map (Optional[dict[str, str]]): Optional mapping of paths to node
            names. If omitted the paths returned by `path` should be node names.
        then (Optional[str]): The name of a node to execute after the nodes
            selected by `path`.

    Returns:
        None
    """
    return self.add_conditional_edges(START, path, path_map, then)

set_finish_point(key: str) -> Self

Marks a node as a finish point of the graph.

If the graph reaches this node, it will cease execution.

Parameters:

  • key (str) –

    The key of the node to set as the finish point.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_finish_point(self, key: str) -> Self:
    """Marks a node as a finish point of the graph.

    If the graph reaches this node, it will cease execution.

    Parameters:
        key (str): The key of the node to set as the finish point.

    Returns:
        None
    """
    return self.add_edge(key, END)

CompiledGraph

Bases: Pregel

Source code in libs/langgraph/langgraph/graph/graph.py
class CompiledGraph(Pregel):
    builder: Graph

    def __init__(self, *, builder: Graph, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.builder = builder

    def attach_node(self, key: str, node: NodeSpec) -> None:
        self.channels[key] = EphemeralValue(Any)
        self.nodes[key] = (
            PregelNode(channels=[], triggers=[], metadata=node.metadata)
            | node.runnable
            | ChannelWrite([ChannelWriteEntry(key)], tags=[TAG_HIDDEN])
        )
        cast(list[str], self.stream_channels).append(key)

    def attach_edge(self, start: str, end: str) -> None:
        if end == END:
            # publish to end channel
            self.nodes[start].writers.append(
                ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])
            )
        else:
            # subscribe to start channel
            self.nodes[end].triggers.append(start)
            cast(list[str], self.nodes[end].channels).append(start)

    def attach_branch(self, start: str, name: str, branch: Branch) -> None:
        def branch_writer(
            packets: Sequence[Union[str, Send]], config: RunnableConfig
        ) -> Optional[ChannelWrite]:
            writes = [
                (
                    ChannelWriteEntry(f"branch:{start}:{name}:{p}" if p != END else END)
                    if not isinstance(p, Send)
                    else p
                )
                for p in packets
            ]
            return ChannelWrite(
                cast(Sequence[Union[ChannelWriteEntry, Send]], writes),
                tags=[TAG_HIDDEN],
            )

        # add hidden start node
        if start == START and start not in self.nodes:
            self.nodes[start] = Channel.subscribe_to(START, tags=[TAG_HIDDEN])

        # attach branch writer
        self.nodes[start] |= branch.run(branch_writer)

        # attach branch readers
        ends = branch.ends.values() if branch.ends else [node for node in self.nodes]
        for end in ends:
            if end != END:
                channel_name = f"branch:{start}:{name}:{end}"
                self.channels[channel_name] = EphemeralValue(Any)
                self.nodes[end].triggers.append(channel_name)
                cast(list[str], self.nodes[end].channels).append(channel_name)

    async def aget_graph(
        self,
        config: Optional[RunnableConfig] = None,
        *,
        xray: Union[int, bool] = False,
    ) -> DrawableGraph:
        return self.get_graph(config, xray=xray)

    def get_graph(
        self,
        config: Optional[RunnableConfig] = None,
        *,
        xray: Union[int, bool] = False,
    ) -> DrawableGraph:
        """Returns a drawable representation of the computation graph."""
        graph = DrawableGraph()
        start_nodes: dict[str, DrawableNode] = {
            START: graph.add_node(self.get_input_schema(config), START)
        }
        end_nodes: dict[str, DrawableNode] = {}
        if xray:
            subgraphs = {
                k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph)
            }
        else:
            subgraphs = {}

        def add_edge(
            start: str,
            end: str,
            label: Optional[Hashable] = None,
            conditional: bool = False,
        ) -> None:
            if end == END and END not in end_nodes:
                end_nodes[END] = graph.add_node(self.get_output_schema(config), END)
            return graph.add_edge(
                start_nodes[start],
                end_nodes[end],
                str(label) if label is not None else None,
                conditional,
            )

        for key, n in self.builder.nodes.items():
            node = n.runnable
            metadata = n.metadata or {}
            if key in self.interrupt_before_nodes and key in self.interrupt_after_nodes:
                metadata["__interrupt"] = "before,after"
            elif key in self.interrupt_before_nodes:
                metadata["__interrupt"] = "before"
            elif key in self.interrupt_after_nodes:
                metadata["__interrupt"] = "after"
            if xray and key in subgraphs:
                subgraph = subgraphs[key].get_graph(
                    config=config,
                    xray=xray - 1
                    if isinstance(xray, int) and not isinstance(xray, bool) and xray > 0
                    else xray,
                )
                subgraph.trim_first_node()
                subgraph.trim_last_node()
                if len(subgraph.nodes) > 1:
                    e, s = graph.extend(subgraph, prefix=key)
                    if e is None:
                        raise ValueError(
                            f"Could not extend subgraph '{key}' due to missing entrypoint"
                        )
                    if s is not None:
                        start_nodes[key] = s
                    end_nodes[key] = e
                else:
                    nn = graph.add_node(node, key, metadata=metadata or None)
                    start_nodes[key] = nn
                    end_nodes[key] = nn
            else:
                nn = graph.add_node(node, key, metadata=metadata or None)
                start_nodes[key] = nn
                end_nodes[key] = nn
        for start, end in sorted(self.builder._all_edges):
            add_edge(start, end)
        for start, branches in self.builder.branches.items():
            default_ends = {
                **{k: k for k in self.builder.nodes if k != start},
                END: END,
            }
            for _, branch in branches.items():
                if branch.ends is not None:
                    ends = branch.ends
                elif branch.then is not None:
                    ends = {k: k for k in default_ends if k not in (END, branch.then)}
                else:
                    ends = cast(dict[Hashable, str], default_ends)
                for label, end in ends.items():
                    add_edge(
                        start,
                        end,
                        label if label != end else None,
                        conditional=True,
                    )
                    if branch.then is not None:
                        add_edge(end, branch.then)
        for key, n in self.builder.nodes.items():
            if n.ends:
                for end in n.ends:
                    add_edge(key, end, conditional=True)

        return graph

stream_mode: StreamMode = stream_mode class-attribute instance-attribute

Mode to stream output, defaults to 'values'.

stream_channels: Optional[Union[str, Sequence[str]]] = stream_channels class-attribute instance-attribute

Channels to stream, defaults to all channels not in reserved channels

step_timeout: Optional[float] = step_timeout class-attribute instance-attribute

Maximum time to wait for a step to complete, in seconds. Defaults to None.

debug: bool = debug if debug is not None else get_debug() instance-attribute

Whether to print debug information during execution. Defaults to False.

checkpointer: Checkpointer = checkpointer class-attribute instance-attribute

Checkpointer used to save and load graph state. Defaults to None.

store: Optional[BaseStore] = store class-attribute instance-attribute

Memory store to use for SharedValues. Defaults to None.

retry_policy: Optional[RetryPolicy] = retry_policy class-attribute instance-attribute

Retry policy to use when running tasks. Set to None to disable.

get_state(config: RunnableConfig, *, subgraphs: bool = False) -> StateSnapshot

Get the current state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def get_state(
    self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
    """Get the current state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return pregel.get_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                subgraphs=subgraphs,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(self.config, config) if self.config else config
    saved = checkpointer.get_tuple(config)
    return self._prepare_state_snapshot(
        config,
        saved,
        recurse=checkpointer if subgraphs else None,
        apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
    )

aget_state(config: RunnableConfig, *, subgraphs: bool = False) -> StateSnapshot async

Get the current state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def aget_state(
    self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
    """Get the current state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        async for _, pregel in self.aget_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return await pregel.aget_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                subgraphs=subgraphs,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(self.config, config) if self.config else config
    saved = await checkpointer.aget_tuple(config)
    return await self._aprepare_state_snapshot(
        config,
        saved,
        recurse=checkpointer if subgraphs else None,
        apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
    )

get_state_history(config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None) -> Iterator[StateSnapshot]

Get the history of the state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def get_state_history(
    self,
    config: RunnableConfig,
    *,
    filter: Optional[Dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
    """Get the history of the state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            yield from pregel.get_state_history(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                filter=filter,
                before=before,
                limit=limit,
            )
            return
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(
        self.config,
        config,
        {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
    )
    # eagerly consume list() to avoid holding up the db cursor
    for checkpoint_tuple in list(
        checkpointer.list(config, before=before, limit=limit, filter=filter)
    ):
        yield self._prepare_state_snapshot(
            checkpoint_tuple.config, checkpoint_tuple
        )

aget_state_history(config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None) -> AsyncIterator[StateSnapshot] async

Get the history of the state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def aget_state_history(
    self,
    config: RunnableConfig,
    *,
    filter: Optional[Dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
    """Get the history of the state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        async for _, pregel in self.aget_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            async for state in pregel.aget_state_history(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                filter=filter,
                before=before,
                limit=limit,
            ):
                yield state
            return
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(
        self.config,
        config,
        {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
    )
    # eagerly consume list() to avoid holding up the db cursor
    for checkpoint_tuple in [
        c
        async for c in checkpointer.alist(
            config, before=before, limit=limit, filter=filter
        )
    ]:
        yield await self._aprepare_state_snapshot(
            checkpoint_tuple.config, checkpoint_tuple
        )

update_state(config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None) -> RunnableConfig

Update the state of the graph with the given values, as if they came from node as_node. If as_node is not provided, it will be set to the last node that updated the state, if not ambiguous.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def update_state(
    self,
    config: RunnableConfig,
    values: Optional[Union[dict[str, Any], Any]],
    as_node: Optional[str] = None,
) -> RunnableConfig:
    """Update the state of the graph with the given values, as if they came from
    node `as_node`. If `as_node` is not provided, it will be set to the last node
    that updated the state, if not ambiguous.
    """
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    # delegate to subgraph
    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return pregel.update_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                values,
                as_node,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    # get last checkpoint
    config = ensure_config(self.config, config)
    saved = checkpointer.get_tuple(config)
    checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
    checkpoint_previous_versions = (
        saved.checkpoint["channel_versions"].copy() if saved else {}
    )
    step = saved.metadata.get("step", -1) if saved else -1
    # merge configurable fields with previous checkpoint config
    checkpoint_config = patch_configurable(
        config,
        {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
    )
    checkpoint_metadata = config["metadata"]
    if saved:
        checkpoint_config = patch_configurable(config, saved.config[CONF])
        checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
    with ChannelsManager(
        self.channels,
        checkpoint,
        LoopProtocol(config=config, step=step + 1, stop=step + 2),
    ) as (channels, managed):
        # no values as END, just clear all tasks
        if values is None and as_node == END:
            if saved is not None:
                # tasks for this checkpoint
                next_tasks = prepare_next_tasks(
                    checkpoint,
                    saved.pending_writes or [],
                    self.nodes,
                    channels,
                    managed,
                    saved.config,
                    saved.metadata.get("step", -1) + 1,
                    for_execution=True,
                    store=self.store,
                    checkpointer=self.checkpointer or None,
                    manager=None,
                )
                # apply null writes
                if null_writes := [
                    w[1:]
                    for w in saved.pending_writes or []
                    if w[0] == NULL_TASK_ID
                ]:
                    apply_writes(
                        saved.checkpoint,
                        channels,
                        [PregelTaskWrites((), INPUT, null_writes, [])],
                        None,
                    )
                # apply writes from tasks that already ran
                for tid, k, v in saved.pending_writes or []:
                    if k in (ERROR, INTERRUPT, SCHEDULED):
                        continue
                    if tid not in next_tasks:
                        continue
                    next_tasks[tid].writes.append((k, v))
                # clear all current tasks
                apply_writes(checkpoint, channels, next_tasks.values(), None)
            # save checkpoint
            next_config = checkpointer.put(
                checkpoint_config,
                create_checkpoint(checkpoint, None, step),
                {
                    **checkpoint_metadata,
                    "source": "update",
                    "step": step + 1,
                    "writes": {},
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        # no values, copy checkpoint
        if values is None and as_node is None:
            next_checkpoint = create_checkpoint(checkpoint, None, step)
            # copy checkpoint
            next_config = checkpointer.put(
                checkpoint_config,
                next_checkpoint,
                {
                    **checkpoint_metadata,
                    "source": "update",
                    "step": step + 1,
                    "writes": {},
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        if values is None and as_node == "__copy__":
            next_checkpoint = create_checkpoint(checkpoint, None, step)
            # copy checkpoint
            next_config = checkpointer.put(
                saved.parent_config or saved.config if saved else checkpoint_config,
                next_checkpoint,
                {
                    **checkpoint_metadata,
                    "source": "fork",
                    "step": step + 1,
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        # apply pending writes, if not on specific checkpoint
        if (
            CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
            and saved is not None
            and saved.pending_writes
        ):
            # tasks for this checkpoint
            next_tasks = prepare_next_tasks(
                checkpoint,
                saved.pending_writes,
                self.nodes,
                channels,
                managed,
                saved.config,
                saved.metadata.get("step", -1) + 1,
                for_execution=True,
                store=self.store,
                checkpointer=self.checkpointer or None,
                manager=None,
            )
            # apply null writes
            if null_writes := [
                w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
            ]:
                apply_writes(
                    saved.checkpoint,
                    channels,
                    [PregelTaskWrites((), INPUT, null_writes, [])],
                    None,
                )
            # apply writes
            for tid, k, v in saved.pending_writes:
                if k in (ERROR, INTERRUPT, SCHEDULED):
                    continue
                if tid not in next_tasks:
                    continue
                next_tasks[tid].writes.append((k, v))
            if tasks := [t for t in next_tasks.values() if t.writes]:
                apply_writes(checkpoint, channels, tasks, None)
        # find last node that updated the state, if not provided
        if as_node is None and not any(
            v for vv in checkpoint["versions_seen"].values() for v in vv.values()
        ):
            if (
                isinstance(self.input_channels, str)
                and self.input_channels in self.nodes
            ):
                as_node = self.input_channels
        elif as_node is None:
            last_seen_by_node = sorted(
                (v, n)
                for n, seen in checkpoint["versions_seen"].items()
                if n in self.nodes
                for v in seen.values()
            )
            # if two nodes updated the state at the same time, it's ambiguous
            if last_seen_by_node:
                if len(last_seen_by_node) == 1:
                    as_node = last_seen_by_node[0][1]
                elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
                    as_node = last_seen_by_node[-1][1]
        if as_node is None:
            raise InvalidUpdateError("Ambiguous update, specify as_node")
        if as_node not in self.nodes:
            raise InvalidUpdateError(f"Node {as_node} does not exist")
        # create task to run all writers of the chosen node
        writers = self.nodes[as_node].flat_writers
        if not writers:
            raise InvalidUpdateError(f"Node {as_node} has no writers")
        writes: deque[tuple[str, Any]] = deque()
        task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
        task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
        run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
        # execute task
        run.invoke(
            values,
            patch_config(
                config,
                run_name=self.name + "UpdateState",
                configurable={
                    # deque.extend is thread-safe
                    CONFIG_KEY_SEND: partial(
                        local_write,
                        writes.extend,
                        self.nodes.keys(),
                    ),
                    CONFIG_KEY_READ: partial(
                        local_read,
                        step + 1,
                        checkpoint,
                        channels,
                        managed,
                        task,
                        config,
                    ),
                },
            ),
        )
        # save task writes
        # channel writes are saved to current checkpoint
        # push writes are saved to next checkpoint
        channel_writes, push_writes = (
            [w for w in task.writes if w[0] != PUSH],
            [w for w in task.writes if w[0] == PUSH],
        )
        if saved and channel_writes:
            checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
        # apply to checkpoint and save
        mv_writes = apply_writes(
            checkpoint, channels, [task], checkpointer.get_next_version
        )
        assert not mv_writes, "Can't write to SharedValues from update_state"
        checkpoint = create_checkpoint(checkpoint, channels, step + 1)
        next_config = checkpointer.put(
            checkpoint_config,
            checkpoint,
            {
                **checkpoint_metadata,
                "source": "update",
                "step": step + 1,
                "writes": {as_node: values},
                "parents": saved.metadata.get("parents", {}) if saved else {},
            },
            get_new_channel_versions(
                checkpoint_previous_versions, checkpoint["channel_versions"]
            ),
        )
        if push_writes:
            checkpointer.put_writes(next_config, push_writes, task_id)
        return patch_checkpoint_map(next_config, saved.metadata if saved else None)

stream(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False) -> Iterator[Union[dict[str, Any], Any]]

Stream graph steps for a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input to the graph.

  • config (Optional[RunnableConfig], default: None ) –

    The configuration to use for the run.

  • stream_mode (Optional[Union[StreamMode, list[StreamMode]]], default: None ) –

    The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step.

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    The keys to stream, defaults to all non-context channels.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt before, defaults to all nodes in the graph.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt after, defaults to all nodes in the graph.

  • debug (Optional[bool], default: None ) –

    Whether to print debug information during execution, defaults to False.

  • subgraphs (bool, default: False ) –

    Whether to stream subgraphs, defaults to False.

Yields:

  • Union[dict[str, Any], Any]

    The output of each step in the graph. The output shape depends on the stream_mode.

Examples:

Using different stream modes with a graph:

>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
...     alist: Annotated[list, operator.add]
...     another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
With stream_mode="values":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
...     print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
With stream_mode="updates":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
...     print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
With stream_mode="debug":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
...     print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
Source code in libs/langgraph/langgraph/pregel/__init__.py
def stream(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]:
    """Stream graph steps for a single input.

    Args:
        input: The input to the graph.
        config: The configuration to use for the run.
        stream_mode: The mode to stream output, defaults to self.stream_mode.
            Options are 'values', 'updates', and 'debug'.
            values: Emit the current values of the state for each step.
            updates: Emit only the updates to the state for each step.
                Output is a dict with the node name as key and the updated values as value.
            debug: Emit debug events for each step.
        output_keys: The keys to stream, defaults to all non-context channels.
        interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
        interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
        debug: Whether to print debug information during execution, defaults to False.
        subgraphs: Whether to stream subgraphs, defaults to False.

    Yields:
        The output of each step in the graph. The output shape depends on the stream_mode.

    Examples:
        Using different stream modes with a graph:
        ```pycon
        >>> import operator
        >>> from typing_extensions import Annotated, TypedDict
        >>> from langgraph.graph import StateGraph
        >>> from langgraph.constants import START
        ...
        >>> class State(TypedDict):
        ...     alist: Annotated[list, operator.add]
        ...     another_list: Annotated[list, operator.add]
        ...
        >>> builder = StateGraph(State)
        >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
        >>> builder.add_node("b", lambda _state: {"alist": ["there"]})
        >>> builder.add_edge("a", "b")
        >>> builder.add_edge(START, "a")
        >>> graph = builder.compile()
        ```
        With stream_mode="values":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
        ...     print(event)
        {'alist': ['Ex for stream_mode="values"'], 'another_list': []}
        {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
        {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
        ```
        With stream_mode="updates":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
        ...     print(event)
        {'a': {'another_list': ['hi']}}
        {'b': {'alist': ['there']}}
        ```
        With stream_mode="debug":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
        ...     print(event)
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
        ```
    """

    stream = SyncQueue()

    def output() -> Iterator:
        while True:
            try:
                ns, mode, payload = stream.get(block=False)
            except queue.Empty:
                break
            if subgraphs and isinstance(stream_mode, list):
                yield (ns, mode, payload)
            elif isinstance(stream_mode, list):
                yield (mode, payload)
            elif subgraphs:
                yield (ns, payload)
            else:
                yield payload

    config = ensure_config(self.config, config)
    callback_manager = get_callback_manager_for_config(config)
    run_manager = callback_manager.on_chain_start(
        None,
        input,
        name=config.get("run_name", self.get_name()),
        run_id=config.get("run_id"),
    )
    try:
        # assign defaults
        (
            debug,
            stream_modes,
            output_keys,
            interrupt_before_,
            interrupt_after_,
            checkpointer,
            store,
        ) = self._defaults(
            config,
            stream_mode=stream_mode,
            output_keys=output_keys,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
        )
        # set up messages stream mode
        if "messages" in stream_modes:
            run_manager.inheritable_handlers.append(
                StreamMessagesHandler(stream.put)
            )
        # set up custom stream mode
        if "custom" in stream_modes:
            config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put(
                ((), "custom", c)
            )
        with SyncPregelLoop(
            input,
            stream=StreamProtocol(stream.put, stream_modes),
            config=config,
            store=store,
            checkpointer=checkpointer,
            nodes=self.nodes,
            specs=self.channels,
            output_keys=output_keys,
            stream_keys=self.stream_channels_asis,
            interrupt_before=interrupt_before_,
            interrupt_after=interrupt_after_,
            manager=run_manager,
            debug=debug,
        ) as loop:
            # create runner
            runner = PregelRunner(
                submit=loop.submit,
                put_writes=loop.put_writes,
                schedule_task=loop.accept_push,
                node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
            )
            # enable subgraph streaming
            if subgraphs:
                loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
            # enable concurrent streaming
            if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
                # we are careful to have a single waiter live at any one time
                # because on exit we increment semaphore count by exactly 1
                waiter: Optional[concurrent.futures.Future] = None
                # because sync futures cannot be cancelled, we instead
                # release the stream semaphore on exit, which will cause
                # a pending waiter to return immediately
                loop.stack.callback(stream._count.release)

                def get_waiter() -> concurrent.futures.Future[None]:
                    nonlocal waiter
                    if waiter is None or waiter.done():
                        waiter = loop.submit(stream.wait)
                        return waiter
                    else:
                        return waiter

            else:
                get_waiter = None  # type: ignore[assignment]
            # Similarly to Bulk Synchronous Parallel / Pregel model
            # computation proceeds in steps, while there are channel updates
            # channel updates from step N are only visible in step N+1
            # channels are guaranteed to be immutable for the duration of the step,
            # with channel updates applied only at the transition between steps
            while loop.tick(input_keys=self.input_channels):
                for _ in runner.tick(
                    loop.tasks.values(),
                    timeout=self.step_timeout,
                    retry_policy=self.retry_policy,
                    get_waiter=get_waiter,
                ):
                    # emit output
                    yield from output()
        # emit output
        yield from output()
        # handle exit
        if loop.status == "out_of_steps":
            msg = create_error_message(
                message=(
                    f"Recursion limit of {config['recursion_limit']} reached "
                    "without hitting a stop condition. You can increase the "
                    "limit by setting the `recursion_limit` config key."
                ),
                error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
            )
            raise GraphRecursionError(msg)
        # set final channel values as run output
        run_manager.on_chain_end(loop.output)
    except BaseException as e:
        run_manager.on_chain_error(e)
        raise

astream(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False) -> AsyncIterator[Union[dict[str, Any], Any]] async

Stream graph steps for a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input to the graph.

  • config (Optional[RunnableConfig], default: None ) –

    The configuration to use for the run.

  • stream_mode (Optional[Union[StreamMode, list[StreamMode]]], default: None ) –

    The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step.

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    The keys to stream, defaults to all non-context channels.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt before, defaults to all nodes in the graph.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt after, defaults to all nodes in the graph.

  • debug (Optional[bool], default: None ) –

    Whether to print debug information during execution, defaults to False.

  • subgraphs (bool, default: False ) –

    Whether to stream subgraphs, defaults to False.

Yields:

  • AsyncIterator[Union[dict[str, Any], Any]]

    The output of each step in the graph. The output shape depends on the stream_mode.

Examples:

Using different stream modes with a graph:

>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
...     alist: Annotated[list, operator.add]
...     another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
With stream_mode="values":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
...     print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
With stream_mode="updates":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
...     print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
With stream_mode="debug":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
...     print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
Source code in libs/langgraph/langgraph/pregel/__init__.py
async def astream(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
    """Stream graph steps for a single input.

    Args:
        input: The input to the graph.
        config: The configuration to use for the run.
        stream_mode: The mode to stream output, defaults to self.stream_mode.
            Options are 'values', 'updates', and 'debug'.
            values: Emit the current values of the state for each step.
            updates: Emit only the updates to the state for each step.
                Output is a dict with the node name as key and the updated values as value.
            debug: Emit debug events for each step.
        output_keys: The keys to stream, defaults to all non-context channels.
        interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
        interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
        debug: Whether to print debug information during execution, defaults to False.
        subgraphs: Whether to stream subgraphs, defaults to False.

    Yields:
        The output of each step in the graph. The output shape depends on the stream_mode.

    Examples:
        Using different stream modes with a graph:
        ```pycon
        >>> import operator
        >>> from typing_extensions import Annotated, TypedDict
        >>> from langgraph.graph import StateGraph
        >>> from langgraph.constants import START
        ...
        >>> class State(TypedDict):
        ...     alist: Annotated[list, operator.add]
        ...     another_list: Annotated[list, operator.add]
        ...
        >>> builder = StateGraph(State)
        >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
        >>> builder.add_node("b", lambda _state: {"alist": ["there"]})
        >>> builder.add_edge("a", "b")
        >>> builder.add_edge(START, "a")
        >>> graph = builder.compile()
        ```
        With stream_mode="values":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
        ...     print(event)
        {'alist': ['Ex for stream_mode="values"'], 'another_list': []}
        {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
        {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
        ```
        With stream_mode="updates":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
        ...     print(event)
        {'a': {'another_list': ['hi']}}
        {'b': {'alist': ['there']}}
        ```
        With stream_mode="debug":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
        ...     print(event)
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
        ```
    """

    stream = AsyncQueue()
    aioloop = asyncio.get_running_loop()
    stream_put = cast(
        Callable[[StreamChunk], None],
        partial(aioloop.call_soon_threadsafe, stream.put_nowait),
    )

    def output() -> Iterator:
        while True:
            try:
                ns, mode, payload = stream.get_nowait()
            except asyncio.QueueEmpty:
                break
            if subgraphs and isinstance(stream_mode, list):
                yield (ns, mode, payload)
            elif isinstance(stream_mode, list):
                yield (mode, payload)
            elif subgraphs:
                yield (ns, payload)
            else:
                yield payload

    config = ensure_config(self.config, config)
    callback_manager = get_async_callback_manager_for_config(config)
    run_manager = await callback_manager.on_chain_start(
        None,
        input,
        name=config.get("run_name", self.get_name()),
        run_id=config.get("run_id"),
    )
    # if running from astream_log() run each proc with streaming
    do_stream = next(
        (
            cast(_StreamingCallbackHandler, h)
            for h in run_manager.handlers
            if isinstance(h, _StreamingCallbackHandler)
        ),
        None,
    )
    try:
        # assign defaults
        (
            debug,
            stream_modes,
            output_keys,
            interrupt_before_,
            interrupt_after_,
            checkpointer,
            store,
        ) = self._defaults(
            config,
            stream_mode=stream_mode,
            output_keys=output_keys,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
        )
        # set up messages stream mode
        if "messages" in stream_modes:
            run_manager.inheritable_handlers.append(
                StreamMessagesHandler(stream_put)
            )
        # set up custom stream mode
        if "custom" in stream_modes:
            config[CONF][CONFIG_KEY_STREAM_WRITER] = (
                lambda c: aioloop.call_soon_threadsafe(
                    stream.put_nowait, ((), "custom", c)
                )
            )
        async with AsyncPregelLoop(
            input,
            stream=StreamProtocol(stream.put_nowait, stream_modes),
            config=config,
            store=store,
            checkpointer=checkpointer,
            nodes=self.nodes,
            specs=self.channels,
            output_keys=output_keys,
            stream_keys=self.stream_channels_asis,
            interrupt_before=interrupt_before_,
            interrupt_after=interrupt_after_,
            manager=run_manager,
            debug=debug,
        ) as loop:
            # create runner
            runner = PregelRunner(
                submit=loop.submit,
                put_writes=loop.put_writes,
                schedule_task=loop.accept_push,
                use_astream=do_stream is not None,
                node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
            )
            # enable subgraph streaming
            if subgraphs:
                loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol(
                    stream_put, stream_modes
                )
            # enable concurrent streaming
            if subgraphs or "messages" in stream_modes or "custom" in stream_modes:

                def get_waiter() -> asyncio.Task[None]:
                    return aioloop.create_task(stream.wait())

            else:
                get_waiter = None  # type: ignore[assignment]
            # Similarly to Bulk Synchronous Parallel / Pregel model
            # computation proceeds in steps, while there are channel updates
            # channel updates from step N are only visible in step N+1
            # channels are guaranteed to be immutable for the duration of the step,
            # with channel updates applied only at the transition between steps
            while loop.tick(input_keys=self.input_channels):
                async for _ in runner.atick(
                    loop.tasks.values(),
                    timeout=self.step_timeout,
                    retry_policy=self.retry_policy,
                    get_waiter=get_waiter,
                ):
                    # emit output
                    for o in output():
                        yield o
        # emit output
        for o in output():
            yield o
        # handle exit
        if loop.status == "out_of_steps":
            msg = create_error_message(
                message=(
                    f"Recursion limit of {config['recursion_limit']} reached "
                    "without hitting a stop condition. You can increase the "
                    "limit by setting the `recursion_limit` config key."
                ),
                error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
            )
            raise GraphRecursionError(msg)
        # set final channel values as run output
        await run_manager.on_chain_end(loop.output)
    except BaseException as e:
        await asyncio.shield(run_manager.on_chain_error(e))
        raise

invoke(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = 'values', output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any) -> Union[dict[str, Any], Any]

Run the graph with a single input and config.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input data for the graph. It can be a dictionary or any other type.

  • config (Optional[RunnableConfig], default: None ) –

    Optional. The configuration for the graph run.

  • stream_mode (StreamMode, default: 'values' ) –

    Optional[str]. The stream mode for the graph run. Default is "values".

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    Optional. The output keys to retrieve from the graph run.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt the graph run before.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt the graph run after.

  • debug (Optional[bool], default: None ) –

    Optional. Enable debug mode for the graph run.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments to pass to the graph run.

Returns:

  • Union[dict[str, Any], Any]

    The output of the graph run. If stream_mode is "values", it returns the latest output.

  • Union[dict[str, Any], Any]

    If stream_mode is not "values", it returns a list of output chunks.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def invoke(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: StreamMode = "values",
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    **kwargs: Any,
) -> Union[dict[str, Any], Any]:
    """Run the graph with a single input and config.

    Args:
        input: The input data for the graph. It can be a dictionary or any other type.
        config: Optional. The configuration for the graph run.
        stream_mode: Optional[str]. The stream mode for the graph run. Default is "values".
        output_keys: Optional. The output keys to retrieve from the graph run.
        interrupt_before: Optional. The nodes to interrupt the graph run before.
        interrupt_after: Optional. The nodes to interrupt the graph run after.
        debug: Optional. Enable debug mode for the graph run.
        **kwargs: Additional keyword arguments to pass to the graph run.

    Returns:
        The output of the graph run. If stream_mode is "values", it returns the latest output.
        If stream_mode is not "values", it returns a list of output chunks.
    """
    output_keys = output_keys if output_keys is not None else self.output_channels
    if stream_mode == "values":
        latest: Union[dict[str, Any], Any] = None
    else:
        chunks = []
    for chunk in self.stream(
        input,
        config,
        stream_mode=stream_mode,
        output_keys=output_keys,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        debug=debug,
        **kwargs,
    ):
        if stream_mode == "values":
            latest = chunk
        else:
            chunks.append(chunk)
    if stream_mode == "values":
        return latest
    else:
        return chunks

ainvoke(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = 'values', output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any) -> Union[dict[str, Any], Any] async

Asynchronously invoke the graph on a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input data for the computation. It can be a dictionary or any other type.

  • config (Optional[RunnableConfig], default: None ) –

    Optional. The configuration for the computation.

  • stream_mode (StreamMode, default: 'values' ) –

    Optional. The stream mode for the computation. Default is "values".

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    Optional. The output keys to include in the result. Default is None.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt before. Default is None.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt after. Default is None.

  • debug (Optional[bool], default: None ) –

    Optional. Whether to enable debug mode. Default is None.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Returns:

  • Union[dict[str, Any], Any]

    The result of the computation. If stream_mode is "values", it returns the latest value.

  • Union[dict[str, Any], Any]

    If stream_mode is "chunks", it returns a list of chunks.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def ainvoke(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: StreamMode = "values",
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    **kwargs: Any,
) -> Union[dict[str, Any], Any]:
    """Asynchronously invoke the graph on a single input.

    Args:
        input: The input data for the computation. It can be a dictionary or any other type.
        config: Optional. The configuration for the computation.
        stream_mode: Optional. The stream mode for the computation. Default is "values".
        output_keys: Optional. The output keys to include in the result. Default is None.
        interrupt_before: Optional. The nodes to interrupt before. Default is None.
        interrupt_after: Optional. The nodes to interrupt after. Default is None.
        debug: Optional. Whether to enable debug mode. Default is None.
        **kwargs: Additional keyword arguments.

    Returns:
        The result of the computation. If stream_mode is "values", it returns the latest value.
        If stream_mode is "chunks", it returns a list of chunks.
    """

    output_keys = output_keys if output_keys is not None else self.output_channels
    if stream_mode == "values":
        latest: Union[dict[str, Any], Any] = None
    else:
        chunks = []
    async for chunk in self.astream(
        input,
        config,
        stream_mode=stream_mode,
        output_keys=output_keys,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        debug=debug,
        **kwargs,
    ):
        if stream_mode == "values":
            latest = chunk
        else:
            chunks.append(chunk)
    if stream_mode == "values":
        return latest
    else:
        return chunks

get_graph(config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False) -> DrawableGraph

Returns a drawable representation of the computation graph.

Source code in libs/langgraph/langgraph/graph/graph.py
def get_graph(
    self,
    config: Optional[RunnableConfig] = None,
    *,
    xray: Union[int, bool] = False,
) -> DrawableGraph:
    """Returns a drawable representation of the computation graph."""
    graph = DrawableGraph()
    start_nodes: dict[str, DrawableNode] = {
        START: graph.add_node(self.get_input_schema(config), START)
    }
    end_nodes: dict[str, DrawableNode] = {}
    if xray:
        subgraphs = {
            k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph)
        }
    else:
        subgraphs = {}

    def add_edge(
        start: str,
        end: str,
        label: Optional[Hashable] = None,
        conditional: bool = False,
    ) -> None:
        if end == END and END not in end_nodes:
            end_nodes[END] = graph.add_node(self.get_output_schema(config), END)
        return graph.add_edge(
            start_nodes[start],
            end_nodes[end],
            str(label) if label is not None else None,
            conditional,
        )

    for key, n in self.builder.nodes.items():
        node = n.runnable
        metadata = n.metadata or {}
        if key in self.interrupt_before_nodes and key in self.interrupt_after_nodes:
            metadata["__interrupt"] = "before,after"
        elif key in self.interrupt_before_nodes:
            metadata["__interrupt"] = "before"
        elif key in self.interrupt_after_nodes:
            metadata["__interrupt"] = "after"
        if xray and key in subgraphs:
            subgraph = subgraphs[key].get_graph(
                config=config,
                xray=xray - 1
                if isinstance(xray, int) and not isinstance(xray, bool) and xray > 0
                else xray,
            )
            subgraph.trim_first_node()
            subgraph.trim_last_node()
            if len(subgraph.nodes) > 1:
                e, s = graph.extend(subgraph, prefix=key)
                if e is None:
                    raise ValueError(
                        f"Could not extend subgraph '{key}' due to missing entrypoint"
                    )
                if s is not None:
                    start_nodes[key] = s
                end_nodes[key] = e
            else:
                nn = graph.add_node(node, key, metadata=metadata or None)
                start_nodes[key] = nn
                end_nodes[key] = nn
        else:
            nn = graph.add_node(node, key, metadata=metadata or None)
            start_nodes[key] = nn
            end_nodes[key] = nn
    for start, end in sorted(self.builder._all_edges):
        add_edge(start, end)
    for start, branches in self.builder.branches.items():
        default_ends = {
            **{k: k for k in self.builder.nodes if k != start},
            END: END,
        }
        for _, branch in branches.items():
            if branch.ends is not None:
                ends = branch.ends
            elif branch.then is not None:
                ends = {k: k for k in default_ends if k not in (END, branch.then)}
            else:
                ends = cast(dict[Hashable, str], default_ends)
            for label, end in ends.items():
                add_edge(
                    start,
                    end,
                    label if label != end else None,
                    conditional=True,
                )
                if branch.then is not None:
                    add_edge(end, branch.then)
    for key, n in self.builder.nodes.items():
        if n.ends:
            for end in n.ends:
                add_edge(key, end, conditional=True)

    return graph

StateGraph

Bases: Graph

A graph whose nodes communicate by reading and writing to a shared state. The signature of each node is State -> Partial.

Each state key can optionally be annotated with a reducer function that will be used to aggregate the values of that key received from multiple nodes. The signature of a reducer function is (Value, Value) -> Value.

Parameters:

  • state_schema (Type[Any], default: None ) –

    The schema class that defines the state.

  • config_schema (Optional[Type[Any]], default: None ) –

    The schema class that defines the configuration. Use this to expose configurable parameters in your API.

Examples:

>>> from langchain_core.runnables import RunnableConfig
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.checkpoint.memory import MemorySaver
>>> from langgraph.graph import StateGraph
>>>
>>> def reducer(a: list, b: int | None) -> list:
...     if b is not None:
...         return a + [b]
...     return a
>>>
>>> class State(TypedDict):
...     x: Annotated[list, reducer]
>>>
>>> class ConfigSchema(TypedDict):
...     r: float
>>>
>>> graph = StateGraph(State, config_schema=ConfigSchema)
>>>
>>> def node(state: State, config: RunnableConfig) -> dict:
...     r = config["configurable"].get("r", 1.0)
...     x = state["x"][-1]
...     next_value = x * r * (1 - x)
...     return {"x": next_value}
>>>
>>> graph.add_node("A", node)
>>> graph.set_entry_point("A")
>>> graph.set_finish_point("A")
>>> compiled = graph.compile()
>>>
>>> print(compiled.config_specs)
[ConfigurableFieldSpec(id='r', annotation=<class 'float'>, name=None, description=None, default=None, is_shared=False, dependencies=None)]
>>>
>>> step1 = compiled.invoke({"x": 0.5}, {"configurable": {"r": 3.0}})
>>> print(step1)
{'x': [0.5, 0.75]}
Source code in libs/langgraph/langgraph/graph/state.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class StateGraph(Graph):
    """A graph whose nodes communicate by reading and writing to a shared state.
    The signature of each node is State -> Partial<State>.

    Each state key can optionally be annotated with a reducer function that
    will be used to aggregate the values of that key received from multiple nodes.
    The signature of a reducer function is (Value, Value) -> Value.

    Args:
        state_schema (Type[Any]): The schema class that defines the state.
        config_schema (Optional[Type[Any]]): The schema class that defines the configuration.
            Use this to expose configurable parameters in your API.


    Examples:
        >>> from langchain_core.runnables import RunnableConfig
        >>> from typing_extensions import Annotated, TypedDict
        >>> from langgraph.checkpoint.memory import MemorySaver
        >>> from langgraph.graph import StateGraph
        >>>
        >>> def reducer(a: list, b: int | None) -> list:
        ...     if b is not None:
        ...         return a + [b]
        ...     return a
        >>>
        >>> class State(TypedDict):
        ...     x: Annotated[list, reducer]
        >>>
        >>> class ConfigSchema(TypedDict):
        ...     r: float
        >>>
        >>> graph = StateGraph(State, config_schema=ConfigSchema)
        >>>
        >>> def node(state: State, config: RunnableConfig) -> dict:
        ...     r = config["configurable"].get("r", 1.0)
        ...     x = state["x"][-1]
        ...     next_value = x * r * (1 - x)
        ...     return {"x": next_value}
        >>>
        >>> graph.add_node("A", node)
        >>> graph.set_entry_point("A")
        >>> graph.set_finish_point("A")
        >>> compiled = graph.compile()
        >>>
        >>> print(compiled.config_specs)
        [ConfigurableFieldSpec(id='r', annotation=<class 'float'>, name=None, description=None, default=None, is_shared=False, dependencies=None)]
        >>>
        >>> step1 = compiled.invoke({"x": 0.5}, {"configurable": {"r": 3.0}})
        >>> print(step1)
        {'x': [0.5, 0.75]}"""

    nodes: dict[str, StateNodeSpec]  # type: ignore[assignment]
    channels: dict[str, BaseChannel]
    managed: dict[str, ManagedValueSpec]
    schemas: dict[Type[Any], dict[str, Union[BaseChannel, ManagedValueSpec]]]

    def __init__(
        self,
        state_schema: Optional[Type[Any]] = None,
        config_schema: Optional[Type[Any]] = None,
        *,
        input: Optional[Type[Any]] = None,
        output: Optional[Type[Any]] = None,
    ) -> None:
        super().__init__()
        if state_schema is None:
            if input is None or output is None:
                raise ValueError("Must provide state_schema or input and output")
            state_schema = input
            warnings.warn(
                "Initializing StateGraph without state_schema is deprecated. "
                "Please pass in an explicit state_schema instead of just an input and output schema.",
                LangGraphDeprecationWarning,
                stacklevel=2,
            )
        else:
            if input is None:
                input = state_schema
            if output is None:
                output = state_schema
        self.schemas = {}
        self.channels = {}
        self.managed = {}
        self.schema = state_schema
        self.input = input
        self.output = output
        self._add_schema(state_schema)
        self._add_schema(input, allow_managed=False)
        self._add_schema(output, allow_managed=False)
        self.config_schema = config_schema
        self.waiting_edges: set[tuple[tuple[str, ...], str]] = set()

    @property
    def _all_edges(self) -> set[tuple[str, str]]:
        return self.edges | {
            (start, end) for starts, end in self.waiting_edges for start in starts
        }

    def _add_schema(self, schema: Type[Any], /, allow_managed: bool = True) -> None:
        if schema not in self.schemas:
            _warn_invalid_state_schema(schema)
            channels, managed = _get_channels(schema)
            if managed and not allow_managed:
                names = ", ".join(managed)
                schema_name = getattr(schema, "__name__", "")
                raise ValueError(
                    f"Invalid managed channels detected in {schema_name}: {names}."
                    " Managed channels are not permitted in Input/Output schema."
                )
            self.schemas[schema] = {**channels, **managed}
            for key, channel in channels.items():
                if key in self.channels:
                    if self.channels[key] != channel:
                        if isinstance(channel, LastValue):
                            pass
                        else:
                            raise ValueError(
                                f"Channel '{key}' already exists with a different type"
                            )
                else:
                    self.channels[key] = channel
            for key, managed in managed.items():
                if key in self.managed:
                    if self.managed[key] != managed:
                        raise ValueError(
                            f"Managed value '{key}' already exists with a different type"
                        )
                else:
                    self.managed[key] = managed

    @overload
    def add_node(
        self,
        node: RunnableLike,
        *,
        metadata: Optional[dict[str, Any]] = None,
        input: Optional[Type[Any]] = None,
        retry: Optional[RetryPolicy] = None,
    ) -> Self:
        """Adds a new node to the state graph.
        Will take the name of the function/runnable as the node name.

        Args:
            node (RunnableLike): The function or runnable this node will run.

        Raises:
            ValueError: If the key is already being used as a state key.

        Returns:
            StateGraph
        """
        ...

    @overload
    def add_node(
        self,
        node: str,
        action: RunnableLike,
        *,
        metadata: Optional[dict[str, Any]] = None,
        input: Optional[Type[Any]] = None,
        retry: Optional[RetryPolicy] = None,
    ) -> Self:
        """Adds a new node to the state graph.

        Args:
            node (str): The key of the node.
            action (RunnableLike): The action associated with the node.

        Raises:
            ValueError: If the key is already being used as a state key.

        Returns:
            StateGraph
        """
        ...

    def add_node(
        self,
        node: Union[str, RunnableLike],
        action: Optional[RunnableLike] = None,
        *,
        metadata: Optional[dict[str, Any]] = None,
        input: Optional[Type[Any]] = None,
        retry: Optional[RetryPolicy] = None,
    ) -> Self:
        """Adds a new node to the state graph.

        Will take the name of the function/runnable as the node name.

        Args:
            node (Union[str, RunnableLike)]: The function or runnable this node will run.
            action (Optional[RunnableLike]): The action associated with the node. (default: None)
            metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None)
            input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema)
            retry (Optional[RetryPolicy]): The policy for retrying the node. (default: None)
        Raises:
            ValueError: If the key is already being used as a state key.


        Examples:
            ```pycon
            >>> from langgraph.graph import START, StateGraph
            ...
            >>> def my_node(state, config):
            ...    return {"x": state["x"] + 1}
            ...
            >>> builder = StateGraph(dict)
            >>> builder.add_node(my_node)  # node name will be 'my_node'
            >>> builder.add_edge(START, "my_node")
            >>> graph = builder.compile()
            >>> graph.invoke({"x": 1})
            {'x': 2}
            ```
            Customize the name:

            ```pycon
            >>> builder = StateGraph(dict)
            >>> builder.add_node("my_fair_node", my_node)
            >>> builder.add_edge(START, "my_fair_node")
            >>> graph = builder.compile()
            >>> graph.invoke({"x": 1})
            {'x': 2}
            ```

        Returns:
            StateGraph
        """
        if not isinstance(node, str):
            action = node
            if isinstance(action, Runnable):
                node = action.get_name()
            else:
                node = getattr(action, "__name__", action.__class__.__name__)
            if node is None:
                raise ValueError(
                    "Node name must be provided if action is not a function"
                )
        if node in self.channels:
            raise ValueError(f"'{node}' is already being used as a state key")
        if self.compiled:
            logger.warning(
                "Adding a node to a graph that has already been compiled. This will "
                "not be reflected in the compiled graph."
            )
        if not isinstance(node, str):
            action = node
            node = cast(str, getattr(action, "name", getattr(action, "__name__", None)))
            if node is None:
                raise ValueError(
                    "Node name must be provided if action is not a function"
                )
        if action is None:
            raise RuntimeError
        if node in self.nodes:
            raise ValueError(f"Node `{node}` already present.")
        if node == END or node == START:
            raise ValueError(f"Node `{node}` is reserved.")

        for character in (NS_SEP, NS_END):
            if character in cast(str, node):
                raise ValueError(
                    f"'{character}' is a reserved character and is not allowed in the node names."
                )

        ends = EMPTY_SEQ
        try:
            if (isfunction(action) or ismethod(getattr(action, "__call__", None))) and (
                hints := get_type_hints(getattr(action, "__call__"))
                or get_type_hints(action)
            ):
                if input is None:
                    first_parameter_name = next(
                        iter(
                            inspect.signature(
                                cast(FunctionType, action)
                            ).parameters.keys()
                        )
                    )
                    if input_hint := hints.get(first_parameter_name):
                        if isinstance(input_hint, type) and get_type_hints(input_hint):
                            input = input_hint
                if (
                    (rtn := hints.get("return"))
                    and get_origin(rtn) in (Command, GraphCommand)
                    and (rargs := get_args(rtn))
                    and get_origin(rargs[0]) is Literal
                    and (vals := get_args(rargs[0]))
                ):
                    ends = vals
        except (TypeError, StopIteration):
            pass
        if input is not None:
            self._add_schema(input)
        self.nodes[cast(str, node)] = StateNodeSpec(
            coerce_to_runnable(action, name=cast(str, node), trace=False),
            metadata,
            input=input or self.schema,
            retry_policy=retry,
            ends=ends,
        )
        return self

    def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:
        """Adds a directed edge from the start node to the end node.

        If the graph transitions to the start_key node, it will always transition to the end_key node next.

        Args:
            start_key (Union[str, list[str]]): The key(s) of the start node(s) of the edge.
            end_key (str): The key of the end node of the edge.

        Raises:
            ValueError: If the start key is 'END' or if the start key or end key is not present in the graph.

        Returns:
            StateGraph
        """
        if isinstance(start_key, str):
            return super().add_edge(start_key, end_key)

        if self.compiled:
            logger.warning(
                "Adding an edge to a graph that has already been compiled. This will "
                "not be reflected in the compiled graph."
            )
        for start in start_key:
            if start == END:
                raise ValueError("END cannot be a start node")
            if start not in self.nodes:
                raise ValueError(f"Need to add_node `{start}` first")
        if end_key == START:
            raise ValueError("START cannot be an end node")
        if end_key != END and end_key not in self.nodes:
            raise ValueError(f"Need to add_node `{end_key}` first")

        self.waiting_edges.add((tuple(start_key), end_key))
        return self

    def add_sequence(
        self,
        nodes: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]],
    ) -> Self:
        """Add a sequence of nodes that will be executed in the provided order.

        Args:
            nodes: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples.
                If no names are provided, the name will be inferred from the node object (e.g. a runnable or a callable name).
                Each node will be executed in the order provided.

        Raises:
            ValueError: if the sequence is empty.
            ValueError: if the sequence contains duplicate node names.

        Returns:
            StateGraph
        """
        if len(nodes) < 1:
            raise ValueError("Sequence requires at least one node.")

        previous_name: Optional[str] = None
        for node in nodes:
            if isinstance(node, tuple) and len(node) == 2:
                name, node = node
            else:
                name = _get_node_name(node)

            if name in self.nodes:
                raise ValueError(
                    f"Node names must be unique: node with the name '{name}' already exists. "
                    "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
                )

            self.add_node(name, node)
            if previous_name is not None:
                self.add_edge(previous_name, name)

            previous_name = name

        return self

    def compile(
        self,
        checkpointer: Checkpointer = None,
        *,
        store: Optional[BaseStore] = None,
        interrupt_before: Optional[Union[All, list[str]]] = None,
        interrupt_after: Optional[Union[All, list[str]]] = None,
        debug: bool = False,
    ) -> "CompiledStateGraph":
        """Compiles the state graph into a `CompiledGraph` object.

        The compiled graph implements the `Runnable` interface and can be invoked,
        streamed, batched, and run asynchronously.

        Args:
            checkpointer (Optional[Union[Checkpointer, Literal[False]]]): A checkpoint saver object or flag.
                If provided, this Checkpointer serves as a fully versioned "short-term memory" for the graph,
                allowing it to be paused, resumed, and replayed from any point.
                If None, it may inherit the parent graph's checkpointer when used as a subgraph.
                If False, it will not use or inherit any checkpointer.
            interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before.
            interrupt_after (Optional[Sequence[str]]): An optional list of node names to interrupt after.
            debug (bool): A flag indicating whether to enable debug mode.

        Returns:
            CompiledStateGraph: The compiled state graph.
        """
        # assign default values
        interrupt_before = interrupt_before or []
        interrupt_after = interrupt_after or []

        # validate the graph
        self.validate(
            interrupt=(
                (interrupt_before if interrupt_before != "*" else []) + interrupt_after
                if interrupt_after != "*"
                else []
            )
        )

        # prepare output channels
        output_channels = (
            "__root__"
            if len(self.schemas[self.output]) == 1
            and "__root__" in self.schemas[self.output]
            else [
                key
                for key, val in self.schemas[self.output].items()
                if not is_managed_value(val)
            ]
        )
        stream_channels = (
            "__root__"
            if len(self.channels) == 1 and "__root__" in self.channels
            else [
                key for key, val in self.channels.items() if not is_managed_value(val)
            ]
        )

        compiled = CompiledStateGraph(
            builder=self,
            config_type=self.config_schema,
            nodes={},
            channels={
                **self.channels,
                **self.managed,
                START: EphemeralValue(self.input),
            },
            input_channels=START,
            stream_mode="updates",
            output_channels=output_channels,
            stream_channels=stream_channels,
            checkpointer=checkpointer,
            interrupt_before_nodes=interrupt_before,
            interrupt_after_nodes=interrupt_after,
            auto_validate=False,
            debug=debug,
            store=store,
        )

        compiled.attach_node(START, None)
        for key, node in self.nodes.items():
            compiled.attach_node(key, node)

        for key, node in self.nodes.items():
            compiled.attach_branch(key, SELF, CONTROL_BRANCH, with_reader=False)

        for start, end in self.edges:
            compiled.attach_edge(start, end)

        for starts, end in self.waiting_edges:
            compiled.attach_edge(starts, end)

        for start, branches in self.branches.items():
            for name, branch in branches.items():
                compiled.attach_branch(start, name, branch)

        return compiled.validate()

add_conditional_edges(source: str, path: Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None) -> Self

Add a conditional edge from the starting node to any number of destination nodes.

Parameters:

  • source (str) –

    The starting node. This conditional edge will run when exiting this node.

  • path (Union[Callable, Runnable]) –

    The callable that determines the next node or nodes. If not specifying path_map it should return one or more nodes. If it returns END, the graph will stop execution.

  • path_map (Optional[dict[Hashable, str]], default: None ) –

    Optional mapping of paths to node names. If omitted the paths returned by path should be node names.

  • then (Optional[str], default: None ) –

    The name of a node to execute after the nodes selected by path.

Returns:

  • Self

    None

Without typehints on the path function's return value (e.g., -> Literal["foo", "__end__"]:)

or a path_map, the graph visualization assumes the edge could transition to any node in the graph.

Source code in libs/langgraph/langgraph/graph/graph.py
def add_conditional_edges(
    self,
    source: str,
    path: Union[
        Callable[..., Union[Hashable, list[Hashable]]],
        Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
        Runnable[Any, Union[Hashable, list[Hashable]]],
    ],
    path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
    then: Optional[str] = None,
) -> Self:
    """Add a conditional edge from the starting node to any number of destination nodes.

    Args:
        source (str): The starting node. This conditional edge will run when
            exiting this node.
        path (Union[Callable, Runnable]): The callable that determines the next
            node or nodes. If not specifying `path_map` it should return one or
            more nodes. If it returns END, the graph will stop execution.
        path_map (Optional[dict[Hashable, str]]): Optional mapping of paths to node
            names. If omitted the paths returned by `path` should be node names.
        then (Optional[str]): The name of a node to execute after the nodes
            selected by `path`.

    Returns:
        None

    Note: Without typehints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`)
        or a path_map, the graph visualization assumes the edge could transition to any node in the graph.

    """  # noqa: E501
    if self.compiled:
        logger.warning(
            "Adding an edge to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    # coerce path_map to a dictionary
    try:
        if isinstance(path_map, dict):
            path_map_ = path_map.copy()
        elif isinstance(path_map, list):
            path_map_ = {name: name for name in path_map}
        elif isinstance(path, Runnable):
            path_map_ = None
        elif rtn_type := get_type_hints(path.__call__).get(  # type: ignore[operator]
            "return"
        ) or get_type_hints(path).get("return"):
            if get_origin(rtn_type) is Literal:
                path_map_ = {name: name for name in get_args(rtn_type)}
            else:
                path_map_ = None
        else:
            path_map_ = None
    except Exception:
        path_map_ = None
    # find a name for the condition
    path = coerce_to_runnable(path, name=None, trace=True)
    name = path.name or "condition"
    # validate the condition
    if name in self.branches[source]:
        raise ValueError(
            f"Branch with name `{path.name}` already exists for node " f"`{source}`"
        )
    # save it
    self.branches[source][name] = Branch(path, path_map_, then)
    return self

set_entry_point(key: str) -> Self

Specifies the first node to be called in the graph.

Equivalent to calling add_edge(START, key).

Parameters:

  • key (str) –

    The key of the node to set as the entry point.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_entry_point(self, key: str) -> Self:
    """Specifies the first node to be called in the graph.

    Equivalent to calling `add_edge(START, key)`.

    Parameters:
        key (str): The key of the node to set as the entry point.

    Returns:
        None
    """
    return self.add_edge(START, key)

set_conditional_entry_point(path: Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None) -> Self

Sets a conditional entry point in the graph.

Parameters:

  • path (Union[Callable, Runnable]) –

    The callable that determines the next node or nodes. If not specifying path_map it should return one or more nodes. If it returns END, the graph will stop execution.

  • path_map (Optional[dict[str, str]], default: None ) –

    Optional mapping of paths to node names. If omitted the paths returned by path should be node names.

  • then (Optional[str], default: None ) –

    The name of a node to execute after the nodes selected by path.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_conditional_entry_point(
    self,
    path: Union[
        Callable[..., Union[Hashable, list[Hashable]]],
        Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
        Runnable[Any, Union[Hashable, list[Hashable]]],
    ],
    path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
    then: Optional[str] = None,
) -> Self:
    """Sets a conditional entry point in the graph.

    Args:
        path (Union[Callable, Runnable]): The callable that determines the next
            node or nodes. If not specifying `path_map` it should return one or
            more nodes. If it returns END, the graph will stop execution.
        path_map (Optional[dict[str, str]]): Optional mapping of paths to node
            names. If omitted the paths returned by `path` should be node names.
        then (Optional[str]): The name of a node to execute after the nodes
            selected by `path`.

    Returns:
        None
    """
    return self.add_conditional_edges(START, path, path_map, then)

set_finish_point(key: str) -> Self

Marks a node as a finish point of the graph.

If the graph reaches this node, it will cease execution.

Parameters:

  • key (str) –

    The key of the node to set as the finish point.

Returns:

  • Self

    None

Source code in libs/langgraph/langgraph/graph/graph.py
def set_finish_point(self, key: str) -> Self:
    """Marks a node as a finish point of the graph.

    If the graph reaches this node, it will cease execution.

    Parameters:
        key (str): The key of the node to set as the finish point.

    Returns:
        None
    """
    return self.add_edge(key, END)

add_node(node: Union[str, RunnableLike], action: Optional[RunnableLike] = None, *, metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None) -> Self

Adds a new node to the state graph.

Will take the name of the function/runnable as the node name.

Parameters:

  • node (Union[str, RunnableLike)]) –

    The function or runnable this node will run.

  • action (Optional[RunnableLike], default: None ) –

    The action associated with the node. (default: None)

  • metadata (Optional[dict[str, Any]], default: None ) –

    The metadata associated with the node. (default: None)

  • input (Optional[Type[Any]], default: None ) –

    The input schema for the node. (default: the graph's input schema)

  • retry (Optional[RetryPolicy], default: None ) –

    The policy for retrying the node. (default: None)

Raises: ValueError: If the key is already being used as a state key.

Examples:

>>> from langgraph.graph import START, StateGraph
...
>>> def my_node(state, config):
...    return {"x": state["x"] + 1}
...
>>> builder = StateGraph(dict)
>>> builder.add_node(my_node)  # node name will be 'my_node'
>>> builder.add_edge(START, "my_node")
>>> graph = builder.compile()
>>> graph.invoke({"x": 1})
{'x': 2}
Customize the name:

>>> builder = StateGraph(dict)
>>> builder.add_node("my_fair_node", my_node)
>>> builder.add_edge(START, "my_fair_node")
>>> graph = builder.compile()
>>> graph.invoke({"x": 1})
{'x': 2}

Returns:

  • Self

    StateGraph

Source code in libs/langgraph/langgraph/graph/state.py
def add_node(
    self,
    node: Union[str, RunnableLike],
    action: Optional[RunnableLike] = None,
    *,
    metadata: Optional[dict[str, Any]] = None,
    input: Optional[Type[Any]] = None,
    retry: Optional[RetryPolicy] = None,
) -> Self:
    """Adds a new node to the state graph.

    Will take the name of the function/runnable as the node name.

    Args:
        node (Union[str, RunnableLike)]: The function or runnable this node will run.
        action (Optional[RunnableLike]): The action associated with the node. (default: None)
        metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None)
        input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema)
        retry (Optional[RetryPolicy]): The policy for retrying the node. (default: None)
    Raises:
        ValueError: If the key is already being used as a state key.


    Examples:
        ```pycon
        >>> from langgraph.graph import START, StateGraph
        ...
        >>> def my_node(state, config):
        ...    return {"x": state["x"] + 1}
        ...
        >>> builder = StateGraph(dict)
        >>> builder.add_node(my_node)  # node name will be 'my_node'
        >>> builder.add_edge(START, "my_node")
        >>> graph = builder.compile()
        >>> graph.invoke({"x": 1})
        {'x': 2}
        ```
        Customize the name:

        ```pycon
        >>> builder = StateGraph(dict)
        >>> builder.add_node("my_fair_node", my_node)
        >>> builder.add_edge(START, "my_fair_node")
        >>> graph = builder.compile()
        >>> graph.invoke({"x": 1})
        {'x': 2}
        ```

    Returns:
        StateGraph
    """
    if not isinstance(node, str):
        action = node
        if isinstance(action, Runnable):
            node = action.get_name()
        else:
            node = getattr(action, "__name__", action.__class__.__name__)
        if node is None:
            raise ValueError(
                "Node name must be provided if action is not a function"
            )
    if node in self.channels:
        raise ValueError(f"'{node}' is already being used as a state key")
    if self.compiled:
        logger.warning(
            "Adding a node to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    if not isinstance(node, str):
        action = node
        node = cast(str, getattr(action, "name", getattr(action, "__name__", None)))
        if node is None:
            raise ValueError(
                "Node name must be provided if action is not a function"
            )
    if action is None:
        raise RuntimeError
    if node in self.nodes:
        raise ValueError(f"Node `{node}` already present.")
    if node == END or node == START:
        raise ValueError(f"Node `{node}` is reserved.")

    for character in (NS_SEP, NS_END):
        if character in cast(str, node):
            raise ValueError(
                f"'{character}' is a reserved character and is not allowed in the node names."
            )

    ends = EMPTY_SEQ
    try:
        if (isfunction(action) or ismethod(getattr(action, "__call__", None))) and (
            hints := get_type_hints(getattr(action, "__call__"))
            or get_type_hints(action)
        ):
            if input is None:
                first_parameter_name = next(
                    iter(
                        inspect.signature(
                            cast(FunctionType, action)
                        ).parameters.keys()
                    )
                )
                if input_hint := hints.get(first_parameter_name):
                    if isinstance(input_hint, type) and get_type_hints(input_hint):
                        input = input_hint
            if (
                (rtn := hints.get("return"))
                and get_origin(rtn) in (Command, GraphCommand)
                and (rargs := get_args(rtn))
                and get_origin(rargs[0]) is Literal
                and (vals := get_args(rargs[0]))
            ):
                ends = vals
    except (TypeError, StopIteration):
        pass
    if input is not None:
        self._add_schema(input)
    self.nodes[cast(str, node)] = StateNodeSpec(
        coerce_to_runnable(action, name=cast(str, node), trace=False),
        metadata,
        input=input or self.schema,
        retry_policy=retry,
        ends=ends,
    )
    return self

add_edge(start_key: Union[str, list[str]], end_key: str) -> Self

Adds a directed edge from the start node to the end node.

If the graph transitions to the start_key node, it will always transition to the end_key node next.

Parameters:

  • start_key (Union[str, list[str]]) –

    The key(s) of the start node(s) of the edge.

  • end_key (str) –

    The key of the end node of the edge.

Raises:

  • ValueError

    If the start key is 'END' or if the start key or end key is not present in the graph.

Returns:

  • Self

    StateGraph

Source code in libs/langgraph/langgraph/graph/state.py
def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:
    """Adds a directed edge from the start node to the end node.

    If the graph transitions to the start_key node, it will always transition to the end_key node next.

    Args:
        start_key (Union[str, list[str]]): The key(s) of the start node(s) of the edge.
        end_key (str): The key of the end node of the edge.

    Raises:
        ValueError: If the start key is 'END' or if the start key or end key is not present in the graph.

    Returns:
        StateGraph
    """
    if isinstance(start_key, str):
        return super().add_edge(start_key, end_key)

    if self.compiled:
        logger.warning(
            "Adding an edge to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    for start in start_key:
        if start == END:
            raise ValueError("END cannot be a start node")
        if start not in self.nodes:
            raise ValueError(f"Need to add_node `{start}` first")
    if end_key == START:
        raise ValueError("START cannot be an end node")
    if end_key != END and end_key not in self.nodes:
        raise ValueError(f"Need to add_node `{end_key}` first")

    self.waiting_edges.add((tuple(start_key), end_key))
    return self

add_sequence(nodes: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]]) -> Self

Add a sequence of nodes that will be executed in the provided order.

Parameters:

  • nodes (Sequence[Union[RunnableLike, tuple[str, RunnableLike]]]) –

    A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. If no names are provided, the name will be inferred from the node object (e.g. a runnable or a callable name). Each node will be executed in the order provided.

Raises:

  • ValueError

    if the sequence is empty.

  • ValueError

    if the sequence contains duplicate node names.

Returns:

  • Self

    StateGraph

Source code in libs/langgraph/langgraph/graph/state.py
def add_sequence(
    self,
    nodes: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]],
) -> Self:
    """Add a sequence of nodes that will be executed in the provided order.

    Args:
        nodes: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples.
            If no names are provided, the name will be inferred from the node object (e.g. a runnable or a callable name).
            Each node will be executed in the order provided.

    Raises:
        ValueError: if the sequence is empty.
        ValueError: if the sequence contains duplicate node names.

    Returns:
        StateGraph
    """
    if len(nodes) < 1:
        raise ValueError("Sequence requires at least one node.")

    previous_name: Optional[str] = None
    for node in nodes:
        if isinstance(node, tuple) and len(node) == 2:
            name, node = node
        else:
            name = _get_node_name(node)

        if name in self.nodes:
            raise ValueError(
                f"Node names must be unique: node with the name '{name}' already exists. "
                "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
            )

        self.add_node(name, node)
        if previous_name is not None:
            self.add_edge(previous_name, name)

        previous_name = name

    return self

compile(checkpointer: Checkpointer = None, *, store: Optional[BaseStore] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, debug: bool = False) -> CompiledStateGraph

Compiles the state graph into a CompiledGraph object.

The compiled graph implements the Runnable interface and can be invoked, streamed, batched, and run asynchronously.

Parameters:

  • checkpointer (Optional[Union[Checkpointer, Literal[False]]], default: None ) –

    A checkpoint saver object or flag. If provided, this Checkpointer serves as a fully versioned "short-term memory" for the graph, allowing it to be paused, resumed, and replayed from any point. If None, it may inherit the parent graph's checkpointer when used as a subgraph. If False, it will not use or inherit any checkpointer.

  • interrupt_before (Optional[Sequence[str]], default: None ) –

    An optional list of node names to interrupt before.

  • interrupt_after (Optional[Sequence[str]], default: None ) –

    An optional list of node names to interrupt after.

  • debug (bool, default: False ) –

    A flag indicating whether to enable debug mode.

Returns:

  • CompiledStateGraph ( CompiledStateGraph ) –

    The compiled state graph.

Source code in libs/langgraph/langgraph/graph/state.py
def compile(
    self,
    checkpointer: Checkpointer = None,
    *,
    store: Optional[BaseStore] = None,
    interrupt_before: Optional[Union[All, list[str]]] = None,
    interrupt_after: Optional[Union[All, list[str]]] = None,
    debug: bool = False,
) -> "CompiledStateGraph":
    """Compiles the state graph into a `CompiledGraph` object.

    The compiled graph implements the `Runnable` interface and can be invoked,
    streamed, batched, and run asynchronously.

    Args:
        checkpointer (Optional[Union[Checkpointer, Literal[False]]]): A checkpoint saver object or flag.
            If provided, this Checkpointer serves as a fully versioned "short-term memory" for the graph,
            allowing it to be paused, resumed, and replayed from any point.
            If None, it may inherit the parent graph's checkpointer when used as a subgraph.
            If False, it will not use or inherit any checkpointer.
        interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before.
        interrupt_after (Optional[Sequence[str]]): An optional list of node names to interrupt after.
        debug (bool): A flag indicating whether to enable debug mode.

    Returns:
        CompiledStateGraph: The compiled state graph.
    """
    # assign default values
    interrupt_before = interrupt_before or []
    interrupt_after = interrupt_after or []

    # validate the graph
    self.validate(
        interrupt=(
            (interrupt_before if interrupt_before != "*" else []) + interrupt_after
            if interrupt_after != "*"
            else []
        )
    )

    # prepare output channels
    output_channels = (
        "__root__"
        if len(self.schemas[self.output]) == 1
        and "__root__" in self.schemas[self.output]
        else [
            key
            for key, val in self.schemas[self.output].items()
            if not is_managed_value(val)
        ]
    )
    stream_channels = (
        "__root__"
        if len(self.channels) == 1 and "__root__" in self.channels
        else [
            key for key, val in self.channels.items() if not is_managed_value(val)
        ]
    )

    compiled = CompiledStateGraph(
        builder=self,
        config_type=self.config_schema,
        nodes={},
        channels={
            **self.channels,
            **self.managed,
            START: EphemeralValue(self.input),
        },
        input_channels=START,
        stream_mode="updates",
        output_channels=output_channels,
        stream_channels=stream_channels,
        checkpointer=checkpointer,
        interrupt_before_nodes=interrupt_before,
        interrupt_after_nodes=interrupt_after,
        auto_validate=False,
        debug=debug,
        store=store,
    )

    compiled.attach_node(START, None)
    for key, node in self.nodes.items():
        compiled.attach_node(key, node)

    for key, node in self.nodes.items():
        compiled.attach_branch(key, SELF, CONTROL_BRANCH, with_reader=False)

    for start, end in self.edges:
        compiled.attach_edge(start, end)

    for starts, end in self.waiting_edges:
        compiled.attach_edge(starts, end)

    for start, branches in self.branches.items():
        for name, branch in branches.items():
            compiled.attach_branch(start, name, branch)

    return compiled.validate()

CompiledStateGraph

Bases: CompiledGraph

Source code in libs/langgraph/langgraph/graph/state.py
class CompiledStateGraph(CompiledGraph):
    builder: StateGraph

    def get_input_schema(
        self, config: Optional[RunnableConfig] = None
    ) -> type[BaseModel]:
        return _get_schema(
            typ=self.builder.input,
            schemas=self.builder.schemas,
            channels=self.builder.channels,
            name=self.get_name("Input"),
        )

    def get_output_schema(
        self, config: Optional[RunnableConfig] = None
    ) -> type[BaseModel]:
        return _get_schema(
            typ=self.builder.output,
            schemas=self.builder.schemas,
            channels=self.builder.channels,
            name=self.get_name("Output"),
        )

    def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
        if key == START:
            output_keys = [
                k
                for k, v in self.builder.schemas[self.builder.input].items()
                if not is_managed_value(v)
            ]
        else:
            output_keys = list(self.builder.channels) + [
                k
                for k, v in self.builder.managed.items()
                if is_writable_managed_value(v)
            ]

        def _get_root(input: Any) -> Any:
            if isinstance(input, Command):
                return input.update
            else:
                return input

        # to avoid name collision below
        node_key = key

        def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
            if input is None:
                return SKIP_WRITE
            elif isinstance(input, dict):
                if all(k not in output_keys for k in input):
                    raise InvalidUpdateError(
                        f"Expected node {node_key} to update at least one of {output_keys}, got {input}"
                    )
                return input.get(key, SKIP_WRITE)
            elif isinstance(input, Command):
                return _get_state_key(input.update, key=key)
            elif get_type_hints(type(input)):
                value = getattr(input, key, SKIP_WRITE)
                return value if value is not None else SKIP_WRITE
            else:
                msg = create_error_message(
                    message=f"Expected dict, got {input}",
                    error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE,
                )
                raise InvalidUpdateError(msg)

        # state updaters
        write_entries = (
            [ChannelWriteEntry("__root__", skip_none=True, mapper=_get_root)]
            if output_keys == ["__root__"]
            else [
                ChannelWriteEntry(key, mapper=partial(_get_state_key, key=key))
                for key in output_keys
            ]
        )

        # add node and output channel
        if key == START:
            self.nodes[key] = PregelNode(
                tags=[TAG_HIDDEN],
                triggers=[START],
                channels=[START],
                writers=[
                    ChannelWrite(
                        write_entries,
                        tags=[TAG_HIDDEN],
                        require_at_least_one_of=output_keys,
                    ),
                ],
            )
        elif node is not None:
            input_schema = node.input if node else self.builder.schema
            input_values = {k: k for k in self.builder.schemas[input_schema]}
            is_single_input = len(input_values) == 1 and "__root__" in input_values

            self.channels[key] = EphemeralValue(Any, guard=False)
            self.nodes[key] = PregelNode(
                triggers=[],
                # read state keys and managed values
                channels=(list(input_values) if is_single_input else input_values),
                # coerce state dict to schema class (eg. pydantic model)
                mapper=(
                    None
                    if is_single_input or issubclass(input_schema, dict)
                    else partial(_coerce_state, input_schema)
                ),
                writers=[
                    # publish to this channel and state keys
                    ChannelWrite(
                        [ChannelWriteEntry(key, key)] + write_entries,
                        tags=[TAG_HIDDEN],
                    ),
                ],
                metadata=node.metadata,
                retry_policy=node.retry_policy,
                bound=node.runnable,
            )
        else:
            raise RuntimeError

    def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None:
        if isinstance(starts, str):
            if starts == START:
                channel_name = f"start:{end}"
                # register channel
                self.channels[channel_name] = EphemeralValue(Any)
                # subscribe to channel
                self.nodes[end].triggers.append(channel_name)
                # publish to channel
                self.nodes[START] |= ChannelWrite(
                    [ChannelWriteEntry(channel_name, START)], tags=[TAG_HIDDEN]
                )
            elif end != END:
                # subscribe to start channel
                self.nodes[end].triggers.append(starts)
        elif end != END:
            channel_name = f"join:{'+'.join(starts)}:{end}"
            # register channel
            self.channels[channel_name] = NamedBarrierValue(str, set(starts))
            # subscribe to channel
            self.nodes[end].triggers.append(channel_name)
            # publish to channel
            for start in starts:
                self.nodes[start] |= ChannelWrite(
                    [ChannelWriteEntry(channel_name, start)], tags=[TAG_HIDDEN]
                )

    def attach_branch(
        self, start: str, name: str, branch: Branch, *, with_reader: bool = True
    ) -> None:
        def branch_writer(
            packets: Sequence[Union[str, Send]], config: RunnableConfig
        ) -> None:
            if filtered := [p for p in packets if p != END]:
                writes = [
                    (
                        ChannelWriteEntry(f"branch:{start}:{name}:{p}", start)
                        if not isinstance(p, Send)
                        else p
                    )
                    for p in filtered
                ]
                if branch.then and branch.then != END:
                    writes.append(
                        ChannelWriteEntry(
                            f"branch:{start}:{name}::then",
                            WaitForNames(
                                {p.node if isinstance(p, Send) else p for p in filtered}
                            ),
                        )
                    )
                ChannelWrite.do_write(
                    config, cast(Sequence[Union[Send, ChannelWriteEntry]], writes)
                )

        # attach branch publisher
        schema = (
            self.builder.nodes[start].input
            if start in self.builder.nodes
            else self.builder.schema
        )
        self.nodes[start] |= branch.run(
            branch_writer,
            _get_state_reader(self.builder, schema) if with_reader else None,
        )

        # attach branch subscribers
        ends = (
            branch.ends.values()
            if branch.ends
            else [node for node in self.builder.nodes if node != branch.then]
        )
        for end in ends:
            if end != END:
                channel_name = f"branch:{start}:{name}:{end}"
                self.channels[channel_name] = EphemeralValue(Any, guard=False)
                self.nodes[end].triggers.append(channel_name)

        # attach then subscriber
        if branch.then and branch.then != END:
            channel_name = f"branch:{start}:{name}::then"
            self.channels[channel_name] = DynamicBarrierValue(str)
            self.nodes[branch.then].triggers.append(channel_name)
            for end in ends:
                if end != END:
                    self.nodes[end] |= ChannelWrite(
                        [ChannelWriteEntry(channel_name, end)], tags=[TAG_HIDDEN]
                    )

stream_mode: StreamMode = stream_mode class-attribute instance-attribute

Mode to stream output, defaults to 'values'.

stream_channels: Optional[Union[str, Sequence[str]]] = stream_channels class-attribute instance-attribute

Channels to stream, defaults to all channels not in reserved channels

step_timeout: Optional[float] = step_timeout class-attribute instance-attribute

Maximum time to wait for a step to complete, in seconds. Defaults to None.

debug: bool = debug if debug is not None else get_debug() instance-attribute

Whether to print debug information during execution. Defaults to False.

checkpointer: Checkpointer = checkpointer class-attribute instance-attribute

Checkpointer used to save and load graph state. Defaults to None.

store: Optional[BaseStore] = store class-attribute instance-attribute

Memory store to use for SharedValues. Defaults to None.

retry_policy: Optional[RetryPolicy] = retry_policy class-attribute instance-attribute

Retry policy to use when running tasks. Set to None to disable.

get_graph(config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False) -> DrawableGraph

Returns a drawable representation of the computation graph.

Source code in libs/langgraph/langgraph/graph/graph.py
def get_graph(
    self,
    config: Optional[RunnableConfig] = None,
    *,
    xray: Union[int, bool] = False,
) -> DrawableGraph:
    """Returns a drawable representation of the computation graph."""
    graph = DrawableGraph()
    start_nodes: dict[str, DrawableNode] = {
        START: graph.add_node(self.get_input_schema(config), START)
    }
    end_nodes: dict[str, DrawableNode] = {}
    if xray:
        subgraphs = {
            k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph)
        }
    else:
        subgraphs = {}

    def add_edge(
        start: str,
        end: str,
        label: Optional[Hashable] = None,
        conditional: bool = False,
    ) -> None:
        if end == END and END not in end_nodes:
            end_nodes[END] = graph.add_node(self.get_output_schema(config), END)
        return graph.add_edge(
            start_nodes[start],
            end_nodes[end],
            str(label) if label is not None else None,
            conditional,
        )

    for key, n in self.builder.nodes.items():
        node = n.runnable
        metadata = n.metadata or {}
        if key in self.interrupt_before_nodes and key in self.interrupt_after_nodes:
            metadata["__interrupt"] = "before,after"
        elif key in self.interrupt_before_nodes:
            metadata["__interrupt"] = "before"
        elif key in self.interrupt_after_nodes:
            metadata["__interrupt"] = "after"
        if xray and key in subgraphs:
            subgraph = subgraphs[key].get_graph(
                config=config,
                xray=xray - 1
                if isinstance(xray, int) and not isinstance(xray, bool) and xray > 0
                else xray,
            )
            subgraph.trim_first_node()
            subgraph.trim_last_node()
            if len(subgraph.nodes) > 1:
                e, s = graph.extend(subgraph, prefix=key)
                if e is None:
                    raise ValueError(
                        f"Could not extend subgraph '{key}' due to missing entrypoint"
                    )
                if s is not None:
                    start_nodes[key] = s
                end_nodes[key] = e
            else:
                nn = graph.add_node(node, key, metadata=metadata or None)
                start_nodes[key] = nn
                end_nodes[key] = nn
        else:
            nn = graph.add_node(node, key, metadata=metadata or None)
            start_nodes[key] = nn
            end_nodes[key] = nn
    for start, end in sorted(self.builder._all_edges):
        add_edge(start, end)
    for start, branches in self.builder.branches.items():
        default_ends = {
            **{k: k for k in self.builder.nodes if k != start},
            END: END,
        }
        for _, branch in branches.items():
            if branch.ends is not None:
                ends = branch.ends
            elif branch.then is not None:
                ends = {k: k for k in default_ends if k not in (END, branch.then)}
            else:
                ends = cast(dict[Hashable, str], default_ends)
            for label, end in ends.items():
                add_edge(
                    start,
                    end,
                    label if label != end else None,
                    conditional=True,
                )
                if branch.then is not None:
                    add_edge(end, branch.then)
    for key, n in self.builder.nodes.items():
        if n.ends:
            for end in n.ends:
                add_edge(key, end, conditional=True)

    return graph

get_state(config: RunnableConfig, *, subgraphs: bool = False) -> StateSnapshot

Get the current state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def get_state(
    self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
    """Get the current state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return pregel.get_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                subgraphs=subgraphs,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(self.config, config) if self.config else config
    saved = checkpointer.get_tuple(config)
    return self._prepare_state_snapshot(
        config,
        saved,
        recurse=checkpointer if subgraphs else None,
        apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
    )

aget_state(config: RunnableConfig, *, subgraphs: bool = False) -> StateSnapshot async

Get the current state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def aget_state(
    self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
    """Get the current state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        async for _, pregel in self.aget_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return await pregel.aget_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                subgraphs=subgraphs,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(self.config, config) if self.config else config
    saved = await checkpointer.aget_tuple(config)
    return await self._aprepare_state_snapshot(
        config,
        saved,
        recurse=checkpointer if subgraphs else None,
        apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
    )

get_state_history(config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None) -> Iterator[StateSnapshot]

Get the history of the state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def get_state_history(
    self,
    config: RunnableConfig,
    *,
    filter: Optional[Dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
    """Get the history of the state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            yield from pregel.get_state_history(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                filter=filter,
                before=before,
                limit=limit,
            )
            return
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(
        self.config,
        config,
        {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
    )
    # eagerly consume list() to avoid holding up the db cursor
    for checkpoint_tuple in list(
        checkpointer.list(config, before=before, limit=limit, filter=filter)
    ):
        yield self._prepare_state_snapshot(
            checkpoint_tuple.config, checkpoint_tuple
        )

aget_state_history(config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None) -> AsyncIterator[StateSnapshot] async

Get the history of the state of the graph.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def aget_state_history(
    self,
    config: RunnableConfig,
    *,
    filter: Optional[Dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
    """Get the history of the state of the graph."""
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        async for _, pregel in self.aget_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            async for state in pregel.aget_state_history(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                filter=filter,
                before=before,
                limit=limit,
            ):
                yield state
            return
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    config = merge_configs(
        self.config,
        config,
        {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
    )
    # eagerly consume list() to avoid holding up the db cursor
    for checkpoint_tuple in [
        c
        async for c in checkpointer.alist(
            config, before=before, limit=limit, filter=filter
        )
    ]:
        yield await self._aprepare_state_snapshot(
            checkpoint_tuple.config, checkpoint_tuple
        )

update_state(config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None) -> RunnableConfig

Update the state of the graph with the given values, as if they came from node as_node. If as_node is not provided, it will be set to the last node that updated the state, if not ambiguous.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def update_state(
    self,
    config: RunnableConfig,
    values: Optional[Union[dict[str, Any], Any]],
    as_node: Optional[str] = None,
) -> RunnableConfig:
    """Update the state of the graph with the given values, as if they came from
    node `as_node`. If `as_node` is not provided, it will be set to the last node
    that updated the state, if not ambiguous.
    """
    checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get(
        CONFIG_KEY_CHECKPOINTER, self.checkpointer
    )
    if not checkpointer:
        raise ValueError("No checkpointer set")

    # delegate to subgraph
    if (
        checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
    ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
        # remove task_ids from checkpoint_ns
        recast_checkpoint_ns = NS_SEP.join(
            part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
        )
        # find the subgraph with the matching name
        for _, pregel in self.get_subgraphs(
            namespace=recast_checkpoint_ns, recurse=True
        ):
            return pregel.update_state(
                patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                values,
                as_node,
            )
        else:
            raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")

    # get last checkpoint
    config = ensure_config(self.config, config)
    saved = checkpointer.get_tuple(config)
    checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
    checkpoint_previous_versions = (
        saved.checkpoint["channel_versions"].copy() if saved else {}
    )
    step = saved.metadata.get("step", -1) if saved else -1
    # merge configurable fields with previous checkpoint config
    checkpoint_config = patch_configurable(
        config,
        {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
    )
    checkpoint_metadata = config["metadata"]
    if saved:
        checkpoint_config = patch_configurable(config, saved.config[CONF])
        checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
    with ChannelsManager(
        self.channels,
        checkpoint,
        LoopProtocol(config=config, step=step + 1, stop=step + 2),
    ) as (channels, managed):
        # no values as END, just clear all tasks
        if values is None and as_node == END:
            if saved is not None:
                # tasks for this checkpoint
                next_tasks = prepare_next_tasks(
                    checkpoint,
                    saved.pending_writes or [],
                    self.nodes,
                    channels,
                    managed,
                    saved.config,
                    saved.metadata.get("step", -1) + 1,
                    for_execution=True,
                    store=self.store,
                    checkpointer=self.checkpointer or None,
                    manager=None,
                )
                # apply null writes
                if null_writes := [
                    w[1:]
                    for w in saved.pending_writes or []
                    if w[0] == NULL_TASK_ID
                ]:
                    apply_writes(
                        saved.checkpoint,
                        channels,
                        [PregelTaskWrites((), INPUT, null_writes, [])],
                        None,
                    )
                # apply writes from tasks that already ran
                for tid, k, v in saved.pending_writes or []:
                    if k in (ERROR, INTERRUPT, SCHEDULED):
                        continue
                    if tid not in next_tasks:
                        continue
                    next_tasks[tid].writes.append((k, v))
                # clear all current tasks
                apply_writes(checkpoint, channels, next_tasks.values(), None)
            # save checkpoint
            next_config = checkpointer.put(
                checkpoint_config,
                create_checkpoint(checkpoint, None, step),
                {
                    **checkpoint_metadata,
                    "source": "update",
                    "step": step + 1,
                    "writes": {},
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        # no values, copy checkpoint
        if values is None and as_node is None:
            next_checkpoint = create_checkpoint(checkpoint, None, step)
            # copy checkpoint
            next_config = checkpointer.put(
                checkpoint_config,
                next_checkpoint,
                {
                    **checkpoint_metadata,
                    "source": "update",
                    "step": step + 1,
                    "writes": {},
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        if values is None and as_node == "__copy__":
            next_checkpoint = create_checkpoint(checkpoint, None, step)
            # copy checkpoint
            next_config = checkpointer.put(
                saved.parent_config or saved.config if saved else checkpoint_config,
                next_checkpoint,
                {
                    **checkpoint_metadata,
                    "source": "fork",
                    "step": step + 1,
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                {},
            )
            return patch_checkpoint_map(
                next_config, saved.metadata if saved else None
            )
        # apply pending writes, if not on specific checkpoint
        if (
            CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
            and saved is not None
            and saved.pending_writes
        ):
            # tasks for this checkpoint
            next_tasks = prepare_next_tasks(
                checkpoint,
                saved.pending_writes,
                self.nodes,
                channels,
                managed,
                saved.config,
                saved.metadata.get("step", -1) + 1,
                for_execution=True,
                store=self.store,
                checkpointer=self.checkpointer or None,
                manager=None,
            )
            # apply null writes
            if null_writes := [
                w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
            ]:
                apply_writes(
                    saved.checkpoint,
                    channels,
                    [PregelTaskWrites((), INPUT, null_writes, [])],
                    None,
                )
            # apply writes
            for tid, k, v in saved.pending_writes:
                if k in (ERROR, INTERRUPT, SCHEDULED):
                    continue
                if tid not in next_tasks:
                    continue
                next_tasks[tid].writes.append((k, v))
            if tasks := [t for t in next_tasks.values() if t.writes]:
                apply_writes(checkpoint, channels, tasks, None)
        # find last node that updated the state, if not provided
        if as_node is None and not any(
            v for vv in checkpoint["versions_seen"].values() for v in vv.values()
        ):
            if (
                isinstance(self.input_channels, str)
                and self.input_channels in self.nodes
            ):
                as_node = self.input_channels
        elif as_node is None:
            last_seen_by_node = sorted(
                (v, n)
                for n, seen in checkpoint["versions_seen"].items()
                if n in self.nodes
                for v in seen.values()
            )
            # if two nodes updated the state at the same time, it's ambiguous
            if last_seen_by_node:
                if len(last_seen_by_node) == 1:
                    as_node = last_seen_by_node[0][1]
                elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
                    as_node = last_seen_by_node[-1][1]
        if as_node is None:
            raise InvalidUpdateError("Ambiguous update, specify as_node")
        if as_node not in self.nodes:
            raise InvalidUpdateError(f"Node {as_node} does not exist")
        # create task to run all writers of the chosen node
        writers = self.nodes[as_node].flat_writers
        if not writers:
            raise InvalidUpdateError(f"Node {as_node} has no writers")
        writes: deque[tuple[str, Any]] = deque()
        task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
        task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
        run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
        # execute task
        run.invoke(
            values,
            patch_config(
                config,
                run_name=self.name + "UpdateState",
                configurable={
                    # deque.extend is thread-safe
                    CONFIG_KEY_SEND: partial(
                        local_write,
                        writes.extend,
                        self.nodes.keys(),
                    ),
                    CONFIG_KEY_READ: partial(
                        local_read,
                        step + 1,
                        checkpoint,
                        channels,
                        managed,
                        task,
                        config,
                    ),
                },
            ),
        )
        # save task writes
        # channel writes are saved to current checkpoint
        # push writes are saved to next checkpoint
        channel_writes, push_writes = (
            [w for w in task.writes if w[0] != PUSH],
            [w for w in task.writes if w[0] == PUSH],
        )
        if saved and channel_writes:
            checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
        # apply to checkpoint and save
        mv_writes = apply_writes(
            checkpoint, channels, [task], checkpointer.get_next_version
        )
        assert not mv_writes, "Can't write to SharedValues from update_state"
        checkpoint = create_checkpoint(checkpoint, channels, step + 1)
        next_config = checkpointer.put(
            checkpoint_config,
            checkpoint,
            {
                **checkpoint_metadata,
                "source": "update",
                "step": step + 1,
                "writes": {as_node: values},
                "parents": saved.metadata.get("parents", {}) if saved else {},
            },
            get_new_channel_versions(
                checkpoint_previous_versions, checkpoint["channel_versions"]
            ),
        )
        if push_writes:
            checkpointer.put_writes(next_config, push_writes, task_id)
        return patch_checkpoint_map(next_config, saved.metadata if saved else None)

stream(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False) -> Iterator[Union[dict[str, Any], Any]]

Stream graph steps for a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input to the graph.

  • config (Optional[RunnableConfig], default: None ) –

    The configuration to use for the run.

  • stream_mode (Optional[Union[StreamMode, list[StreamMode]]], default: None ) –

    The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step.

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    The keys to stream, defaults to all non-context channels.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt before, defaults to all nodes in the graph.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt after, defaults to all nodes in the graph.

  • debug (Optional[bool], default: None ) –

    Whether to print debug information during execution, defaults to False.

  • subgraphs (bool, default: False ) –

    Whether to stream subgraphs, defaults to False.

Yields:

  • Union[dict[str, Any], Any]

    The output of each step in the graph. The output shape depends on the stream_mode.

Examples:

Using different stream modes with a graph:

>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
...     alist: Annotated[list, operator.add]
...     another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
With stream_mode="values":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
...     print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
With stream_mode="updates":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
...     print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
With stream_mode="debug":

>>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
...     print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
Source code in libs/langgraph/langgraph/pregel/__init__.py
def stream(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]:
    """Stream graph steps for a single input.

    Args:
        input: The input to the graph.
        config: The configuration to use for the run.
        stream_mode: The mode to stream output, defaults to self.stream_mode.
            Options are 'values', 'updates', and 'debug'.
            values: Emit the current values of the state for each step.
            updates: Emit only the updates to the state for each step.
                Output is a dict with the node name as key and the updated values as value.
            debug: Emit debug events for each step.
        output_keys: The keys to stream, defaults to all non-context channels.
        interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
        interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
        debug: Whether to print debug information during execution, defaults to False.
        subgraphs: Whether to stream subgraphs, defaults to False.

    Yields:
        The output of each step in the graph. The output shape depends on the stream_mode.

    Examples:
        Using different stream modes with a graph:
        ```pycon
        >>> import operator
        >>> from typing_extensions import Annotated, TypedDict
        >>> from langgraph.graph import StateGraph
        >>> from langgraph.constants import START
        ...
        >>> class State(TypedDict):
        ...     alist: Annotated[list, operator.add]
        ...     another_list: Annotated[list, operator.add]
        ...
        >>> builder = StateGraph(State)
        >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
        >>> builder.add_node("b", lambda _state: {"alist": ["there"]})
        >>> builder.add_edge("a", "b")
        >>> builder.add_edge(START, "a")
        >>> graph = builder.compile()
        ```
        With stream_mode="values":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
        ...     print(event)
        {'alist': ['Ex for stream_mode="values"'], 'another_list': []}
        {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
        {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
        ```
        With stream_mode="updates":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
        ...     print(event)
        {'a': {'another_list': ['hi']}}
        {'b': {'alist': ['there']}}
        ```
        With stream_mode="debug":

        ```pycon
        >>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
        ...     print(event)
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
        ```
    """

    stream = SyncQueue()

    def output() -> Iterator:
        while True:
            try:
                ns, mode, payload = stream.get(block=False)
            except queue.Empty:
                break
            if subgraphs and isinstance(stream_mode, list):
                yield (ns, mode, payload)
            elif isinstance(stream_mode, list):
                yield (mode, payload)
            elif subgraphs:
                yield (ns, payload)
            else:
                yield payload

    config = ensure_config(self.config, config)
    callback_manager = get_callback_manager_for_config(config)
    run_manager = callback_manager.on_chain_start(
        None,
        input,
        name=config.get("run_name", self.get_name()),
        run_id=config.get("run_id"),
    )
    try:
        # assign defaults
        (
            debug,
            stream_modes,
            output_keys,
            interrupt_before_,
            interrupt_after_,
            checkpointer,
            store,
        ) = self._defaults(
            config,
            stream_mode=stream_mode,
            output_keys=output_keys,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
        )
        # set up messages stream mode
        if "messages" in stream_modes:
            run_manager.inheritable_handlers.append(
                StreamMessagesHandler(stream.put)
            )
        # set up custom stream mode
        if "custom" in stream_modes:
            config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put(
                ((), "custom", c)
            )
        with SyncPregelLoop(
            input,
            stream=StreamProtocol(stream.put, stream_modes),
            config=config,
            store=store,
            checkpointer=checkpointer,
            nodes=self.nodes,
            specs=self.channels,
            output_keys=output_keys,
            stream_keys=self.stream_channels_asis,
            interrupt_before=interrupt_before_,
            interrupt_after=interrupt_after_,
            manager=run_manager,
            debug=debug,
        ) as loop:
            # create runner
            runner = PregelRunner(
                submit=loop.submit,
                put_writes=loop.put_writes,
                schedule_task=loop.accept_push,
                node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
            )
            # enable subgraph streaming
            if subgraphs:
                loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
            # enable concurrent streaming
            if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
                # we are careful to have a single waiter live at any one time
                # because on exit we increment semaphore count by exactly 1
                waiter: Optional[concurrent.futures.Future] = None
                # because sync futures cannot be cancelled, we instead
                # release the stream semaphore on exit, which will cause
                # a pending waiter to return immediately
                loop.stack.callback(stream._count.release)

                def get_waiter() -> concurrent.futures.Future[None]:
                    nonlocal waiter
                    if waiter is None or waiter.done():
                        waiter = loop.submit(stream.wait)
                        return waiter
                    else:
                        return waiter

            else:
                get_waiter = None  # type: ignore[assignment]
            # Similarly to Bulk Synchronous Parallel / Pregel model
            # computation proceeds in steps, while there are channel updates
            # channel updates from step N are only visible in step N+1
            # channels are guaranteed to be immutable for the duration of the step,
            # with channel updates applied only at the transition between steps
            while loop.tick(input_keys=self.input_channels):
                for _ in runner.tick(
                    loop.tasks.values(),
                    timeout=self.step_timeout,
                    retry_policy=self.retry_policy,
                    get_waiter=get_waiter,
                ):
                    # emit output
                    yield from output()
        # emit output
        yield from output()
        # handle exit
        if loop.status == "out_of_steps":
            msg = create_error_message(
                message=(
                    f"Recursion limit of {config['recursion_limit']} reached "
                    "without hitting a stop condition. You can increase the "
                    "limit by setting the `recursion_limit` config key."
                ),
                error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
            )
            raise GraphRecursionError(msg)
        # set final channel values as run output
        run_manager.on_chain_end(loop.output)
    except BaseException as e:
        run_manager.on_chain_error(e)
        raise

astream(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False) -> AsyncIterator[Union[dict[str, Any], Any]] async

Stream graph steps for a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input to the graph.

  • config (Optional[RunnableConfig], default: None ) –

    The configuration to use for the run.

  • stream_mode (Optional[Union[StreamMode, list[StreamMode]]], default: None ) –

    The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step.

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    The keys to stream, defaults to all non-context channels.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt before, defaults to all nodes in the graph.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Nodes to interrupt after, defaults to all nodes in the graph.

  • debug (Optional[bool], default: None ) –

    Whether to print debug information during execution, defaults to False.

  • subgraphs (bool, default: False ) –

    Whether to stream subgraphs, defaults to False.

Yields:

  • AsyncIterator[Union[dict[str, Any], Any]]

    The output of each step in the graph. The output shape depends on the stream_mode.

Examples:

Using different stream modes with a graph:

>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
...     alist: Annotated[list, operator.add]
...     another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
With stream_mode="values":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
...     print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
With stream_mode="updates":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
...     print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
With stream_mode="debug":

>>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
...     print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
Source code in libs/langgraph/langgraph/pregel/__init__.py
async def astream(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
    """Stream graph steps for a single input.

    Args:
        input: The input to the graph.
        config: The configuration to use for the run.
        stream_mode: The mode to stream output, defaults to self.stream_mode.
            Options are 'values', 'updates', and 'debug'.
            values: Emit the current values of the state for each step.
            updates: Emit only the updates to the state for each step.
                Output is a dict with the node name as key and the updated values as value.
            debug: Emit debug events for each step.
        output_keys: The keys to stream, defaults to all non-context channels.
        interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
        interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
        debug: Whether to print debug information during execution, defaults to False.
        subgraphs: Whether to stream subgraphs, defaults to False.

    Yields:
        The output of each step in the graph. The output shape depends on the stream_mode.

    Examples:
        Using different stream modes with a graph:
        ```pycon
        >>> import operator
        >>> from typing_extensions import Annotated, TypedDict
        >>> from langgraph.graph import StateGraph
        >>> from langgraph.constants import START
        ...
        >>> class State(TypedDict):
        ...     alist: Annotated[list, operator.add]
        ...     another_list: Annotated[list, operator.add]
        ...
        >>> builder = StateGraph(State)
        >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
        >>> builder.add_node("b", lambda _state: {"alist": ["there"]})
        >>> builder.add_edge("a", "b")
        >>> builder.add_edge(START, "a")
        >>> graph = builder.compile()
        ```
        With stream_mode="values":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
        ...     print(event)
        {'alist': ['Ex for stream_mode="values"'], 'another_list': []}
        {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
        {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
        ```
        With stream_mode="updates":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
        ...     print(event)
        {'a': {'another_list': ['hi']}}
        {'b': {'alist': ['there']}}
        ```
        With stream_mode="debug":

        ```pycon
        >>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
        ...     print(event)
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
        {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
        {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
        ```
    """

    stream = AsyncQueue()
    aioloop = asyncio.get_running_loop()
    stream_put = cast(
        Callable[[StreamChunk], None],
        partial(aioloop.call_soon_threadsafe, stream.put_nowait),
    )

    def output() -> Iterator:
        while True:
            try:
                ns, mode, payload = stream.get_nowait()
            except asyncio.QueueEmpty:
                break
            if subgraphs and isinstance(stream_mode, list):
                yield (ns, mode, payload)
            elif isinstance(stream_mode, list):
                yield (mode, payload)
            elif subgraphs:
                yield (ns, payload)
            else:
                yield payload

    config = ensure_config(self.config, config)
    callback_manager = get_async_callback_manager_for_config(config)
    run_manager = await callback_manager.on_chain_start(
        None,
        input,
        name=config.get("run_name", self.get_name()),
        run_id=config.get("run_id"),
    )
    # if running from astream_log() run each proc with streaming
    do_stream = next(
        (
            cast(_StreamingCallbackHandler, h)
            for h in run_manager.handlers
            if isinstance(h, _StreamingCallbackHandler)
        ),
        None,
    )
    try:
        # assign defaults
        (
            debug,
            stream_modes,
            output_keys,
            interrupt_before_,
            interrupt_after_,
            checkpointer,
            store,
        ) = self._defaults(
            config,
            stream_mode=stream_mode,
            output_keys=output_keys,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
        )
        # set up messages stream mode
        if "messages" in stream_modes:
            run_manager.inheritable_handlers.append(
                StreamMessagesHandler(stream_put)
            )
        # set up custom stream mode
        if "custom" in stream_modes:
            config[CONF][CONFIG_KEY_STREAM_WRITER] = (
                lambda c: aioloop.call_soon_threadsafe(
                    stream.put_nowait, ((), "custom", c)
                )
            )
        async with AsyncPregelLoop(
            input,
            stream=StreamProtocol(stream.put_nowait, stream_modes),
            config=config,
            store=store,
            checkpointer=checkpointer,
            nodes=self.nodes,
            specs=self.channels,
            output_keys=output_keys,
            stream_keys=self.stream_channels_asis,
            interrupt_before=interrupt_before_,
            interrupt_after=interrupt_after_,
            manager=run_manager,
            debug=debug,
        ) as loop:
            # create runner
            runner = PregelRunner(
                submit=loop.submit,
                put_writes=loop.put_writes,
                schedule_task=loop.accept_push,
                use_astream=do_stream is not None,
                node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
            )
            # enable subgraph streaming
            if subgraphs:
                loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol(
                    stream_put, stream_modes
                )
            # enable concurrent streaming
            if subgraphs or "messages" in stream_modes or "custom" in stream_modes:

                def get_waiter() -> asyncio.Task[None]:
                    return aioloop.create_task(stream.wait())

            else:
                get_waiter = None  # type: ignore[assignment]
            # Similarly to Bulk Synchronous Parallel / Pregel model
            # computation proceeds in steps, while there are channel updates
            # channel updates from step N are only visible in step N+1
            # channels are guaranteed to be immutable for the duration of the step,
            # with channel updates applied only at the transition between steps
            while loop.tick(input_keys=self.input_channels):
                async for _ in runner.atick(
                    loop.tasks.values(),
                    timeout=self.step_timeout,
                    retry_policy=self.retry_policy,
                    get_waiter=get_waiter,
                ):
                    # emit output
                    for o in output():
                        yield o
        # emit output
        for o in output():
            yield o
        # handle exit
        if loop.status == "out_of_steps":
            msg = create_error_message(
                message=(
                    f"Recursion limit of {config['recursion_limit']} reached "
                    "without hitting a stop condition. You can increase the "
                    "limit by setting the `recursion_limit` config key."
                ),
                error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
            )
            raise GraphRecursionError(msg)
        # set final channel values as run output
        await run_manager.on_chain_end(loop.output)
    except BaseException as e:
        await asyncio.shield(run_manager.on_chain_error(e))
        raise

invoke(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = 'values', output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any) -> Union[dict[str, Any], Any]

Run the graph with a single input and config.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input data for the graph. It can be a dictionary or any other type.

  • config (Optional[RunnableConfig], default: None ) –

    Optional. The configuration for the graph run.

  • stream_mode (StreamMode, default: 'values' ) –

    Optional[str]. The stream mode for the graph run. Default is "values".

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    Optional. The output keys to retrieve from the graph run.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt the graph run before.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt the graph run after.

  • debug (Optional[bool], default: None ) –

    Optional. Enable debug mode for the graph run.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments to pass to the graph run.

Returns:

  • Union[dict[str, Any], Any]

    The output of the graph run. If stream_mode is "values", it returns the latest output.

  • Union[dict[str, Any], Any]

    If stream_mode is not "values", it returns a list of output chunks.

Source code in libs/langgraph/langgraph/pregel/__init__.py
def invoke(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: StreamMode = "values",
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    **kwargs: Any,
) -> Union[dict[str, Any], Any]:
    """Run the graph with a single input and config.

    Args:
        input: The input data for the graph. It can be a dictionary or any other type.
        config: Optional. The configuration for the graph run.
        stream_mode: Optional[str]. The stream mode for the graph run. Default is "values".
        output_keys: Optional. The output keys to retrieve from the graph run.
        interrupt_before: Optional. The nodes to interrupt the graph run before.
        interrupt_after: Optional. The nodes to interrupt the graph run after.
        debug: Optional. Enable debug mode for the graph run.
        **kwargs: Additional keyword arguments to pass to the graph run.

    Returns:
        The output of the graph run. If stream_mode is "values", it returns the latest output.
        If stream_mode is not "values", it returns a list of output chunks.
    """
    output_keys = output_keys if output_keys is not None else self.output_channels
    if stream_mode == "values":
        latest: Union[dict[str, Any], Any] = None
    else:
        chunks = []
    for chunk in self.stream(
        input,
        config,
        stream_mode=stream_mode,
        output_keys=output_keys,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        debug=debug,
        **kwargs,
    ):
        if stream_mode == "values":
            latest = chunk
        else:
            chunks.append(chunk)
    if stream_mode == "values":
        return latest
    else:
        return chunks

ainvoke(input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = 'values', output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any) -> Union[dict[str, Any], Any] async

Asynchronously invoke the graph on a single input.

Parameters:

  • input (Union[dict[str, Any], Any]) –

    The input data for the computation. It can be a dictionary or any other type.

  • config (Optional[RunnableConfig], default: None ) –

    Optional. The configuration for the computation.

  • stream_mode (StreamMode, default: 'values' ) –

    Optional. The stream mode for the computation. Default is "values".

  • output_keys (Optional[Union[str, Sequence[str]]], default: None ) –

    Optional. The output keys to include in the result. Default is None.

  • interrupt_before (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt before. Default is None.

  • interrupt_after (Optional[Union[All, Sequence[str]]], default: None ) –

    Optional. The nodes to interrupt after. Default is None.

  • debug (Optional[bool], default: None ) –

    Optional. Whether to enable debug mode. Default is None.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Returns:

  • Union[dict[str, Any], Any]

    The result of the computation. If stream_mode is "values", it returns the latest value.

  • Union[dict[str, Any], Any]

    If stream_mode is "chunks", it returns a list of chunks.

Source code in libs/langgraph/langgraph/pregel/__init__.py
async def ainvoke(
    self,
    input: Union[dict[str, Any], Any],
    config: Optional[RunnableConfig] = None,
    *,
    stream_mode: StreamMode = "values",
    output_keys: Optional[Union[str, Sequence[str]]] = None,
    interrupt_before: Optional[Union[All, Sequence[str]]] = None,
    interrupt_after: Optional[Union[All, Sequence[str]]] = None,
    debug: Optional[bool] = None,
    **kwargs: Any,
) -> Union[dict[str, Any], Any]:
    """Asynchronously invoke the graph on a single input.

    Args:
        input: The input data for the computation. It can be a dictionary or any other type.
        config: Optional. The configuration for the computation.
        stream_mode: Optional. The stream mode for the computation. Default is "values".
        output_keys: Optional. The output keys to include in the result. Default is None.
        interrupt_before: Optional. The nodes to interrupt before. Default is None.
        interrupt_after: Optional. The nodes to interrupt after. Default is None.
        debug: Optional. Whether to enable debug mode. Default is None.
        **kwargs: Additional keyword arguments.

    Returns:
        The result of the computation. If stream_mode is "values", it returns the latest value.
        If stream_mode is "chunks", it returns a list of chunks.
    """

    output_keys = output_keys if output_keys is not None else self.output_channels
    if stream_mode == "values":
        latest: Union[dict[str, Any], Any] = None
    else:
        chunks = []
    async for chunk in self.astream(
        input,
        config,
        stream_mode=stream_mode,
        output_keys=output_keys,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        debug=debug,
        **kwargs,
    ):
        if stream_mode == "values":
            latest = chunk
        else:
            chunks.append(chunk)
    if stream_mode == "values":
        return latest
    else:
        return chunks

add_messages(left: Messages, right: Messages) -> Messages

Merges two lists of messages, updating existing messages by ID.

By default, this ensures the state is "append-only", unless the new message has the same ID as an existing message.

Parameters:

  • left (Messages) –

    The base list of messages.

  • right (Messages) –

    The list of messages (or single message) to merge into the base list.

Returns:

  • Messages

    A new list of messages with the messages from right merged into left.

  • Messages

    If a message in right has the same ID as a message in left, the

  • Messages

    message from right will replace the message from left.

Examples:

>>> from langchain_core.messages import AIMessage, HumanMessage
>>> msgs1 = [HumanMessage(content="Hello", id="1")]
>>> msgs2 = [AIMessage(content="Hi there!", id="2")]
>>> add_messages(msgs1, msgs2)
[HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]

>>> msgs1 = [HumanMessage(content="Hello", id="1")]
>>> msgs2 = [HumanMessage(content="Hello again", id="1")]
>>> add_messages(msgs1, msgs2)
[HumanMessage(content='Hello again', id='1')]

>>> from typing import Annotated
>>> from typing_extensions import TypedDict
>>> from langgraph.graph import StateGraph
>>>
>>> class State(TypedDict):
...     messages: Annotated[list, add_messages]
...
>>> builder = StateGraph(State)
>>> builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> graph = builder.compile()
>>> graph.invoke({})
{'messages': [AIMessage(content='Hello', id=...)]}
Source code in libs/langgraph/langgraph/graph/message.py
def add_messages(left: Messages, right: Messages) -> Messages:
    """Merges two lists of messages, updating existing messages by ID.

    By default, this ensures the state is "append-only", unless the
    new message has the same ID as an existing message.

    Args:
        left: The base list of messages.
        right: The list of messages (or single message) to merge
            into the base list.

    Returns:
        A new list of messages with the messages from `right` merged into `left`.
        If a message in `right` has the same ID as a message in `left`, the
        message from `right` will replace the message from `left`.

    Examples:
        ```pycon
        >>> from langchain_core.messages import AIMessage, HumanMessage
        >>> msgs1 = [HumanMessage(content="Hello", id="1")]
        >>> msgs2 = [AIMessage(content="Hi there!", id="2")]
        >>> add_messages(msgs1, msgs2)
        [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]

        >>> msgs1 = [HumanMessage(content="Hello", id="1")]
        >>> msgs2 = [HumanMessage(content="Hello again", id="1")]
        >>> add_messages(msgs1, msgs2)
        [HumanMessage(content='Hello again', id='1')]

        >>> from typing import Annotated
        >>> from typing_extensions import TypedDict
        >>> from langgraph.graph import StateGraph
        >>>
        >>> class State(TypedDict):
        ...     messages: Annotated[list, add_messages]
        ...
        >>> builder = StateGraph(State)
        >>> builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
        >>> builder.set_entry_point("chatbot")
        >>> builder.set_finish_point("chatbot")
        >>> graph = builder.compile()
        >>> graph.invoke({})
        {'messages': [AIMessage(content='Hello', id=...)]}
        ```

    """
    # coerce to list
    if not isinstance(left, list):
        left = [left]  # type: ignore[assignment]
    if not isinstance(right, list):
        right = [right]  # type: ignore[assignment]
    # coerce to message
    left = [
        message_chunk_to_message(cast(BaseMessageChunk, m))
        for m in convert_to_messages(left)
    ]
    right = [
        message_chunk_to_message(cast(BaseMessageChunk, m))
        for m in convert_to_messages(right)
    ]
    # assign missing ids
    for m in left:
        if m.id is None:
            m.id = str(uuid.uuid4())
    for m in right:
        if m.id is None:
            m.id = str(uuid.uuid4())
    # merge
    left_idx_by_id = {m.id: i for i, m in enumerate(left)}
    merged = left.copy()
    ids_to_remove = set()
    for m in right:
        if (existing_idx := left_idx_by_id.get(m.id)) is not None:
            if isinstance(m, RemoveMessage):
                ids_to_remove.add(m.id)
            else:
                merged[existing_idx] = m
        else:
            if isinstance(m, RemoveMessage):
                raise ValueError(
                    f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
                )

            merged.append(m)
    merged = [m for m in merged if m.id not in ids_to_remove]
    return merged

Comments