How to create subgraphs¶
For more complex systems, subgraphs are a useful design principle. Subgraphs allow you to create and manage different states in different parts of your graph. This allows you build things like multi-agent teams, where each team can track its own separate state.
Setup¶
First, let's install the required packages
%%capture --no-stderr
%pip install -U langgraph
Simple example¶
Let's consider a toy example: a system that accepts logs and perform two separate sub-tasks. First, it will summarize them. Second, it will summarize any failure modes captured in the logs. These two operations will be performed by two different subgraphs.
The most important thing to recognize is the information transfer between the graphs. Entry Graph
is the parent, and each of the two subgraphs are defined as nodes in Entry Graph
. Both subgraphs inherit state from the parent Entry Graph
; I can access docs
in each of the subgraphs simply by specifying it in the subgraph state (see diagram). Each subgraph can have its own private state. And any values that I want propagated back to the parent Entry Graph
(for final reporting) simply need to be defined in my Entry Graph
state (e.g., summary report
and failure report
).
Define subgraphs¶
from typing import Optional, Annotated
from typing_extensions import TypedDict
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
# The structure of the logs
class Logs(TypedDict):
id: str
question: str
answer: str
grade: Optional[int]
feedback: Optional[str]
# Define custom reducer (see more on this in the "Custom reducer" section below)
def add_logs(left: list[Logs], right: list[Logs]) -> list[Logs]:
if not left:
left = []
if not right:
right = []
logs = left.copy()
left_id_to_idx = {log["id"]: idx for idx, log in enumerate(logs)}
# update if the new logs are already in the state, otherwise append
for log in right:
idx = left_id_to_idx.get(log["id"])
if idx is not None:
logs[idx] = log
else:
logs.append(log)
return logs
# Failure Analysis Subgraph
class FailureAnalysisState(TypedDict):
# keys shared with the parent graph (EntryGraphState)
logs: Annotated[list[Logs], add_logs]
failure_report: str
# subgraph key
failures: list[Logs]
def get_failures(state: FailureAnalysisState):
failures = [log for log in state["logs"] if log["grade"] == 0]
return {"failures": failures}
def generate_summary(state: FailureAnalysisState):
failures = state["failures"]
# NOTE: you can implement custom summarization logic here
failure_ids = [log["id"] for log in failures]
fa_summary = f"Poor quality of retrieval for document IDs: {', '.join(failure_ids)}"
return {"failure_report": fa_summary}
fa_builder = StateGraph(FailureAnalysisState)
fa_builder.add_node("get_failures", get_failures)
fa_builder.add_node("generate_summary", generate_summary)
fa_builder.add_edge(START, "get_failures")
fa_builder.add_edge("get_failures", "generate_summary")
fa_builder.add_edge("generate_summary", END)
# Summarization subgraph
class QuestionSummarizationState(TypedDict):
# keys that are shared with the parent graph (EntryGraphState)
summary_report: str
logs: Annotated[list[Logs], add_logs]
# subgraph keys
summary: str
def generate_summary(state: QuestionSummarizationState):
docs = state["logs"]
# NOTE: you can implement custom summarization logic here
summary = "Questions focused on usage of ChatOllama and Chroma vector store."
return {"summary": summary}
def send_to_slack(state: QuestionSummarizationState):
summary = state["summary"]
# NOTE: you can implement custom logic here, for example sending the summary generated in the previous step to Slack
return {"summary_report": summary}
qs_builder = StateGraph(QuestionSummarizationState)
qs_builder.add_node("generate_summary", generate_summary)
qs_builder.add_node("send_to_slack", send_to_slack)
qs_builder.add_edge(START, "generate_summary")
qs_builder.add_edge("generate_summary", "send_to_slack")
qs_builder.add_edge("send_to_slack", END)
Note that each subgraph has its own state, QuestionSummarizationState
and FailureAnalysisState
.
After defining each subgraph, we put everything together.
Define parent graph¶
# Dummy logs
dummy_logs = [
Logs(
id="1",
question="How can I import ChatOllama?",
grade=1,
answer="To import ChatOllama, use: 'from langchain_community.chat_models import ChatOllama.'",
),
Logs(
id="2",
question="How can I use Chroma vector store?",
answer="To use Chroma, define: rag_chain = create_retrieval_chain(retriever, question_answer_chain).",
grade=0,
feedback="The retrieved documents discuss vector stores in general, but not Chroma specifically",
),
Logs(
id="3",
question="How do I create react agent in langgraph?",
answer="from langgraph.prebuilt import create_react_agent",
)
]
# Entry Graph
class EntryGraphState(TypedDict):
raw_logs: Annotated[list[Logs], add_logs]
logs: Annotated[list[Logs], add_logs] # This will be used in subgraphs
failure_report: str # This will be generated in the FA subgraph
summary_report: str # This will be generated in the QS subgraph
def select_logs(state):
return {"logs": [log for log in state["raw_logs"] if "grade" in log]}
entry_builder = StateGraph(EntryGraphState)
entry_builder.add_node("select_logs", select_logs)
entry_builder.add_node("question_summarization", qs_builder.compile())
entry_builder.add_node("failure_analysis", fa_builder.compile())
entry_builder.add_edge(START, "select_logs")
entry_builder.add_edge("select_logs", "failure_analysis")
entry_builder.add_edge("select_logs", "question_summarization")
entry_builder.add_edge("failure_analysis", END)
entry_builder.add_edge("question_summarization", END)
graph = entry_builder.compile()
from IPython.display import Image, display
# Setting xray to 1 will show the internal structure of the nested graph
display(Image(graph.get_graph(xray=1).draw_mermaid_png()))
graph.invoke({"raw_logs": dummy_logs}, debug=False)
{'raw_logs': [{'id': '1', 'question': 'How can I import ChatOllama?', 'grade': 1, 'answer': "To import ChatOllama, use: 'from langchain_community.chat_models import ChatOllama.'"}, {'id': '2', 'question': 'How can I use Chroma vector store?', 'answer': 'To use Chroma, define: rag_chain = create_retrieval_chain(retriever, question_answer_chain).', 'grade': 0, 'feedback': 'The retrieved documents discuss vector stores in general, but not Chroma specifically'}, {'id': '3', 'question': 'How do I create react agent in langgraph?', 'answer': 'from langgraph.prebuilt import create_react_agent'}], 'logs': [{'id': '1', 'question': 'How can I import ChatOllama?', 'grade': 1, 'answer': "To import ChatOllama, use: 'from langchain_community.chat_models import ChatOllama.'"}, {'id': '2', 'question': 'How can I use Chroma vector store?', 'answer': 'To use Chroma, define: rag_chain = create_retrieval_chain(retriever, question_answer_chain).', 'grade': 0, 'feedback': 'The retrieved documents discuss vector stores in general, but not Chroma specifically'}], 'failure_report': 'Poor quality of retrieval for document IDs: 2', 'summary_report': 'Questions focused on usage of ChatOllama and Chroma vector store.'}
Custom reducer functions to manage state¶
You might have noticed that we defined a custom reducer function (add_logs
) or the logs
key in EntryGraphState
. It is necessary to provide a reducer when using shared state keys across multiple subgraphs.
Let's take a look at implementing a custom reducer. We will create two graphs: a parent graph with a few nodes and a child graph that is added as a node in the parent. We'll also define a custom reducer function (reduce_list
) for our state. This is functionally equivalent to simply using operator.add
.
from typing import Annotated
from typing_extensions import TypedDict
# define a simple reducer
def reduce_list(left: list, right: list) -> list:
if not left:
left = []
if not right:
right = []
return left + right
# define parent and child state
class ChildState(TypedDict):
name: str
path: Annotated[list[str], reduce_list]
class ParentState(TypedDict):
name: str
path: Annotated[list[str], reduce_list]
# define a helper to build the graph
def make_graph(parent_schema, child_schema):
child_builder = StateGraph(child_schema)
child_builder.add_node("child_start", lambda state: {"path": ["child_start"]})
child_builder.add_edge(START, "child_start")
child_builder.add_node("child_middle", lambda state: {"path": ["child_middle"]})
child_builder.add_node("child_end", lambda state: {"path": ["child_end"]})
child_builder.add_edge("child_start", "child_middle")
child_builder.add_edge("child_middle", "child_end")
child_builder.add_edge("child_end", END)
builder = StateGraph(parent_schema)
builder.add_node("grandparent", lambda state: {"path": ["grandparent"]})
builder.add_edge(START, "grandparent")
builder.add_node("parent", lambda state: {"path": ["parent"]})
builder.add_node("child", child_builder.compile())
builder.add_node("sibling", lambda state: {"path": ["sibling"]})
builder.add_node("fin", lambda state: {"path": ["fin"]})
# Add connections
builder.add_edge("grandparent", "parent")
builder.add_edge("parent", "child")
builder.add_edge("parent", "sibling")
builder.add_edge("child", "fin")
builder.add_edge("sibling", "fin")
builder.add_edge("fin", END)
graph = builder.compile()
return graph
graph = make_graph(ParentState, ChildState)
from IPython.display import Image, display
# Setting xray to 1 will show the internal structure of the nested graph
display(Image(graph.get_graph(xray=1).draw_mermaid_png()))