146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
"""Service for chat."""
|
|
|
|
import json
|
|
import logging
|
|
from typing import AsyncIterator, Any, List
|
|
|
|
from src.chat.schemas import ChatMessageItem, PostChatMessageResponse
|
|
from src.chat.domain.chatbot import Chatbot
|
|
from src.chat import constants as chat_constants
|
|
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def post_message(
|
|
thread_id: str,
|
|
message: str,
|
|
chatmodel: any,
|
|
chatmemory: any,
|
|
embeddings: any,
|
|
) -> PostChatMessageResponse:
|
|
"""Post a chat message to the chatbot and return the response.
|
|
|
|
Args:
|
|
thread_id: The unique identifier for the chat thread.
|
|
message: The content of the chat message.
|
|
chatmodel: The chat model to use for generating responses.
|
|
chatmemory: The chat memory to use for storing conversation history.
|
|
embeddings: The embeddings model to use for message embeddings.
|
|
|
|
Returns:
|
|
The response containing the full chat message history and thread ID.
|
|
"""
|
|
logger.info(f"Received message: {message} for thread {thread_id}")
|
|
|
|
# Create chatbot instance
|
|
chatbot = await Chatbot.create(
|
|
model=chatmodel,
|
|
memory=chatmemory,
|
|
system_prompt=chat_constants.SYSTEM_PROMPT,
|
|
)
|
|
|
|
# Send message to chatbot
|
|
response = await chatbot.chat(
|
|
message=message,
|
|
chat_id=thread_id,
|
|
)
|
|
|
|
# Parse the response to the correct format
|
|
chat_history = []
|
|
for message in response:
|
|
# Determine the role of the message
|
|
if isinstance(message, HumanMessage):
|
|
role = "user"
|
|
elif isinstance(message, AIMessage):
|
|
role = "assistant"
|
|
else:
|
|
continue # Skip any other message types
|
|
|
|
# Skip messages that are structured content, such as tool calls.
|
|
if not isinstance(message.content, str):
|
|
continue
|
|
|
|
# Append message to chat history
|
|
item = ChatMessageItem(
|
|
role=role,
|
|
content=message.content.strip(),
|
|
)
|
|
chat_history.append(item)
|
|
|
|
return PostChatMessageResponse(thread=thread_id, chat_history=chat_history)
|
|
|
|
|
|
async def post_message_stream(
|
|
thread_id: str,
|
|
message: str,
|
|
chatmodel: Any,
|
|
chatmemory: Any,
|
|
embeddings: Any,
|
|
) -> AsyncIterator[str]:
|
|
"""Post a chat message to the chatbot and stream progress updates (SSE)."""
|
|
logger.info(f"Received streaming message: {message} for thread {thread_id}")
|
|
|
|
try:
|
|
chatbot = await Chatbot.create(
|
|
model=chatmodel,
|
|
memory=chatmemory,
|
|
system_prompt=chat_constants.SYSTEM_PROMPT,
|
|
)
|
|
|
|
current_step = None
|
|
|
|
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
|
|
etype = event.get("type")
|
|
|
|
# In case we have transient status updates, forward them as-is
|
|
if etype == "status":
|
|
current_step = event.get("label")
|
|
yield f"data: {json.dumps({'type': 'status', 'label': current_step})}\n\n"
|
|
continue
|
|
|
|
# In case we have final response
|
|
if etype == "final":
|
|
response_from_event = event.get("response") or {}
|
|
|
|
# Use the chat history from the final event (already normalized by stream_events)
|
|
chat_history_payload = response_from_event.get("chat_history", [])
|
|
if isinstance(chat_history_payload, list):
|
|
# Convert to ChatMessageItem (content is already flattened by stream_events)
|
|
items: List[ChatMessageItem] = []
|
|
for it in chat_history_payload:
|
|
role = it.get("role")
|
|
content = it.get("content", "")
|
|
if role in ("user", "assistant") and content:
|
|
items.append(ChatMessageItem(role=role, content=content))
|
|
|
|
response = PostChatMessageResponse(
|
|
thread=thread_id, chat_history=items
|
|
)
|
|
# Yield the final response and exit
|
|
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
|
return
|
|
else:
|
|
# Unexpected payload format - log warning and return empty history
|
|
logger.warning(
|
|
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
|
|
)
|
|
response = PostChatMessageResponse(
|
|
thread=thread_id, chat_history=[]
|
|
)
|
|
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in streaming chat: {str(e)}", exc_info=True)
|
|
yield (
|
|
"data: "
|
|
+ json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "Ein Fehler ist aufgetreten. Bitte versuchen Sie es erneut.",
|
|
}
|
|
)
|
|
+ "\n\n"
|
|
)
|