Extracting high-cardinality categoricals#

Suppose we built a book recommendation chatbot, and as part of it we want to extract and filter on author name if that’s part of the user input. A user might ask a question like:

“what are books about aliens by Steven King”

If we’re not careful, our extraction system would most likely extract the author name “Steven King” from this input. This might cause us to miss all the most relevant book results, since the user was almost certainly looking for books by Stephen King.

This is a case of having to extract a high-cardinality categorical value. Given a dataset of books and their respective authors, there’s a large but finite number of valid author names, and we need some way of making sure our extraction system outputs valid and relevant author names even if the user input refers to invalid names.

We’ve built a dataset to help benchmark different approaches for dealing with this challenge. The dataset is simple: it is a collection of 23 mispelled and corrected human names. To use it for high-cardinality categorical testing, we’re going to generate a large set of valid names (~10,000) that includes the correct spellings of all the names in the dataset. Using this, we’ll test the ability of various extraction systems to extract a corrected name from the user question:

“what are books about aliens by {mispelled_name}”

where for each datapoint in our dataset, we’ll use the mispelled name as the input and expect the corrected name as the extracted output.

Setup#

We need to install a few packages and set some env vars first:

%pip install -qU langchain-benchmarks langchain-openai faker chromadb numpy scikit-learn
import getpass
import os

os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
from operator import attrgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langsmith import Client

from langchain_benchmarks import registry

This is the Name Correction benchmark in langchain-benchmarket:

client = Client()
task = registry["Name Correction"]
task.dataset_url
'https://smith.langchain.com/public/78df83ee-ba7f-41c6-832c-2b23327d4cf7/d'

NOTE: If you are running this notebook for the first time, clone the public dataset into your LangSmith organization by uncommenting the below:

# client.clone_public_dataset(task.dataset_url)
examples = list(client.list_examples(dataset_name=task.dataset_name))
for example in examples[:5]:
    print(example.inputs, example.outputs)
{'name': 'Tracy Cook'} {'name': 'Traci Cook'}
{'name': 'Dan Klein'} {'name': 'Daniel Klein'}
{'name': 'Jen Mcintosh'} {'name': 'Jennifer Mcintosh'}
{'name': 'Cassie Hull'} {'name': 'Cassandra Hull'}
{'name': 'Andy Williams'} {'name': 'Andrew Williams'}
def run_on_dataset(chain, run_name):
    client.run_on_dataset(
        dataset_name=task.dataset_name,
        llm_or_chain_factory=chain,
        evaluation=task.eval_config,
        project_name=run_name,
    )

Augmenting with more fake names#

For our tests we’ll create a list of 10,000 names that represent all the possible values for this category. This will include our target names from the dataset.

from faker import Faker

Faker.seed(42)
fake = Faker()
fake.seed_instance(0)

incorrect_names = [example.inputs["name"] for example in examples]
correct_names = [example.outputs["name"] for example in examples]

# We'll make sure that our list of valid names contains the correct spellings
# and not the incorrect spellings from our dataset
valid_names = list(
    set([fake.name() for _ in range(10_000)] + correct_names).difference(
        incorrect_names
    )
)
len(valid_names)
9382
valid_names[:3]
['Debra Lee', 'Kevin Harper', 'Donald Anderson']

Chain 1: Baseline#

As a baseline we’ll create a function-calling chain that has no information about the set of valid names.

class Search(BaseModel):
    query: str
    author: str


system = """Generate a relevant search query for a library system"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "{system}"),
        ("human", "what are books about aliens by {name}"),
    ]
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(Search)

query_analyzer_1 = (
    prompt.partial(system=system) | structured_llm | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_1, "GPT-3.5")
View the evaluation results for project 'GPT-3.5' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=f429ec84-b879-4e66-b7fb-ef7be69d1acd

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

As we might have expected, this gives us a Correct rate: 0%. Let’s see if we can do better :)

See the test run in LangSmith here.

Chain 2: All candidates in prompt#

Next, let’s dump the full list of valid names in the system prompt. We’ll need a model with a longer context window than the 16k token window of gpt-3.5-turbo-0125 so we’ll use gpt-4-0125-preview.

valid_names_str = "\n".join(valid_names)

system_2 = """Generate a relevant search query for a library system.

`author` attribute MUST be one of:

{valid_names_str}

Do NOT hallucinate author name!"""

formatted_system = system_2.format(valid_names_str=valid_names_str)
structured_llm_2 = ChatOpenAI(
    model="gpt-4-0125-preview", temperature=0
).with_structured_output(Search)
query_analyzer_2 = (
    prompt.partial(system=formatted_system)
    | structured_llm_2
    | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_2, "GPT-4, all names in prompt")
View the evaluation results for project 'GPT-4, all names in prompt' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=8c4cfdfc-3646-438e-be47-43a40d66292a

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

This gets us up to Correct rate: 26%.

See the test run in LangSmith here.

Chain 3: Top k candidates from vectorstore in prompt#

10,000 names is a lot to have in the prompt. Perhaps we could get better performance by shortening the list using vector search first to only include names that have the highest similarity to the user question. We can return to using GPT-3.5 as a result:

from langchain_community.vectorstores import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings

k = 10
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma.from_texts(valid_names, embeddings, collection_name="author_names")
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
system_chain = (
    (lambda name: f"what are books about aliens by {name}")
    | retriever
    | (
        lambda docs: system_2.format(
            valid_names_str="\n".join(d.page_content for d in docs)
        )
    )
)
query_analyzer_3 = (
    RunnablePassthrough.assign(system=system_chain)
    | prompt
    | structured_llm
    | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_3, f"GPT-3.5, top {k} names in prompt, vecstore")
View the evaluation results for project 'GPT-3.5, top 10 names in prompt, vecstore' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=af93ec50-ccbb-4b3c-908a-70c75e5516ea

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

This gets us up to Correct rate: 57%

See the test run in LangSmith here.

Chain 4: Top k candidates by ngram overlap in prompt#

Instead of using vector search, which requires embeddings and vector stores, a cheaper and faster approach would be to compare ngram overlap between the user question and the list of valid names:

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


# Function to generate character n-grams
def ngrams(string, n=3):
    string = "START" + string.replace(" ", "").lower() + "END"
    ngrams = zip(*[string[i:] for i in range(n)])
    return ["".join(ngram) for ngram in ngrams]


# Vectorize documents using TfidfVectorizer with the custom n-grams function
vectorizer = TfidfVectorizer(analyzer=ngrams)
tfidf_matrix = vectorizer.fit_transform(valid_names)
def get_names(query):
    # Vectorize query
    query_tfidf = vectorizer.transform([query])

    # Compute cosine similarity
    cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()

    # Find the index of the most similar document
    most_similar_document_indexes = np.argsort(-cosine_similarities)

    return "\n".join([valid_names[i] for i in most_similar_document_indexes[:k]])
def get_system_prompt(input):
    name = input["name"]
    valid_names_str = get_names(f"what are books about aliens by {name}")
    return system_2.format(valid_names_str=valid_names_str)


query_analyzer_4 = (
    RunnablePassthrough.assign(system=get_system_prompt)
    | prompt
    | structured_llm
    | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_4, f"GPT-3.5, top {k} names in prompt, ngram")
View the evaluation results for project 'GPT-3.5, top 10 names in prompt, ngram' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=bc28b761-2ac9-4391-8df1-758f0a4d5100

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

This gets us up to Correct rate: 65%

See the test run in LangSmith here.

Chain 5: Replace with top candidate from vectorstore#

Instead of (or in addition to) searching for similar candidates before extraction, we can also compare and correct the extracted value after-the-fact a search over the valid names. With Pydantic classes this is easy using a validator:

from langchain_core.pydantic_v1 import validator


class Search(BaseModel):
    query: str
    author: str

    @validator("author")
    def double(cls, v: str) -> str:
        return vectorstore.similarity_search(v, k=1)[0].page_content


structured_llm_3 = llm.with_structured_output(Search)
query_analyzer_5 = (
    prompt.partial(system=system) | structured_llm_3 | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_5, f"GPT-3.5, correct name, vecstore")
View the evaluation results for project 'GPT-3.5, correct name, vecstore' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=e3eda1e1-bc25-46e8-a4fb-db324cefd1c9

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

This gets us up to Correct rate: 83%

See the test run in LangSmith here.

Chain 6: Replace with top candidate by ngram overlap#

We can do the same with ngram overlap search instead of vector search:

class Search(BaseModel):
    query: str
    author: str

    @validator("author")
    def double(cls, v: str) -> str:
        return get_names(v).split("\n")[0]


structured_llm_4 = llm.with_structured_output(Search)
query_analyzer_6 = (
    prompt.partial(system=system) | structured_llm_4 | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_6, f"GPT-3.5, correct name, ngram")
View the evaluation results for project 'GPT-3.5, correct name, ngram' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=8f8846c8-2ada-41bc-8d2c-e1d56e7c92ce

View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23

This gets us up to Correct rate: 74%, slightly worse than Chain 5 (same thing using vector search insteadf of ngram).

See the test run in LangSmith here.

See all results in LangSmith#

To see the full dataset and all the test results, head to LangSmith: https://smith.langchain.com/public/8c0a4c25-426d-4582-96fc-d7def170be76/d