242 lines
9 KiB
Python
242 lines
9 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Streaming helper utilities for chat message processing and normalization."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Literal, Mapping, Optional
|
|
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
|
|
Role = Literal["user", "assistant", "system", "tool"]
|
|
|
|
|
|
class ChatStreamingHelper:
|
|
"""Pure helper methods for streaming and message normalization.
|
|
|
|
This class provides static utility methods for converting between different
|
|
message formats, extracting content, and normalizing message structures
|
|
for streaming chat applications.
|
|
"""
|
|
|
|
@staticmethod
|
|
def role_from_message(*, msg: BaseMessage) -> Role:
|
|
"""Extract the role from a BaseMessage instance.
|
|
|
|
Args:
|
|
msg: The BaseMessage instance to extract the role from.
|
|
|
|
Returns:
|
|
The role as a string literal: "user", "assistant", "system", or "tool".
|
|
Defaults to "assistant" if the message type is not recognized.
|
|
|
|
Examples:
|
|
>>> from langchain_core.messages import HumanMessage
|
|
>>> msg = HumanMessage(content="Hello")
|
|
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
|
'user'
|
|
"""
|
|
if isinstance(msg, HumanMessage):
|
|
return "user"
|
|
if isinstance(msg, AIMessage):
|
|
return "assistant"
|
|
if isinstance(msg, SystemMessage):
|
|
return "system"
|
|
if isinstance(msg, ToolMessage):
|
|
return "tool"
|
|
return getattr(msg, "role", "assistant")
|
|
|
|
@staticmethod
|
|
def flatten_content(*, content: Any) -> str:
|
|
"""Convert complex content structures to plain text.
|
|
|
|
This method handles various content formats including strings, lists of
|
|
content parts, and dictionaries with text fields. It's designed to
|
|
normalize content from different message sources into a consistent
|
|
plain text format.
|
|
|
|
Args:
|
|
content: The content to flatten. Can be:
|
|
- str: Returned as-is after stripping whitespace
|
|
- list: Each item processed and joined with newlines
|
|
- dict: Text extracted from "text" or "content" fields
|
|
- None: Returns empty string
|
|
- Any other type: Converted to string
|
|
|
|
Returns:
|
|
The flattened content as a plain text string with whitespace stripped.
|
|
|
|
Examples:
|
|
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
|
>>> ChatStreamingHelper.flatten_content(content=content)
|
|
'Hello\nworld'
|
|
|
|
>>> content = {"text": "Simple message"}
|
|
>>> ChatStreamingHelper.flatten_content(content=content)
|
|
'Simple message'
|
|
"""
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content.strip()
|
|
if isinstance(content, list):
|
|
parts: List[str] = []
|
|
for part in content:
|
|
if isinstance(part, dict):
|
|
if "text" in part and isinstance(part["text"], str):
|
|
parts.append(part["text"])
|
|
elif part.get("type") == "text" and isinstance(
|
|
part.get("text"), str
|
|
):
|
|
parts.append(part["text"])
|
|
elif "content" in part and isinstance(part["content"], str):
|
|
parts.append(part["content"])
|
|
else:
|
|
# Fallback for unknown dictionary structures
|
|
val = part.get("value")
|
|
if isinstance(val, str):
|
|
parts.append(val)
|
|
else:
|
|
parts.append(str(part))
|
|
return "\n".join(p.strip() for p in parts if p is not None)
|
|
if isinstance(content, dict):
|
|
if "text" in content and isinstance(content["text"], str):
|
|
return content["text"].strip()
|
|
if "content" in content and isinstance(content["content"], str):
|
|
return content["content"].strip()
|
|
return str(content).strip()
|
|
|
|
@staticmethod
|
|
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
|
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
|
|
|
This method normalizes BaseMessage instances into a consistent dictionary
|
|
format suitable for JSON serialization and streaming to clients.
|
|
|
|
Args:
|
|
msg: The BaseMessage instance to convert.
|
|
|
|
Returns:
|
|
A dictionary containing:
|
|
- "role": The message role (user, assistant, system, tool)
|
|
- "content": The flattened message content as plain text
|
|
- "tool_calls": Tool calls if present (optional)
|
|
- "name": Message name if present (optional)
|
|
|
|
Examples:
|
|
>>> from langchain_core.messages import HumanMessage
|
|
>>> msg = HumanMessage(content="Hello there")
|
|
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
|
>>> result["role"]
|
|
'user'
|
|
>>> result["content"]
|
|
'Hello there'
|
|
"""
|
|
payload: Dict[str, Any] = {
|
|
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
|
"content": ChatStreamingHelper.flatten_content(
|
|
content=getattr(msg, "content", "")
|
|
),
|
|
}
|
|
tool_calls = getattr(msg, "tool_calls", None)
|
|
if tool_calls:
|
|
payload["tool_calls"] = tool_calls
|
|
name = getattr(msg, "name", None)
|
|
if name:
|
|
payload["name"] = name
|
|
return payload
|
|
|
|
@staticmethod
|
|
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
|
"""Convert a dictionary-shaped message to a normalized dictionary.
|
|
|
|
This method handles messages that come from serialized state and are
|
|
represented as dictionaries rather than BaseMessage instances. It
|
|
normalizes various dictionary formats into a consistent structure.
|
|
|
|
Args:
|
|
obj: The dictionary-shaped message to convert. Expected to contain
|
|
fields like "role", "type", "content", "text", etc.
|
|
|
|
Returns:
|
|
A normalized dictionary containing:
|
|
- "role": The message role (user, assistant, system, tool)
|
|
- "content": The flattened message content as plain text
|
|
- "tool_calls": Tool calls if present (optional)
|
|
- "name": Message name if present (optional)
|
|
|
|
Examples:
|
|
>>> obj = {"type": "human", "content": "Hello"}
|
|
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
|
>>> result["role"]
|
|
'user'
|
|
>>> result["content"]
|
|
'Hello'
|
|
"""
|
|
role: Optional[str] = obj.get("role")
|
|
if not role:
|
|
# Handle alternative type field mappings
|
|
typ = obj.get("type")
|
|
if typ in ("human", "user"):
|
|
role = "user"
|
|
elif typ in ("ai", "assistant"):
|
|
role = "assistant"
|
|
elif typ in ("system",):
|
|
role = "system"
|
|
elif typ in ("tool", "function"):
|
|
role = "tool"
|
|
|
|
content = obj.get("content")
|
|
if content is None and "text" in obj:
|
|
content = obj["text"]
|
|
|
|
out: Dict[str, Any] = {
|
|
"role": role or "assistant",
|
|
"content": ChatStreamingHelper.flatten_content(content=content),
|
|
}
|
|
if "tool_calls" in obj:
|
|
out["tool_calls"] = obj["tool_calls"]
|
|
if obj.get("name"):
|
|
out["name"] = obj["name"]
|
|
return out
|
|
|
|
@staticmethod
|
|
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
|
"""Extract messages from LangGraph output objects.
|
|
|
|
This method handles various output formats from LangGraph execution,
|
|
extracting the messages list from different possible structures.
|
|
|
|
Args:
|
|
output_obj: The output object from LangGraph execution. Can be:
|
|
- An object with a "messages" attribute
|
|
- A dictionary with a "messages" key
|
|
- Any other object (returns empty list)
|
|
|
|
Returns:
|
|
A list of extracted messages, or an empty list if no messages
|
|
are found or if the output object is None.
|
|
|
|
Examples:
|
|
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
|
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
|
>>> len(messages)
|
|
1
|
|
"""
|
|
if output_obj is None:
|
|
return []
|
|
|
|
# Try to parse dicts first
|
|
if isinstance(output_obj, dict):
|
|
msgs = output_obj.get("messages")
|
|
return msgs if isinstance(msgs, list) else []
|
|
|
|
# Then try to get messages attribute
|
|
msgs = getattr(output_obj, "messages", None)
|
|
return msgs if isinstance(msgs, list) else []
|