{ "cells": [ { "cell_type": "markdown", "id": "5fa317ef-b9a7-4432-ba85-ce71b8dfbdc6", "metadata": {}, "source": [ "# How to handle large numbers of tools\n", "\n", "The subset of available tools to call is generally at the discretion of the model (although many providers also enable the user to [specify or constrain the choice of tool](https://python.langchain.com/docs/how_to/tool_choice/)). As the number of available tools grows, you may want to limit the scope of the LLM's selection, to decrease token consumption and to help manage sources of error in LLM reasoning.\n", "\n", "Here we will demonstrate how to dynamically adjust the tools available to a model. Bottom line up front: like [RAG](https://python.langchain.com/docs/concepts/#retrieval) and similar methods, we prefix the model invocation by retrieving over available tools. Although we demonstrate one implementation that searches over tool descriptions, the details of the tool selection can be customized as needed.\n", "\n", "## Setup\n", "\n", "First, let's install the required packages and set our API keys" ] }, { "cell_type": "code", "execution_count": null, "id": "9b6c62bd", "metadata": {}, "outputs": [], "source": [ "%%capture --no-stderr\n", "%pip install --quiet -U langgraph langchain_openai numpy" ] }, { "cell_type": "code", "execution_count": null, "id": "360d7ff6", "metadata": {}, "outputs": [], "source": [ "import getpass\n", "import os\n", "\n", "\n", "def _set_env(var: str):\n", " if not os.environ.get(var):\n", " os.environ[var] = getpass.getpass(f\"{var}: \")\n", "\n", "\n", "_set_env(\"OPENAI_API_KEY\")" ] }, { "cell_type": "markdown", "id": "25f9f6a0", "metadata": {}, "source": [ "
\n", "

Set up LangSmith for LangGraph development

\n", "

\n", " Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started here. \n", "

\n", "
" ] }, { "cell_type": "markdown", "id": "1a417013-ddc4-463b-8ea0-0904bd232827", "metadata": {}, "source": [ "## Define the tools" ] }, { "cell_type": "markdown", "id": "24708f3b-18b1-4b42-9f6a-0d4827222918", "metadata": {}, "source": [ "Let's consider a toy example in which we have one tool for each company in the S&P 500 index. Each tool will fetch information, and is parameterized by a single integer representing the year.\n", "\n", "We first construct a registry that associates a unique identifier with a schema for each tool. We will represent the tools using JSON schema, which can be bound directly to chat models supporting tool calling." ] }, { "cell_type": "code", "execution_count": 1, "id": "da30c3f1-127f-4828-8609-94e16719f0be", "metadata": {}, "outputs": [], "source": [ "import re\n", "import uuid\n", "\n", "from langchain_core.tools import StructuredTool\n", "\n", "\n", "def create_tool(company: str) -> dict:\n", " \"\"\"Create schema for a placeholder tool.\"\"\"\n", " formatted_company = re.sub(r\"[^\\w\\s]\", \"\", company).replace(\" \", \"_\")\n", "\n", " def company_tool(year: int) -> str:\n", " return f\"{company} had revenues of $100 in {year}.\"\n", "\n", " return StructuredTool.from_function(\n", " company_tool,\n", " name=formatted_company,\n", " description=f\"Information about {company}\",\n", " )\n", "\n", "\n", "s_and_p_500_companies = [ # Abbreviated list for demonstration purposes\n", " \"3M\",\n", " \"A.O. Smith\",\n", " \"Abbott\",\n", " \"Accenture\",\n", " \"Advanced Micro Devices\",\n", " \"Yum! Brands\",\n", " \"Zebra Technologies\",\n", " \"Zimmer Biomet\",\n", " \"Zoetis\",\n", "]\n", "\n", "tool_registry = {\n", " str(uuid.uuid4()): create_tool(company) for company in s_and_p_500_companies\n", "}" ] }, { "cell_type": "markdown", "id": "ba17b047-73ed-4385-adc2-f02012db2206", "metadata": {}, "source": [ "## Define the graph" ] }, { "cell_type": "markdown", "id": "2055548d-3d14-4aaf-9588-abf70f28b5d6", "metadata": {}, "source": [ "### Tool selection" ] }, { "cell_type": "markdown", "id": "8798a0d2-ea93-45bc-ab55-071ab975f2c2", "metadata": {}, "source": [ "We will construct a node that retrieves a subset of available tools given the information in the state-- such as a recent user message. In general, the full scope of [retrieval solutions](https://python.langchain.com/docs/concepts/#retrieval) are available for this step. As a simple solution, we index embeddings of tool descriptions in a vector store, and associate user queries to tools via semantic search." ] }, { "cell_type": "code", "execution_count": 2, "id": "435b0201-7296-4617-abf8-2c757a71f6b5", "metadata": {}, "outputs": [], "source": [ "from langchain_core.documents import Document\n", "from langchain_core.vectorstores import InMemoryVectorStore\n", "from langchain_openai import OpenAIEmbeddings\n", "\n", "tool_documents = [\n", " Document(\n", " page_content=tool.description,\n", " id=id,\n", " metadata={\"tool_name\": tool.name},\n", " )\n", " for id, tool in tool_registry.items()\n", "]\n", "\n", "vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())\n", "document_ids = vector_store.add_documents(tool_documents)" ] }, { "cell_type": "markdown", "id": "e9ce366b-b5e7-41e9-b4a9-d775b9be0d09", "metadata": {}, "source": [ "### Incorporating with an agent\n", "\n", "We will use a typical React agent graph (e.g., as used in the [quickstart](https://langchain-ai.github.io/langgraph/tutorials/introduction/#part-2-enhancing-the-chatbot-with-tools)), with some modifications:\n", "\n", "- We add a `selected_tools` key to the state, which stores our selected subset of tools;\n", "- We set the entry point of the graph to be a `select_tools` node, which populates this element of the state;\n", "- We bind the selected subset of tools to the chat model within the `agent` node." ] }, { "cell_type": "code", "execution_count": 3, "id": "d319fea9-e8ae-4763-a785-b2bf72239ae4", "metadata": {}, "outputs": [], "source": [ "from typing import Annotated\n", "\n", "from langchain_openai import ChatOpenAI\n", "from typing_extensions import TypedDict\n", "\n", "from langgraph.graph import StateGraph, START\n", "from langgraph.graph.message import add_messages\n", "from langgraph.prebuilt import ToolNode, tools_condition\n", "\n", "\n", "class State(TypedDict):\n", " messages: Annotated[list, add_messages]\n", " selected_tools: list[str]\n", "\n", "\n", "graph_builder = StateGraph(State)\n", "\n", "tools = list(tool_registry.values())\n", "llm = ChatOpenAI()\n", "\n", "\n", "def agent(state: State):\n", " selected_tools = [tool_registry[id] for id in state[\"selected_tools\"]]\n", " llm_with_tools = llm.bind_tools(selected_tools)\n", " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}\n", "\n", "\n", "def select_tools(state: State):\n", " last_user_message = state[\"messages\"][-1]\n", " query = last_user_message.content\n", " tool_documents = vector_store.similarity_search(query)\n", " return {\"selected_tools\": [document.id for document in tool_documents]}\n", "\n", "\n", "graph_builder.add_node(\"agent\", agent)\n", "graph_builder.add_node(\"select_tools\", select_tools)\n", "\n", "tool_node = ToolNode(tools=tools)\n", "graph_builder.add_node(\"tools\", tool_node)\n", "\n", "graph_builder.add_conditional_edges(\n", " \"agent\",\n", " tools_condition,\n", ")\n", "graph_builder.add_edge(\"tools\", \"agent\")\n", "graph_builder.add_edge(\"select_tools\", \"agent\")\n", "graph_builder.add_edge(START, \"select_tools\")\n", "graph = graph_builder.compile()" ] }, { "cell_type": "code", "execution_count": 4, "id": "35cab3b2-4d03-4cb5-ba10-f7d3a5ad5244", "metadata": {}, "outputs": [ { "data": { "image/jpeg": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Image, display\n", "\n", "try:\n", " display(Image(graph.get_graph().draw_mermaid_png()))\n", "except Exception:\n", " # This requires some extra dependencies and is optional\n", " pass" ] }, { "cell_type": "code", "execution_count": 5, "id": "66f62a69-989b-46ce-80b3-97a867e36782", "metadata": {}, "outputs": [], "source": [ "user_input = \"Can you give me some information about AMD in 2022?\"\n", "\n", "result = graph.invoke({\"messages\": [(\"user\", user_input)]})" ] }, { "cell_type": "code", "execution_count": 6, "id": "479a459d-6896-4960-aae9-9f1259fb47d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['3b7d1528-6007-4473-a92f-b9b3341c3bfe', '8d77b753-c58a-41bf-9649-ad1a7326bc27', '514a6fc3-03d1-4e73-b410-c39309ad7b2f', '83c5cc8f-5111-46ed-874a-e0b883265ff6']\n" ] } ], "source": [ "print(result[\"selected_tools\"])" ] }, { "cell_type": "code", "execution_count": 7, "id": "376f28fd-3f7f-4ae5-a34c-baef1778e82b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================\u001b[1m Human Message \u001b[0m=================================\n", "\n", "Can you give me some information about AMD in 2022?\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "Tool Calls:\n", " Advanced_Micro_Devices (call_Htbv7Imx4BwSsYWhZvSSs6yW)\n", " Call ID: call_Htbv7Imx4BwSsYWhZvSSs6yW\n", " Args:\n", " year: 2022\n", "=================================\u001b[1m Tool Message \u001b[0m=================================\n", "Name: Advanced_Micro_Devices\n", "\n", "Advanced Micro Devices had revenues of $100 in 2022.\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "\n", "In 2022, Advanced Micro Devices had revenues of $100.\n" ] } ], "source": [ "for message in result[\"messages\"]:\n", " message.pretty_print()" ] }, { "cell_type": "markdown", "id": "3bd847ef-4627-4fc2-99f9-c2b17cf83f95", "metadata": {}, "source": [ "## Repeating tool selection\n", "\n", "To manage errors from incorrect tool selection, we could revisit the `select_tools` node. One option for implementing this is to modify `select_tools` to generate the vector store query using all messages in the state (e.g., with a chat model) and add an edge routing from `tools` to `select_tools`.\n", "\n", "We implement this change below. For demonstration purposes, we simulate an error in the initial tool selection by adding a `hack_remove_tool_condition` to the `select_tools` node, which removes the correct tool on the first iteration of the node. Note that on the second iteration, the agent finishes the run as it has access to the correct tool." ] }, { "cell_type": "markdown", "id": "985a5388", "metadata": {}, "source": [ "
\n", "

Using Pydantic with LangChain

\n", "

\n", " This notebook uses Pydantic v2 BaseModel, which requires langchain-core >= 0.3. Using langchain-core < 0.3 will result in errors due to mixing of Pydantic v1 and v2 BaseModels.\n", "

\n", "
" ] }, { "cell_type": "code", "execution_count": 5, "id": "1954a5f1-91e4-4b32-9be9-c8bc1cc43cb5", "metadata": {}, "outputs": [], "source": [ "from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage\n", "from langgraph.pregel.retry import RetryPolicy\n", "\n", "from pydantic import BaseModel, Field\n", "\n", "\n", "class QueryForTools(BaseModel):\n", " \"\"\"Generate a query for additional tools.\"\"\"\n", "\n", " query: str = Field(..., description=\"Query for additional tools.\")\n", "\n", "\n", "def select_tools(state: State):\n", " last_message = state[\"messages\"][-1]\n", " hack_remove_tool_condition = False\n", " if isinstance(last_message, HumanMessage):\n", " query = last_message.content\n", " hack_remove_tool_condition = True\n", " else:\n", " assert isinstance(last_message, ToolMessage)\n", " system = SystemMessage(\n", " \"Given this conversation, generate a query for additional tools. \"\n", " \"The query should be a short string containing what type of information \"\n", " \"is needed. If no further information is needed, \"\n", " \"set more_information_needed False and populate a blank string for the query.\"\n", " )\n", " input_messages = [system] + state[\"messages\"]\n", " response = llm.bind_tools([QueryForTools], tool_choice=True).invoke(\n", " input_messages\n", " )\n", " query = response.tool_calls[0][\"args\"][\"query\"]\n", " tool_documents = vector_store.similarity_search(query)\n", " if hack_remove_tool_condition:\n", " # Remove needed tool\n", " selected_tools = [\n", " document.id\n", " for document in tool_documents\n", " if document.metadata[\"tool_name\"] != \"Advanced_Micro_Devices\"\n", " ]\n", " else:\n", " selected_tools = [document.id for document in tool_documents]\n", " return {\"selected_tools\": selected_tools}\n", "\n", "\n", "graph_builder = StateGraph(State)\n", "graph_builder.add_node(\"agent\", agent)\n", "graph_builder.add_node(\"select_tools\", select_tools, retry=RetryPolicy(max_attempts=3))\n", "\n", "tool_node = ToolNode(tools=tools)\n", "graph_builder.add_node(\"tools\", tool_node)\n", "\n", "graph_builder.add_conditional_edges(\n", " \"agent\",\n", " tools_condition,\n", ")\n", "graph_builder.add_edge(\"tools\", \"select_tools\")\n", "graph_builder.add_edge(\"select_tools\", \"agent\")\n", "graph_builder.add_edge(START, \"select_tools\")\n", "graph = graph_builder.compile()" ] }, { "cell_type": "code", "execution_count": 6, "id": "9110789a-843a-4c21-aeff-8841b24f7674", "metadata": {}, "outputs": [ { "data": { "image/jpeg": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Image, display\n", "\n", "try:\n", " display(Image(graph.get_graph().draw_mermaid_png()))\n", "except Exception:\n", " # This requires some extra dependencies and is optional\n", " pass" ] }, { "cell_type": "code", "execution_count": 18, "id": "bee04c3d-0e36-4443-b0c8-10986a5f6e39", "metadata": {}, "outputs": [], "source": [ "user_input = \"Can you give me some information about AMD in 2022?\"\n", "\n", "result = graph.invoke({\"messages\": [(\"user\", user_input)]})" ] }, { "cell_type": "code", "execution_count": 19, "id": "6906fb50-435c-4473-bbb6-5353433b9199", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================\u001b[1m Human Message \u001b[0m=================================\n", "\n", "Can you give me some information about AMD in 2022?\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "Tool Calls:\n", " Accenture (call_L82JRUyIFilhzeTmPnNbPeVD)\n", " Call ID: call_L82JRUyIFilhzeTmPnNbPeVD\n", " Args:\n", " year: 2022\n", "=================================\u001b[1m Tool Message \u001b[0m=================================\n", "Name: Accenture\n", "\n", "Accenture had revenues of $100 in 2022.\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "Tool Calls:\n", " Advanced_Micro_Devices (call_k3zR9zS98gjiejmNgq6aVsXL)\n", " Call ID: call_k3zR9zS98gjiejmNgq6aVsXL\n", " Args:\n", " year: 2022\n", "=================================\u001b[1m Tool Message \u001b[0m=================================\n", "Name: Advanced_Micro_Devices\n", "\n", "Advanced Micro Devices had revenues of $100 in 2022.\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "\n", "In 2022, Advanced Micro Devices (AMD) had revenues of $100.\n" ] } ], "source": [ "for message in result[\"messages\"]:\n", " message.pretty_print()" ] }, { "cell_type": "markdown", "id": "177aedfa-cec5-45d0-82ad-efc0233aa6b4", "metadata": {}, "source": [ "## Next steps\n", "\n", "This guide provides a minimal implementation for dynamically selecting tools. There is a host of possible improvements and optimizations:\n", "\n", "- **Repeating tool selection**: Here, we repeated tool selection by modifying the `select_tools` node. Another option is to equip the agent with a `reselect_tools` tool, allowing it to re-select tools at its discretion.\n", "- **Optimizing tool selection**: In general, the full scope of [retrieval solutions](https://python.langchain.com/docs/concepts/#retrieval) are available for tool selection. Additional options include:\n", " - Group tools and retrieve over groups;\n", " - Use a chat model to select tools or groups of tool." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }