How to disable streaming for models that don't support it¶
Certain LLM models do not support streaming - perhaps most noticeably the new (depending on when you're reading this) o1 models from OpenAI. This can break LangGraph when using astream_events
since under the hood, LangGraph is calling all the models in streaming mode. This guide shows how you can avoid this error by returning the final message as a chunk instead of receiving a streaming output.
Let's see how this would cause an error by defining the simplest graph we can:
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
llm = ChatOpenAI(model="o1-preview",temperature=1)
graph_builder = StateGraph(MessagesState)
def chatbot(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
Without disabling streaming¶
Now that we've defined our graph, let's try to call astream_events
without disabling streaming. This should throw an error because the o1
model does not support streaming natively:
input = {"messages": {"role":"user", "content":"how many r's are in strawberry?"}}
try:
async for event in graph.astream_events(input, version="v2"):
if event["event"] == "on_chat_model_end":
print(event["data"]["output"].content, end="", flush=True)
except:
print("Streaming not supported!")
--------------------------------------------------------------------------- BadRequestError Traceback (most recent call last) Cell In[2], line 2 1 input = {"messages": {"role":"user", "content":"how many r's are in strawberry?"}} ----> 2 async for event in graph.astream_events(input, version="v2"): 3 if event["event"] == "on_chat_model_end": 4 print(event["data"]["output"].content, end="", flush=True) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/runnables/base.py:1377, in Runnable.astream_events(self, input, config, version, include_names, include_types, include_tags, exclude_names, exclude_types, exclude_tags, **kwargs) 1372 raise NotImplementedError( 1373 'Only versions "v1" and "v2" of the schema is currently supported.' 1374 ) 1376 async with aclosing(event_stream): -> 1377 async for event in event_stream: 1378 yield event File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py:1006, in _astream_events_implementation_v2(runnable, input, config, include_names, include_types, include_tags, exclude_names, exclude_types, exclude_tags, **kwargs) 1004 # Await it anyway, to run any cleanup code, and propagate any exceptions 1005 try: -> 1006 await task 1007 except asyncio.CancelledError: 1008 pass File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py:966, in _astream_events_implementation_v2.<locals>.consume_astream() 963 try: 964 # if astream also calls tap_output_aiter this will be a no-op 965 async with aclosing(runnable.astream(input, config, **kwargs)) as stream: --> 966 async for _ in event_streamer.tap_output_aiter(run_id, stream): 967 # All the content will be picked up 968 pass 969 finally: File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py:182, in _AstreamEventsCallbackHandler.tap_output_aiter(self, run_id, output) 180 tap = self.is_tapped.setdefault(run_id, sentinel) 181 # wait for first chunk --> 182 first = await py_anext(output, default=sentinel) 183 if first is sentinel: 184 return File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/utils/aiter.py:78, in py_anext.<locals>.anext_impl() 71 async def anext_impl() -> Union[T, Any]: 72 try: 73 # The C code is way more low-level than this, as it implements 74 # all methods of the iterator protocol. In this implementation 75 # we're relying on higher-level coroutine concepts, but that's 76 # exactly what we want -- crosstest pure-Python high-level 77 # implementation and low-level C anext() iterators. ---> 78 return await __anext__(iterator) 79 except StopAsyncIteration: 80 return default File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langgraph/pregel/__init__.py:1477, in Pregel.astream(self, input, config, stream_mode, output_keys, interrupt_before, interrupt_after, debug, subgraphs) 1466 # Similarly to Bulk Synchronous Parallel / Pregel model 1467 # computation proceeds in steps, while there are channel updates 1468 # channel updates from step N are only visible in step N+1 1469 # channels are guaranteed to be immutable for the duration of the step, 1470 # with channel updates applied only at the transition between steps 1471 while loop.tick( 1472 input_keys=self.input_channels, 1473 interrupt_before=interrupt_before, 1474 interrupt_after=interrupt_after, 1475 manager=run_manager, 1476 ): -> 1477 async for _ in runner.atick( 1478 loop.tasks.values(), 1479 timeout=self.step_timeout, 1480 retry_policy=self.retry_policy, 1481 get_waiter=get_waiter, 1482 ): 1483 # emit output 1484 for o in output(): 1485 yield o File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langgraph/pregel/runner.py:124, in PregelRunner.atick(self, tasks, reraise, timeout, retry_policy, get_waiter) 122 task = tasks[0] 123 try: --> 124 await arun_with_retry(task, retry_policy, stream=self.use_astream) 125 self.commit(task, None) 126 except Exception as exc: File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langgraph/pregel/retry.py:92, in arun_with_retry(task, retry_policy, stream) 90 # run the task 91 if stream: ---> 92 async for _ in task.proc.astream(task.input, config): 93 pass 94 else: File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langgraph/utils/runnable.py:503, in RunnableSeq.astream(self, input, config, **kwargs) 501 output: Output = None 502 add_supported = False --> 503 async for chunk in aiterator: 504 yield chunk 505 # collect final output File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py:182, in _AstreamEventsCallbackHandler.tap_output_aiter(self, run_id, output) 180 tap = self.is_tapped.setdefault(run_id, sentinel) 181 # wait for first chunk --> 182 first = await py_anext(output, default=sentinel) 183 if first is sentinel: 184 return File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/utils/aiter.py:78, in py_anext.<locals>.anext_impl() 71 async def anext_impl() -> Union[T, Any]: 72 try: 73 # The C code is way more low-level than this, as it implements 74 # all methods of the iterator protocol. In this implementation 75 # we're relying on higher-level coroutine concepts, but that's 76 # exactly what we want -- crosstest pure-Python high-level 77 # implementation and low-level C anext() iterators. ---> 78 return await __anext__(iterator) 79 except StopAsyncIteration: 80 return default File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/runnables/base.py:1444, in Runnable.atransform(self, input, config, **kwargs) 1441 final: Input 1442 got_first_val = False -> 1444 async for ichunk in input: 1445 # The default implementation of transform is to buffer input and 1446 # then call stream. 1447 # It'll attempt to gather all input into a single chunk using 1448 # the `+` operator. 1449 # If the input is not addable, then we'll assume that we can 1450 # only operate on the last chunk, 1451 # and we'll iterate until we get to the last chunk. 1452 if not got_first_val: 1453 final = ichunk File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/runnables/base.py:1006, in Runnable.astream(self, input, config, **kwargs) 988 async def astream( 989 self, 990 input: Input, 991 config: Optional[RunnableConfig] = None, 992 **kwargs: Optional[Any], 993 ) -> AsyncIterator[Output]: 994 """ 995 Default implementation of astream, which calls ainvoke. 996 Subclasses should override this method if they support streaming output. (...) 1004 The output of the Runnable. 1005 """ -> 1006 yield await self.ainvoke(input, config, **kwargs) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langgraph/utils/runnable.py:171, in RunnableCallable.ainvoke(self, input, config, **kwargs) 169 context.run(_set_config_context, config) 170 if ASYNCIO_ACCEPTS_CONTEXT: --> 171 ret = await asyncio.create_task( 172 self.afunc(input, **kwargs), context=context 173 ) 174 else: 175 ret = await self.afunc(input, **kwargs) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/runnables/config.py:590, in run_in_executor(executor_or_config, func, *args, **kwargs) 586 raise RuntimeError from exc 588 if executor_or_config is None or isinstance(executor_or_config, dict): 589 # Use default executor with context copied from current context --> 590 return await asyncio.get_running_loop().run_in_executor( 591 None, 592 cast(Callable[..., T], partial(copy_context().run, wrapper)), 593 ) 595 return await asyncio.get_running_loop().run_in_executor(executor_or_config, wrapper) File ~/.pyenv/versions/3.11.9/lib/python3.11/concurrent/futures/thread.py:58, in _WorkItem.run(self) 55 return 57 try: ---> 58 result = self.fn(*self.args, **self.kwargs) 59 except BaseException as exc: 60 self.future.set_exception(exc) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/runnables/config.py:581, in run_in_executor.<locals>.wrapper() 579 def wrapper() -> T: 580 try: --> 581 return func(*args, **kwargs) 582 except StopIteration as exc: 583 # StopIteration can't be set on an asyncio.Future 584 # it raises a TypeError and leaves the Future pending forever 585 # so we need to convert it to a RuntimeError 586 raise RuntimeError from exc Cell In[1], line 10, in chatbot(state) 9 def chatbot(state: MessagesState): ---> 10 return {"messages": [llm.invoke(state["messages"])]} File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:286, in BaseChatModel.invoke(self, input, config, stop, **kwargs) 275 def invoke( 276 self, 277 input: LanguageModelInput, (...) 281 **kwargs: Any, 282 ) -> BaseMessage: 283 config = ensure_config(config) 284 return cast( 285 ChatGeneration, --> 286 self.generate_prompt( 287 [self._convert_input(input)], 288 stop=stop, 289 callbacks=config.get("callbacks"), 290 tags=config.get("tags"), 291 metadata=config.get("metadata"), 292 run_name=config.get("run_name"), 293 run_id=config.pop("run_id", None), 294 **kwargs, 295 ).generations[0][0], 296 ).message File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:786, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs) 778 def generate_prompt( 779 self, 780 prompts: List[PromptValue], (...) 783 **kwargs: Any, 784 ) -> LLMResult: 785 prompt_messages = [p.to_messages() for p in prompts] --> 786 return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:643, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs) 641 if run_managers: 642 run_managers[i].on_llm_error(e, response=LLMResult(generations=[])) --> 643 raise e 644 flattened_outputs = [ 645 LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item] 646 for res in results 647 ] 648 llm_output = self._combine_llm_outputs([res.llm_output for res in results]) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:633, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs) 630 for i, m in enumerate(messages): 631 try: 632 results.append( --> 633 self._generate_with_cache( 634 m, 635 stop=stop, 636 run_manager=run_managers[i] if run_managers else None, 637 **kwargs, 638 ) 639 ) 640 except BaseException as e: 641 if run_managers: File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:843, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs) 837 if self._should_stream( 838 async_api=False, 839 run_manager=run_manager, 840 **kwargs, 841 ): 842 chunks: List[ChatGenerationChunk] = [] --> 843 for chunk in self._stream(messages, stop=stop, **kwargs): 844 chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) 845 if run_manager: File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/langchain_openai/chat_models/base.py:621, in BaseChatOpenAI._stream(self, messages, stop, run_manager, **kwargs) 619 base_generation_info = {"headers": dict(raw_response.headers)} 620 else: --> 621 response = self.client.create(**payload) 622 with response: 623 is_first_chunk = True File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/openai/_utils/_utils.py:274, in required_args.<locals>.inner.<locals>.wrapper(*args, **kwargs) 272 msg = f"Missing required argument: {quote(missing[0])}" 273 raise TypeError(msg) --> 274 return func(*args, **kwargs) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/openai/resources/chat/completions.py:704, in Completions.create(self, messages, model, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, n, parallel_tool_calls, presence_penalty, response_format, seed, service_tier, stop, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, extra_headers, extra_query, extra_body, timeout) 668 @required_args(["messages", "model"], ["messages", "model", "stream"]) 669 def create( 670 self, (...) 701 timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 702 ) -> ChatCompletion | Stream[ChatCompletionChunk]: 703 validate_response_format(response_format) --> 704 return self._post( 705 "/chat/completions", 706 body=maybe_transform( 707 { 708 "messages": messages, 709 "model": model, 710 "frequency_penalty": frequency_penalty, 711 "function_call": function_call, 712 "functions": functions, 713 "logit_bias": logit_bias, 714 "logprobs": logprobs, 715 "max_completion_tokens": max_completion_tokens, 716 "max_tokens": max_tokens, 717 "n": n, 718 "parallel_tool_calls": parallel_tool_calls, 719 "presence_penalty": presence_penalty, 720 "response_format": response_format, 721 "seed": seed, 722 "service_tier": service_tier, 723 "stop": stop, 724 "stream": stream, 725 "stream_options": stream_options, 726 "temperature": temperature, 727 "tool_choice": tool_choice, 728 "tools": tools, 729 "top_logprobs": top_logprobs, 730 "top_p": top_p, 731 "user": user, 732 }, 733 completion_create_params.CompletionCreateParams, 734 ), 735 options=make_request_options( 736 extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 737 ), 738 cast_to=ChatCompletion, 739 stream=stream or False, 740 stream_cls=Stream[ChatCompletionChunk], 741 ) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/openai/_base_client.py:1260, in SyncAPIClient.post(self, path, cast_to, body, options, files, stream, stream_cls) 1246 def post( 1247 self, 1248 path: str, (...) 1255 stream_cls: type[_StreamT] | None = None, 1256 ) -> ResponseT | _StreamT: 1257 opts = FinalRequestOptions.construct( 1258 method="post", url=path, json_data=body, files=to_httpx_files(files), **options 1259 ) -> 1260 return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/openai/_base_client.py:937, in SyncAPIClient.request(self, cast_to, options, remaining_retries, stream, stream_cls) 928 def request( 929 self, 930 cast_to: Type[ResponseT], (...) 935 stream_cls: type[_StreamT] | None = None, 936 ) -> ResponseT | _StreamT: --> 937 return self._request( 938 cast_to=cast_to, 939 options=options, 940 stream=stream, 941 stream_cls=stream_cls, 942 remaining_retries=remaining_retries, 943 ) File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/openai/_base_client.py:1041, in SyncAPIClient._request(self, cast_to, options, remaining_retries, stream, stream_cls) 1038 err.response.read() 1040 log.debug("Re-raising status error") -> 1041 raise self._make_status_error_from_response(err.response) from None 1043 return self._process_response( 1044 cast_to=cast_to, 1045 options=options, (...) 1049 retries_taken=options.get_max_retries(self.max_retries) - retries, 1050 ) BadRequestError: Error code: 400 - {'error': {'message': "Unsupported value: 'stream' does not support true with this model. Only the default (false) value is supported.", 'type': 'invalid_request_error', 'param': 'stream', 'code': 'unsupported_value'}}
An error occurred as we expected, luckily there is an easy fix!
Disabling streaming¶
Now without making any changse to our graph, let's set the disable_streaming
parameter on our model to be True
which will solve our issues:
llm = ChatOpenAI(model="o1-preview",temperature=1,disable_streaming=True)
graph_builder = StateGraph(MessagesState)
def chatbot(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
And now, rerunning with the same input, we should see no errors:
input = {"messages": {"role":"user", "content":"how many r's are in strawberry?"}}
async for event in graph.astream_events(input, version="v2"):
if event["event"] == "on_chat_model_end":
print(event["data"]["output"].content, end="", flush=True)
There are **three** letter "r"s in the word "strawberry." Here's the breakdown: - **S** - **T** - **R** (1st "r") - **A** - **W** - **B** - **E** - **R** (2nd "r") - **R** (3rd "r") - **Y** So, the letters "r" appear in the 3rd, 8th, and 9th positions of the word "strawberry."