fixed merge conflicts
This commit is contained in:
commit
411a6a081a
21 changed files with 4236 additions and 15 deletions
17
config.ini
17
config.ini
|
|
@ -36,6 +36,12 @@ Web_Crawl_RETRY_DELAY = 2
|
|||
# Web Research configuration
|
||||
Web_Research_MAX_DEPTH = 2
|
||||
Web_Research_MAX_LINKS_PER_DOMAIN = 4
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
<<<<<<< Updated upstream
|
||||
Web_Research_CRAWL_TIMEOUT_MINUTES = 10
|
||||
=======
|
||||
>>>>>>> feat/chatbot-althaus-integration
|
||||
Web_Research_CRAWL_TIMEOUT_MINUTES = 10
|
||||
|
||||
# STAC API Connector configuration (Swiss Topo)
|
||||
|
|
@ -43,4 +49,13 @@ Connector_StacSwisstopo_BASE_URL = https://data.geo.admin.ch/api/stac/v1
|
|||
Connector_StacSwisstopo_TIMEOUT = 30
|
||||
Connector_StacSwisstopo_MAX_RETRIES = 3
|
||||
Connector_StacSwisstopo_RETRY_DELAY = 1.0
|
||||
Connector_StacSwisstopo_ENABLE_CACHE = True
|
||||
<<<<<<< HEAD
|
||||
Connector_StacSwisstopo_ENABLE_CACHE = True
|
||||
=======
|
||||
Connector_StacSwisstopo_ENABLE_CACHE = True
|
||||
|
||||
# Tavily AI Connector configuration (Web Search & Research)
|
||||
# Get your API key from https://tavily.com
|
||||
# Connector_AiTavily_API_SECRET = your_tavily_api_key_here
|
||||
>>>>>>> Stashed changes
|
||||
>>>>>>> feat/chatbot-althaus-integration
|
||||
|
|
|
|||
127
modules/features/chatbot/README.md
Normal file
127
modules/features/chatbot/README.md
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# Chatbot Feature Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
The chatbot feature provides an intelligent conversational interface that processes user queries, executes database searches, performs web research, and generates contextual responses. The implementation leverages LangGraph to orchestrate complex multi-step workflows while seamlessly integrating with the existing infrastructure including AI Center, database systems, and event streaming.
|
||||
|
||||
## Architecture
|
||||
|
||||
The chatbot feature follows a modular architecture centered around LangGraph's state graph pattern. The system processes user messages through a structured workflow that can dynamically invoke tools, query databases, search the web, and generate responses based on context.
|
||||
|
||||
### Core Components
|
||||
|
||||
**Workflow Management**: Each conversation is managed as a workflow with a unique identifier. Workflows track the conversation state, message history, and processing status. New conversations create fresh workflows, while existing conversations resume their workflows with incremented round numbers.
|
||||
|
||||
**LangGraph State Graph**: The heart of the chatbot is a LangGraph state graph that orchestrates the conversation flow. The graph maintains conversation state through a checkpointer system and routes between agent processing and tool execution nodes based on the model's decisions.
|
||||
|
||||
**Event Streaming**: Real-time updates are delivered to clients through an event-driven streaming system. Status updates, messages, logs, and completion events are emitted asynchronously and queued for delivery to connected clients.
|
||||
|
||||
## LangGraph Implementation
|
||||
|
||||
### State Management
|
||||
|
||||
LangGraph manages conversation state through a state graph that tracks messages in the conversation. The state is persisted using a custom checkpointer that bridges LangGraph's checkpoint system with the existing database infrastructure. This allows conversations to be resumed, state to be maintained across sessions, and message history to be preserved.
|
||||
|
||||
### Graph Structure
|
||||
|
||||
The workflow graph consists of two primary nodes:
|
||||
|
||||
**Agent Node**: Processes user messages and conversation history using the AI model. The agent analyzes the input, determines what actions are needed, and decides whether to generate a response directly or invoke tools. The agent has access to the full conversation history, which is automatically trimmed to fit within the model's context window while preserving the most recent and relevant messages.
|
||||
|
||||
**Tools Node**: Executes tools when the agent determines they are needed. Tools can query databases, search the web, or send status updates. After tool execution, the workflow returns to the agent node to process the tool results and generate an appropriate response.
|
||||
|
||||
### Conditional Routing
|
||||
|
||||
The graph uses conditional edges to determine workflow progression. After the agent processes a message, the system checks whether the agent requested tool calls. If tools were requested, the workflow routes to the tools node. If no tools are needed, the workflow completes and returns the final response to the user.
|
||||
|
||||
### Message Window Management
|
||||
|
||||
To handle long conversations that exceed model context limits, the system implements intelligent message windowing. Messages are trimmed from the beginning while preserving system prompts and ensuring the conversation ends on a human or tool message. This maintains context continuity while respecting token limits.
|
||||
|
||||
## Integration with Existing Infrastructure
|
||||
|
||||
### AI Center Integration
|
||||
|
||||
The chatbot integrates with the AI Center through a custom bridge that implements LangChain's chat model interface. This bridge allows LangGraph to use AI Center's model selection, routing, and calling infrastructure while maintaining compatibility with LangChain's expected interfaces.
|
||||
|
||||
**Model Selection**: When processing messages, the bridge converts LangChain message formats to AI Center's expected format and uses the model selector to choose the appropriate AI model based on operation type, processing mode, and available models. The selection respects role-based access control and considers model capabilities.
|
||||
|
||||
**Tool Calling Support**: The bridge handles tool calling by detecting when models support function calling and converting tool definitions between LangChain and AI Center formats. For OpenAI-compatible models, the bridge directly calls the API with tool definitions. For other models, it relies on connector-specific implementations.
|
||||
|
||||
**Operation Types**: The chatbot uses AI Center's operation type system to select models appropriate for different tasks. Database queries use data analysis operation types, while web searches use web search operation types, ensuring optimal model selection for each task.
|
||||
|
||||
### Database Integration
|
||||
|
||||
**Message Storage**: All conversation messages are stored in the existing chat database through the database interface. The custom checkpointer converts between LangGraph's message format and the database's message format, ensuring seamless persistence. Messages are stored with metadata including workflow identifiers, round numbers, sequence numbers, and timestamps.
|
||||
|
||||
**Workflow Persistence**: Workflow state is maintained in the database, allowing conversations to be resumed across sessions. The system tracks workflow status, current round numbers, and activity timestamps. When resuming a conversation, the workflow round number is incremented to maintain conversation continuity.
|
||||
|
||||
**Document Management**: User-uploaded files are tracked as document references within workflows. The system creates document records that link files to specific messages and rounds, enabling the chatbot to reference and process uploaded documents in its responses.
|
||||
|
||||
### Tool Integration
|
||||
|
||||
**SQL Query Tool**: The chatbot includes a tool that executes SQL queries against the preprocessor database. This tool uses the existing database connector infrastructure, ensuring proper connection management, query execution, and result formatting. The tool returns formatted results that the agent can use to answer user questions about products, inventory, prices, and other database-stored information.
|
||||
|
||||
**Web Search Tool**: Web research capabilities are provided through a Tavily search tool that integrates with AI Center's Tavily connector. The tool uses AI Center's model registry and selector to find and use Tavily models, ensuring consistent integration with the existing AI infrastructure. Search results include full content from multiple sources, allowing comprehensive research.
|
||||
|
||||
**Streaming Status Tool**: A special tool allows the agent to send status updates during processing. These updates are captured by the event streaming system and delivered to clients in real-time, providing users with visibility into what the chatbot is doing.
|
||||
|
||||
### Event Streaming System
|
||||
|
||||
The chatbot uses an event-driven streaming architecture to deliver real-time updates to clients. An event manager maintains queues for each workflow, allowing multiple clients to receive updates for the same conversation.
|
||||
|
||||
**Event Types**: The system emits several types of events including chat data events (messages and logs), completion events, and error events. Each event includes metadata about its type, timestamp, and associated workflow.
|
||||
|
||||
**Queue Management**: Event queues are created when workflows start and cleaned up after conversations complete. The cleanup system ensures resources are properly released while allowing sufficient time for clients to receive all events.
|
||||
|
||||
**Event Bridging**: LangGraph's native event streaming is bridged to the custom event system. Status updates from tool calls are captured and converted to the appropriate event format. Final responses are extracted from LangGraph's output and emitted as message events.
|
||||
|
||||
## Configuration System
|
||||
|
||||
The chatbot supports multiple configuration profiles loaded from JSON files. Each configuration specifies:
|
||||
|
||||
**System Prompts**: Customizable instructions that define the chatbot's behavior, personality, and capabilities. Prompts can include placeholders for dynamic content like dates.
|
||||
|
||||
**Database Schema**: Information about available database tables and structures, enabling the agent to generate appropriate queries.
|
||||
|
||||
**Tool Configuration**: Settings for which tools are enabled and how they should behave. This includes SQL query settings, web search parameters, and streaming options.
|
||||
|
||||
**Model Configuration**: Operation types and processing modes that determine which AI models are selected for different tasks.
|
||||
|
||||
## Conversation Flow
|
||||
|
||||
### Initial Request Processing
|
||||
|
||||
When a user submits a message, the system first creates or loads the workflow. For new conversations, a conversation name is generated using AI based on the user's initial prompt. The user's message is stored in the database and an event is emitted to notify connected clients.
|
||||
|
||||
### Background Processing
|
||||
|
||||
Message processing occurs asynchronously in the background, allowing the API to return immediately while processing continues. The system creates a LangGraph chatbot instance configured with the appropriate model, memory checkpointer, and tools.
|
||||
|
||||
### Tool Execution
|
||||
|
||||
When the agent determines that tools are needed, it requests tool calls. The tools node executes the requested tools, which may involve database queries, web searches, or status updates. Tool results are added to the conversation state and returned to the agent for processing.
|
||||
|
||||
### Response Generation
|
||||
|
||||
After tool execution or when no tools are needed, the agent generates a final response based on the conversation history and any tool results. The response is stored in the database through the checkpointer system and emitted as an event to connected clients.
|
||||
|
||||
### Completion
|
||||
|
||||
Once processing completes, a completion event is emitted and the workflow status is updated. The event queue remains available for a grace period to ensure all clients receive the final events before cleanup.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The system includes comprehensive error handling at multiple levels. Workflow errors are caught and stored as error messages in the database. Error events are emitted to notify clients of failures. The system gracefully handles cases where workflows are stopped by users, preventing unnecessary error messages from being stored.
|
||||
|
||||
## Memory and Context Management
|
||||
|
||||
The custom checkpointer bridges LangGraph's checkpoint system with the database, ensuring conversation history is preserved. The system intelligently filters messages when storing checkpoints, skipping intermediate tool call requests and only storing final user and assistant messages. This prevents duplicate storage while maintaining complete conversation context.
|
||||
|
||||
## Multi-Language Support
|
||||
|
||||
The system supports multiple languages through configuration. Conversation names are generated in the user's preferred language, and the AI models can process and respond in various languages based on the system prompt and user input.
|
||||
|
||||
## Scalability Considerations
|
||||
|
||||
The asynchronous processing model allows the system to handle multiple concurrent conversations efficiently. Each workflow operates independently with its own event queue and processing task. The database checkpointer ensures state persistence without blocking processing, and the event streaming system efficiently manages multiple client connections per workflow.
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chatbot feature - LangGraph-based chatbot implementation.
|
||||
"""
|
||||
|
||||
from .mainChatbot import chatProcess
|
||||
from .service import chatProcess
|
||||
|
||||
__all__ = ['chatProcess']
|
||||
|
||||
|
|
|
|||
3
modules/features/chatbot/bridges/__init__.py
Normal file
3
modules/features/chatbot/bridges/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Bridges to external systems (AI models, database, tools)."""
|
||||
547
modules/features/chatbot/bridges/ai.py
Normal file
547
modules/features/chatbot/bridges/ai.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
AI Center to LangChain bridge.
|
||||
Implements LangChain BaseChatModel interface using AI center models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
convert_to_openai_messages,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import modelSelector
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModel,
|
||||
AiModelCall,
|
||||
AiModelResponse,
|
||||
AiCallOptions,
|
||||
OperationTypeEnum,
|
||||
ProcessingModeEnum,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AICenterChatModel(BaseChatModel):
|
||||
"""
|
||||
LangChain-compatible chat model that uses AI center models.
|
||||
Bridges AI center model selection and calling to LangChain's BaseChatModel interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user: User,
|
||||
operation_type: OperationTypeEnum = OperationTypeEnum.DATA_ANALYSE,
|
||||
processing_mode: ProcessingModeEnum = ProcessingModeEnum.DETAILED,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the AI center chat model bridge.
|
||||
|
||||
Args:
|
||||
user: Current user for RBAC and model selection
|
||||
operation_type: Operation type for model selection
|
||||
processing_mode: Processing mode for model selection
|
||||
**kwargs: Additional arguments passed to BaseChatModel
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
# Use object.__setattr__ to bypass Pydantic validation for custom attributes
|
||||
object.__setattr__(self, "user", user)
|
||||
object.__setattr__(self, "operation_type", operation_type)
|
||||
object.__setattr__(self, "processing_mode", processing_mode)
|
||||
object.__setattr__(self, "_selected_model", None)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
return "aicenter"
|
||||
|
||||
def _select_model(self, messages: List[BaseMessage]) -> AiModel:
|
||||
"""
|
||||
Select the best AI center model for the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages
|
||||
|
||||
Returns:
|
||||
Selected AI model
|
||||
"""
|
||||
# Convert messages to prompt/context format for model selector
|
||||
prompt_parts = []
|
||||
context_parts = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
prompt_parts.append(msg.content)
|
||||
elif isinstance(msg, HumanMessage):
|
||||
prompt_parts.append(msg.content)
|
||||
elif isinstance(msg, AIMessage):
|
||||
context_parts.append(msg.content)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
context_parts.append(f"Tool {msg.name}: {msg.content}")
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
context = "\n".join(context_parts) if context_parts else ""
|
||||
|
||||
# Get available models with RBAC filtering
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
# Get database connectors for RBAC
|
||||
# Create a database connector instance for RBAC with proper configuration
|
||||
dbHost = APP_CONFIG.get("DB_MANAGEMENT_HOST")
|
||||
dbDatabase = APP_CONFIG.get("DB_MANAGEMENT_DATABASE", "management")
|
||||
dbUser = APP_CONFIG.get("DB_MANAGEMENT_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_MANAGEMENT_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_MANAGEMENT_PORT"))
|
||||
|
||||
db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.user.id if hasattr(self.user, 'id') else None
|
||||
)
|
||||
dbApp = getRootDbAppConnector()
|
||||
rbac_instance = RbacClass(db, dbApp=dbApp)
|
||||
|
||||
available_models = modelRegistry.getAvailableModels(
|
||||
currentUser=self.user,
|
||||
rbacInstance=rbac_instance
|
||||
)
|
||||
|
||||
# Create options for model selector
|
||||
options = AiCallOptions(
|
||||
operationType=self.operation_type,
|
||||
processingMode=self.processing_mode
|
||||
)
|
||||
|
||||
# Select model
|
||||
selected_model = modelSelector.selectModel(
|
||||
prompt=prompt,
|
||||
context=context,
|
||||
options=options,
|
||||
availableModels=available_models
|
||||
)
|
||||
|
||||
if not selected_model:
|
||||
raise ValueError(f"No suitable model found for operation type {self.operation_type.value}")
|
||||
|
||||
logger.info(f"Selected AI center model: {selected_model.displayName} ({selected_model.name})")
|
||||
object.__setattr__(self, "_selected_model", selected_model)
|
||||
return selected_model
|
||||
|
||||
def _convert_messages_to_ai_format(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert LangChain messages to AI center format (OpenAI-style).
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
# Use LangChain's built-in conversion
|
||||
openai_messages = convert_to_openai_messages(messages)
|
||||
return openai_messages
|
||||
|
||||
def _convert_ai_response_to_langchain(
|
||||
self,
|
||||
response: AiModelResponse,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AIMessage:
|
||||
"""
|
||||
Convert AI center response to LangChain AIMessage.
|
||||
|
||||
Args:
|
||||
response: AI center response
|
||||
tool_calls: Optional tool calls from the response (format: [{"id": "...", "name": "...", "args": {...}}])
|
||||
|
||||
Returns:
|
||||
LangChain AIMessage with tool_calls if present
|
||||
"""
|
||||
# LangChain expects tool_calls in format: [{"id": "...", "name": "...", "args": {...}}]
|
||||
# The tool_calls parameter should already be in this format
|
||||
|
||||
kwargs = {}
|
||||
if tool_calls:
|
||||
kwargs["tool_calls"] = tool_calls
|
||||
|
||||
return AIMessage(content=response.content or "", **kwargs)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Synchronous generate method required by BaseChatModel.
|
||||
Wraps the async _agenerate method.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages
|
||||
stop: Optional stop sequences
|
||||
run_manager: Optional callback manager
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
ChatResult with generations
|
||||
"""
|
||||
# Try to get the current event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If we're in an async context, raise an error
|
||||
raise RuntimeError(
|
||||
"AICenterChatModel._generate() called from async context. "
|
||||
"Use _agenerate() instead."
|
||||
)
|
||||
except RuntimeError:
|
||||
# No event loop, we can create one
|
||||
pass
|
||||
|
||||
# Run the async method synchronously
|
||||
return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs))
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Async generate method required by BaseChatModel.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages
|
||||
stop: Optional stop sequences
|
||||
run_manager: Optional callback manager
|
||||
**kwargs: Additional arguments (may include tools for tool calling)
|
||||
|
||||
Returns:
|
||||
ChatResult with generations
|
||||
"""
|
||||
# Select model if not already selected
|
||||
if not self._selected_model:
|
||||
self._select_model(messages)
|
||||
|
||||
# Check if tools are bound (for tool calling)
|
||||
tools = getattr(self, "_bound_tools", None)
|
||||
|
||||
# Convert messages to AI center format
|
||||
ai_messages = self._convert_messages_to_ai_format(messages)
|
||||
|
||||
# If tools are bound, add tool definitions to the system message
|
||||
# This ensures the model knows about available tools
|
||||
# Some models need explicit tool definitions to enable tool calling
|
||||
if tools:
|
||||
# Find or create system message
|
||||
system_message_idx = None
|
||||
for i, msg in enumerate(ai_messages):
|
||||
if msg.get("role") == "system":
|
||||
system_message_idx = i
|
||||
break
|
||||
|
||||
# Build tool descriptions for the system message
|
||||
tool_descriptions = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and hasattr(tool, "description"):
|
||||
# Get tool parameters for better description
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
params_info = ""
|
||||
if args_schema:
|
||||
try:
|
||||
if hasattr(args_schema, "model_json_schema"):
|
||||
schema = args_schema.model_json_schema()
|
||||
if "properties" in schema:
|
||||
params = list(schema["properties"].keys())
|
||||
params_info = f" (Parameter: {', '.join(params)})"
|
||||
except:
|
||||
pass
|
||||
tool_descriptions.append(f"- {tool.name}: {tool.description}{params_info}")
|
||||
|
||||
if tool_descriptions:
|
||||
tools_text = "\n".join(tool_descriptions)
|
||||
tools_note = f"\n\n⚠️⚠️⚠️ KRITISCH - TOOL-NUTZUNG ⚠️⚠️⚠️\n\nVERFÜGBARE TOOLS:\n{tools_text}\n\nABSOLUT VERBINDLICH:\n- Du MUSST diese Tools verwenden, um Anfragen zu bearbeiten!\n- Für Status-Updates MUSST du IMMER das Tool 'send_streaming_message' verwenden!\n- VERBOTEN: Normale Text-Nachrichten für Status-Updates!\n- Du MUSST Tools aufrufen, nicht nur darüber sprechen!\n\nBeispiel FALSCH: \"Ich werde die Datenbank durchsuchen...\"\nBeispiel RICHTIG: Rufe das Tool 'send_streaming_message' mit \"Durchsuche Datenbank...\" auf!"
|
||||
|
||||
if system_message_idx is not None:
|
||||
# Append to existing system message
|
||||
ai_messages[system_message_idx]["content"] += tools_note
|
||||
else:
|
||||
# Add new system message at the beginning
|
||||
ai_messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": tools_note.strip()
|
||||
})
|
||||
|
||||
# Convert LangChain tools to OpenAI tool format for potential use
|
||||
# Note: The actual tool calling is handled by the connector if it supports it
|
||||
# This conversion is kept for potential future use or connector support
|
||||
openai_tools = None
|
||||
if tools and self._selected_model.connectorType == "openai":
|
||||
# Convert LangChain tools to OpenAI tool format
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and hasattr(tool, "description"):
|
||||
# Get tool parameters schema
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
parameters = {}
|
||||
if args_schema:
|
||||
# Check if it's a Pydantic model class or instance
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Check if it's a class (not an instance)
|
||||
if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
|
||||
# It's a Pydantic model class - get JSON schema
|
||||
if hasattr(args_schema, "model_json_schema"):
|
||||
# Pydantic v2
|
||||
parameters = args_schema.model_json_schema()
|
||||
elif hasattr(args_schema, "schema"):
|
||||
# Pydantic v1
|
||||
parameters = args_schema.schema()
|
||||
elif isinstance(args_schema, BaseModel):
|
||||
# It's a Pydantic model instance
|
||||
if hasattr(args_schema, "model_dump"):
|
||||
# Pydantic v2
|
||||
parameters = args_schema.model_dump()
|
||||
elif hasattr(args_schema, "dict"):
|
||||
# Pydantic v1
|
||||
parameters = args_schema.dict()
|
||||
elif hasattr(args_schema, "schema"):
|
||||
# Has schema method (might be a class)
|
||||
try:
|
||||
parameters = args_schema.schema()
|
||||
except TypeError:
|
||||
# If schema() requires instance, try model_json_schema
|
||||
if hasattr(args_schema, "model_json_schema"):
|
||||
parameters = args_schema.model_json_schema()
|
||||
else:
|
||||
parameters = {}
|
||||
elif isinstance(args_schema, dict):
|
||||
# Already a dict
|
||||
parameters = args_schema
|
||||
|
||||
tool_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": parameters
|
||||
}
|
||||
}
|
||||
openai_tools.append(tool_schema)
|
||||
|
||||
# Store tools for potential use by connector
|
||||
# Note: The connector may need to access tools from the model_call
|
||||
# This is a workaround since AiModelCall doesn't have a tools field
|
||||
# Tools are added to system message above to ensure model knows about them
|
||||
|
||||
# Create model call
|
||||
model_call = AiModelCall(
|
||||
messages=ai_messages,
|
||||
model=self._selected_model,
|
||||
options=AiCallOptions(
|
||||
operationType=self.operation_type,
|
||||
processingMode=self.processing_mode,
|
||||
temperature=self._selected_model.temperature
|
||||
)
|
||||
)
|
||||
|
||||
# If tools are bound and this is an OpenAI model, we need to call the API directly
|
||||
# with tools included, since the connector interface doesn't support tools
|
||||
if openai_tools and self._selected_model.connectorType == "openai":
|
||||
# Call OpenAI API directly with tools (like legacy ChatAnthropic does)
|
||||
import httpx
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
api_key = APP_CONFIG.get('Connector_AiOpenai_API_SECRET')
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key not configured")
|
||||
|
||||
payload = {
|
||||
"model": self._selected_model.name,
|
||||
"messages": ai_messages,
|
||||
"tools": openai_tools,
|
||||
"tool_choice": "auto", # Let model decide when to use tools
|
||||
"temperature": self._selected_model.temperature,
|
||||
"max_tokens": self._selected_model.maxTokens
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
response_obj = await client.post(
|
||||
self._selected_model.apiUrl,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response_obj.status_code != 200:
|
||||
error_msg = f"OpenAI API error: {response_obj.status_code} - {response_obj.text}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
response_json = response_obj.json()
|
||||
choice = response_json["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# Extract content and tool calls
|
||||
content = message.get("content", "")
|
||||
tool_calls_raw = message.get("tool_calls")
|
||||
|
||||
# Convert OpenAI tool_calls format to LangChain format
|
||||
# LangChain expects: [{"id": "...", "name": "...", "args": {...}}]
|
||||
tool_calls = None
|
||||
if tool_calls_raw:
|
||||
tool_calls = []
|
||||
for tc in tool_calls_raw:
|
||||
func_data = tc.get("function", {})
|
||||
func_name = func_data.get("name")
|
||||
func_args_str = func_data.get("arguments", "{}")
|
||||
|
||||
# Parse JSON arguments string to dict
|
||||
import json
|
||||
try:
|
||||
func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
|
||||
except:
|
||||
func_args = {}
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc.get("id"),
|
||||
"name": func_name,
|
||||
"args": func_args
|
||||
})
|
||||
|
||||
# Create response object
|
||||
response = AiModelResponse(
|
||||
content=content or "",
|
||||
success=True,
|
||||
modelId=self._selected_model.name,
|
||||
metadata={
|
||||
"response_id": response_json.get("id", ""),
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
)
|
||||
else:
|
||||
# No tools or not OpenAI - use connector normally
|
||||
if not self._selected_model.functionCall:
|
||||
raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined")
|
||||
|
||||
response: AiModelResponse = await self._selected_model.functionCall(model_call)
|
||||
|
||||
if not response.success:
|
||||
raise ValueError(f"AI model call failed: {response.error or 'Unknown error'}")
|
||||
|
||||
# Extract tool calls from response metadata if present
|
||||
tool_calls = None
|
||||
if response.metadata:
|
||||
# Check for tool calls in metadata (format may vary by connector)
|
||||
tool_calls = response.metadata.get("tool_calls") or response.metadata.get("function_calls")
|
||||
|
||||
# Convert response to LangChain format with tool calls
|
||||
ai_message = self._convert_ai_response_to_langchain(response, tool_calls=tool_calls)
|
||||
|
||||
# Create generation and result
|
||||
generation = ChatGeneration(message=ai_message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def bind_tools(self, tools: List[Any], **kwargs: Any) -> "AICenterChatModel":
|
||||
"""
|
||||
Bind tools to the model (required for LangGraph tool calling).
|
||||
|
||||
Args:
|
||||
tools: List of LangChain tools
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
New instance with tools bound
|
||||
"""
|
||||
# Create a new instance with tools bound
|
||||
# Note: The actual tool binding happens in LangGraph's ToolNode
|
||||
# This method is called by LangGraph to prepare the model
|
||||
bound_model = AICenterChatModel(
|
||||
user=self.user,
|
||||
operation_type=self.operation_type,
|
||||
processing_mode=self.processing_mode
|
||||
)
|
||||
object.__setattr__(bound_model, "_selected_model", self._selected_model)
|
||||
# Store tools for potential use in message conversion
|
||||
object.__setattr__(bound_model, "_bound_tools", tools)
|
||||
return bound_model
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: List[BaseMessage],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""
|
||||
Synchronous invoke method (required by BaseChatModel).
|
||||
Note: This is a wrapper around async _agenerate.
|
||||
|
||||
Args:
|
||||
input: List of LangChain messages
|
||||
config: Optional runnable config
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AIMessage response
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Try to get existing event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is running, we need to use a different approach
|
||||
# This shouldn't happen in LangGraph context, but handle it gracefully
|
||||
raise RuntimeError("Cannot use synchronous invoke in async context. Use ainvoke instead.")
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run async generation
|
||||
result = loop.run_until_complete(self._agenerate(input, **kwargs))
|
||||
return result.generations[0].message
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: List[BaseMessage],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""
|
||||
Async invoke method (required by BaseChatModel).
|
||||
|
||||
Args:
|
||||
input: List of LangChain messages
|
||||
config: Optional runnable config
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AIMessage response
|
||||
"""
|
||||
result = await self._agenerate(input, **kwargs)
|
||||
return result.generations[0].message
|
||||
432
modules/features/chatbot/bridges/memory.py
Normal file
432
modules/features/chatbot/bridges/memory.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Custom LangGraph checkpointer using existing database interface.
|
||||
Maps LangGraph state to existing message storage format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
|
||||
|
||||
# CheckpointTuple might not be directly importable, so we define it as a NamedTuple
|
||||
# Based on LangGraph's usage, it needs config, checkpoint, metadata, parent_config, and pending_writes
|
||||
class CheckpointTuple(NamedTuple):
|
||||
"""Tuple containing config, checkpoint, metadata, parent_config, and pending_writes."""
|
||||
config: Dict[str, Any]
|
||||
checkpoint: Checkpoint
|
||||
metadata: CheckpointMetadata
|
||||
parent_config: Optional[Dict[str, Any]] = None
|
||||
pending_writes: Optional[List[Tuple[str, Any]]] = None
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||||
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface
|
||||
from modules.datamodels.datamodelChat import ChatMessage, ChatWorkflow
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseCheckpointer(BaseCheckpointSaver):
|
||||
"""
|
||||
Custom LangGraph checkpointer that uses the existing database interface.
|
||||
Maps LangGraph thread_id to workflow.id and stores messages in the existing format.
|
||||
"""
|
||||
|
||||
def __init__(self, user: User, workflow_id: str):
|
||||
"""
|
||||
Initialize the database checkpointer.
|
||||
|
||||
Args:
|
||||
user: Current user for database access
|
||||
workflow_id: Workflow ID (maps to LangGraph thread_id)
|
||||
"""
|
||||
self.user = user
|
||||
self.workflow_id = workflow_id
|
||||
self.interface = getInterface(user)
|
||||
|
||||
def _convert_langchain_to_db_message(
|
||||
self,
|
||||
msg: BaseMessage,
|
||||
sequence_nr: int,
|
||||
round_number: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert LangChain message to database message format.
|
||||
|
||||
Args:
|
||||
msg: LangChain message
|
||||
sequence_nr: Sequence number for ordering
|
||||
round_number: Round number in workflow
|
||||
|
||||
Returns:
|
||||
Dictionary in database message format
|
||||
"""
|
||||
import uuid
|
||||
|
||||
role = "user"
|
||||
content = ""
|
||||
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "user"
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "assistant"
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
elif isinstance(msg, SystemMessage):
|
||||
# System messages are stored but marked as system
|
||||
role = "system"
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
# Tool messages are stored as assistant messages with tool info
|
||||
role = "assistant"
|
||||
content = f"Tool {msg.name}: {msg.content}"
|
||||
|
||||
return {
|
||||
"id": f"msg_{uuid.uuid4()}",
|
||||
"workflowId": self.workflow_id,
|
||||
"message": content,
|
||||
"role": role,
|
||||
"status": "step" if sequence_nr > 1 else "first",
|
||||
"sequenceNr": sequence_nr,
|
||||
"publishedAt": getUtcTimestamp(),
|
||||
"roundNumber": round_number,
|
||||
"taskNumber": 0,
|
||||
"actionNumber": 0
|
||||
}
|
||||
|
||||
def _convert_db_to_langchain_messages(
|
||||
self,
|
||||
messages: List[ChatMessage]
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
Convert database messages to LangChain messages.
|
||||
|
||||
Args:
|
||||
messages: List of database ChatMessage objects
|
||||
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects
|
||||
"""
|
||||
langchain_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
langchain_messages.append(HumanMessage(content=msg.message))
|
||||
elif msg.role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=msg.message))
|
||||
elif msg.role == "system":
|
||||
langchain_messages.append(SystemMessage(content=msg.message))
|
||||
# Skip other roles for now
|
||||
|
||||
return langchain_messages
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: Dict[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Store a checkpoint in the database.
|
||||
|
||||
Args:
|
||||
config: LangGraph config (contains thread_id)
|
||||
checkpoint: Checkpoint to store
|
||||
metadata: Checkpoint metadata
|
||||
new_versions: New version numbers
|
||||
"""
|
||||
try:
|
||||
# Extract thread_id from config (maps to workflow_id)
|
||||
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
||||
|
||||
# Get current workflow to determine round number
|
||||
workflow = self.interface.getWorkflow(thread_id)
|
||||
if not workflow:
|
||||
logger.warning(f"Workflow {thread_id} not found, cannot store checkpoint")
|
||||
return
|
||||
|
||||
round_number = workflow.currentRound if workflow else 1
|
||||
|
||||
# Extract messages from checkpoint
|
||||
state = checkpoint.get("channel_values", {})
|
||||
messages = state.get("messages", [])
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"No messages in checkpoint for workflow {thread_id}")
|
||||
return
|
||||
|
||||
# Get existing messages to determine what's already stored
|
||||
existing_messages = self.interface.getMessages(thread_id)
|
||||
existing_count = len(existing_messages) if existing_messages else 0
|
||||
|
||||
# Create a set of existing message content+role for quick lookup
|
||||
existing_content_set = set()
|
||||
if existing_messages:
|
||||
for existing_msg in existing_messages:
|
||||
# Create a unique key from role and message content
|
||||
content_key = (existing_msg.role, existing_msg.message)
|
||||
existing_content_set.add(content_key)
|
||||
|
||||
# Filter checkpoint messages to only user/assistant (skip system)
|
||||
# Skip intermediate AIMessages with tool_calls (these are tool call requests, not final answers)
|
||||
checkpoint_user_assistant_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
# Always store user messages
|
||||
checkpoint_user_assistant_messages.append(msg)
|
||||
elif isinstance(msg, AIMessage):
|
||||
# Check if this message has tool_calls
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
|
||||
# Skip messages with tool_calls - these are intermediate tool call requests
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
logger.debug(f"Skipping intermediate AIMessage with tool_calls for workflow {thread_id}")
|
||||
continue
|
||||
|
||||
# Store all other AIMessages (final answers)
|
||||
checkpoint_user_assistant_messages.append(msg)
|
||||
|
||||
# Only store new messages that aren't already in the database
|
||||
new_messages_to_store = []
|
||||
for msg in checkpoint_user_assistant_messages:
|
||||
# Determine role
|
||||
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
|
||||
# Skip empty messages (they might be status updates)
|
||||
if not content or not content.strip():
|
||||
continue
|
||||
|
||||
# Check if this message already exists
|
||||
content_key = (role, content)
|
||||
if content_key not in existing_content_set:
|
||||
new_messages_to_store.append(msg)
|
||||
existing_content_set.add(content_key) # Mark as seen to avoid duplicates in this batch
|
||||
|
||||
# Store only the new messages
|
||||
if new_messages_to_store:
|
||||
for i, msg in enumerate(new_messages_to_store, 1):
|
||||
sequence_nr = existing_count + i
|
||||
|
||||
# Convert to database format
|
||||
db_message_data = self._convert_langchain_to_db_message(
|
||||
msg,
|
||||
sequence_nr,
|
||||
round_number
|
||||
)
|
||||
|
||||
# Store the message
|
||||
try:
|
||||
self.interface.createMessage(db_message_data)
|
||||
logger.debug(f"Stored message {db_message_data['id']} for workflow {thread_id}")
|
||||
existing_count += 1 # Update count for next message
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing message: {e}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"No new messages to store for workflow {thread_id} (existing: {existing_count}, checkpoint: {len(checkpoint_user_assistant_messages)})")
|
||||
|
||||
# Update workflow last activity
|
||||
self.interface.updateWorkflow(thread_id, {
|
||||
"lastActivity": getUtcTimestamp()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing checkpoint: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
) -> Optional[Checkpoint]:
|
||||
"""
|
||||
Retrieve a checkpoint from the database.
|
||||
|
||||
Args:
|
||||
config: LangGraph config (contains thread_id)
|
||||
|
||||
Returns:
|
||||
Checkpoint if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract thread_id from config (maps to workflow_id)
|
||||
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
||||
|
||||
# Get workflow
|
||||
workflow = self.interface.getWorkflow(thread_id)
|
||||
if not workflow:
|
||||
logger.debug(f"Workflow {thread_id} not found")
|
||||
return None
|
||||
|
||||
# Get messages
|
||||
messages = self.interface.getMessages(thread_id)
|
||||
|
||||
checkpoint_id = str(uuid.uuid4())
|
||||
|
||||
if not messages:
|
||||
# Return empty checkpoint for new workflow
|
||||
return {
|
||||
"id": checkpoint_id,
|
||||
"v": 1,
|
||||
"ts": getUtcTimestamp(),
|
||||
"channel_values": {
|
||||
"messages": []
|
||||
},
|
||||
"channel_versions": {},
|
||||
"versions_seen": {}
|
||||
}
|
||||
|
||||
# Convert to LangChain messages
|
||||
langchain_messages = self._convert_db_to_langchain_messages(messages)
|
||||
|
||||
# Build checkpoint
|
||||
checkpoint = {
|
||||
"id": checkpoint_id,
|
||||
"v": 1,
|
||||
"ts": getUtcTimestamp(),
|
||||
"channel_values": {
|
||||
"messages": langchain_messages
|
||||
},
|
||||
"channel_versions": {},
|
||||
"versions_seen": {}
|
||||
}
|
||||
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving checkpoint: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[Checkpoint]:
|
||||
"""
|
||||
List checkpoints (not fully implemented - returns current checkpoint).
|
||||
|
||||
Args:
|
||||
config: LangGraph config
|
||||
filter: Optional filter
|
||||
before: Optional timestamp before which to list
|
||||
limit: Optional limit on number of results
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
checkpoint = self.get(config)
|
||||
if checkpoint:
|
||||
return [checkpoint]
|
||||
return []
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
writes: List[Tuple[str, Any]],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store checkpoint writes (not used in current implementation).
|
||||
|
||||
Args:
|
||||
config: LangGraph config
|
||||
writes: List of write operations
|
||||
task_id: Task ID
|
||||
"""
|
||||
# Not implemented - using put() instead
|
||||
pass
|
||||
|
||||
async def aget_tuple(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
) -> Optional[CheckpointTuple]:
|
||||
"""
|
||||
Async version of get that returns tuple of (config, checkpoint, metadata).
|
||||
|
||||
Args:
|
||||
config: LangGraph config (contains thread_id)
|
||||
|
||||
Returns:
|
||||
CheckpointTuple with config, checkpoint and metadata if found, None otherwise
|
||||
"""
|
||||
checkpoint = self.get(config)
|
||||
if checkpoint:
|
||||
# Return checkpoint with metadata including step
|
||||
# CheckpointMetadata is typically a TypedDict
|
||||
# LangGraph expects 'step' in metadata
|
||||
metadata: CheckpointMetadata = {
|
||||
"step": 0 # Start at step 0, LangGraph will increment
|
||||
}
|
||||
return CheckpointTuple(
|
||||
config=config,
|
||||
checkpoint=checkpoint,
|
||||
metadata=metadata,
|
||||
parent_config=None, # No parent checkpoint for our implementation
|
||||
pending_writes=None # No pending writes in our implementation
|
||||
)
|
||||
return None
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: Dict[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Async version of put.
|
||||
|
||||
Args:
|
||||
config: LangGraph config (contains thread_id)
|
||||
checkpoint: Checkpoint to store
|
||||
metadata: Checkpoint metadata
|
||||
new_versions: New version numbers
|
||||
"""
|
||||
self.put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[Checkpoint]:
|
||||
"""
|
||||
Async version of list.
|
||||
|
||||
Args:
|
||||
config: LangGraph config
|
||||
filter: Optional filter
|
||||
before: Optional timestamp before which to list
|
||||
limit: Optional limit on number of results
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
return self.list(config, filter, before, limit)
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
writes: List[Tuple[str, Any]],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Async version of put_writes.
|
||||
Store checkpoint writes (not used in current implementation).
|
||||
|
||||
Args:
|
||||
config: LangGraph config
|
||||
writes: List of write operations
|
||||
task_id: Task ID
|
||||
"""
|
||||
# Not implemented - using aput() instead
|
||||
# This method is called by LangGraph but we handle writes through aput()
|
||||
pass
|
||||
313
modules/features/chatbot/bridges/tools.py
Normal file
313
modules/features/chatbot/bridges/tools.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chatbot tools for LangGraph integration.
|
||||
Includes SQL query tool, Tavily search tool, and streaming status tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from modules.connectors.connectorPreprocessor import PreprocessorConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
async def sqlite_query(query: str) -> str:
|
||||
"""
|
||||
Execute a SQL SELECT query on the Althaus AG database.
|
||||
|
||||
This tool allows you to query the SQLite database to find articles, prices,
|
||||
inventory levels, and other product information.
|
||||
|
||||
Args:
|
||||
query: A valid SQL SELECT query. Must use double quotes for column names
|
||||
with spaces or special characters (e.g., "Artikelnummer", "S_IST_BESTAND").
|
||||
Only SELECT queries are allowed.
|
||||
|
||||
Returns:
|
||||
Query results as a formatted string, or an error message if the query fails.
|
||||
|
||||
Examples:
|
||||
- Find articles by name:
|
||||
SELECT a."Artikelnummer", a."Artikelbezeichnung", a."Lieferant"
|
||||
FROM Artikel a
|
||||
WHERE a."Artikelbezeichnung" LIKE '%Motor%'
|
||||
LIMIT 20
|
||||
|
||||
- Find articles with price and inventory:
|
||||
SELECT a."Artikelnummer", a."Artikelbezeichnung", e."EP_CHF",
|
||||
lp."Lagerplatz" as "Lagerplatzname", l."S_IST_BESTAND",
|
||||
l."S_RESERVIERTER__BESTAND",
|
||||
CASE WHEN l."S_IST_BESTAND" != 'Unbekannt'
|
||||
THEN CAST(l."S_IST_BESTAND" AS INTEGER) - COALESCE(l."S_RESERVIERTER__BESTAND", 0)
|
||||
ELSE NULL END as "Verfügbarer Bestand"
|
||||
FROM Artikel a
|
||||
LEFT JOIN Einkaufspreis e ON a."I_ID" = e."m_Artikel"
|
||||
LEFT JOIN Lagerplatz_Artikel l ON a."I_ID" = l."R_ARTIKEL"
|
||||
LEFT JOIN Lagerplatz lp ON l."R_LAGERPLATZ" = lp."I_ID"
|
||||
WHERE a."Artikelbezeichnung" LIKE '%Netzgerät%'
|
||||
LIMIT 20
|
||||
"""
|
||||
try:
|
||||
connector = PreprocessorConnector()
|
||||
try:
|
||||
result = await connector.executeQuery(query, return_json=True)
|
||||
|
||||
if result.get("text", "").startswith(("Error:", "Query failed:")):
|
||||
error_msg = result.get("text", "Query failed")
|
||||
logger.error(f"SQL query failed: {error_msg}")
|
||||
return error_msg
|
||||
|
||||
# Format results
|
||||
data = result.get("data", [])
|
||||
row_count = result.get("row_count", len(data))
|
||||
|
||||
if not data:
|
||||
return f"Query executed successfully. Returned {row_count} rows (no data)."
|
||||
|
||||
# Format as readable string
|
||||
lines = [f"Query executed successfully. Returned {row_count} rows:"]
|
||||
|
||||
# Show column headers from first row
|
||||
if data and isinstance(data[0], dict):
|
||||
headers = list(data[0].keys())
|
||||
lines.append("\nColumns: " + ", ".join(headers))
|
||||
lines.append("\nResults:")
|
||||
|
||||
# Show first 50 rows
|
||||
for i, row in enumerate(data[:50], 1):
|
||||
row_str = ", ".join([f"{k}: {v}" for k, v in row.items()])
|
||||
lines.append(f"{i}. {row_str}")
|
||||
|
||||
if row_count > 50:
|
||||
lines.append(f"\n(Showing first 50 of {row_count} rows)")
|
||||
else:
|
||||
# Fallback for non-dict rows
|
||||
for i, row in enumerate(data[:50], 1):
|
||||
lines.append(f"{i}. {row}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing SQL query: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return error_msg
|
||||
|
||||
|
||||
@tool
|
||||
async def tavily_search(query: str) -> str:
|
||||
"""
|
||||
Search the internet for comprehensive information using Tavily search via AI Center.
|
||||
|
||||
Use this tool when you need to find detailed product information, datasheets,
|
||||
certifications, technical specifications, market trends, or other comprehensive
|
||||
information that is not in the database.
|
||||
|
||||
IMPORTANT: This tool returns FULL content from search results (not truncated).
|
||||
Use all available information to provide comprehensive, detailed answers with
|
||||
specific facts, numbers, dates, and technical details.
|
||||
|
||||
Args:
|
||||
query: Search query string. Be specific and include product names,
|
||||
model numbers, or other relevant keywords. For comprehensive
|
||||
research, use broad queries like "latest developments in LED technology 2026"
|
||||
|
||||
Returns:
|
||||
Comprehensive search results with full content, titles, URLs, and sources.
|
||||
Results include up to 15 sources with complete content for detailed analysis.
|
||||
|
||||
Examples:
|
||||
- Search for comprehensive product information:
|
||||
tavily_search("latest LED technology developments 2026")
|
||||
|
||||
- Search for product datasheet:
|
||||
tavily_search("Siemens 6AV2 181-8XP00-0AX0 datasheet")
|
||||
|
||||
- Search for market trends:
|
||||
tavily_search("LED market trends efficiency breakthroughs 2025")
|
||||
"""
|
||||
try:
|
||||
# Use AI Center Tavily plugin instead of direct langchain-tavily
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import modelSelector
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModelCall,
|
||||
AiModelResponse,
|
||||
AiCallOptions,
|
||||
OperationTypeEnum,
|
||||
ProcessingModeEnum,
|
||||
AiCallPromptWebSearch
|
||||
)
|
||||
import json
|
||||
|
||||
# Discover and register connectors if not already registered
|
||||
if not modelRegistry._connectors:
|
||||
discovered_connectors = modelRegistry.discoverConnectors()
|
||||
for connector in discovered_connectors:
|
||||
modelRegistry.registerConnector(connector)
|
||||
|
||||
# Refresh models to ensure Tavily is available
|
||||
modelRegistry.refreshModels()
|
||||
|
||||
# Get available Tavily models (without RBAC filtering since tools don't have user context)
|
||||
available_models = modelRegistry.getAvailableModels()
|
||||
tavily_models = [m for m in available_models if m.connectorType == "tavily"]
|
||||
|
||||
if not tavily_models:
|
||||
return "Error: Tavily model not available in AI Center. Please check configuration."
|
||||
|
||||
# Select the best Tavily model for web search
|
||||
options = AiCallOptions(
|
||||
operationType=OperationTypeEnum.WEB_SEARCH_DATA,
|
||||
processingMode=ProcessingModeEnum.BASIC
|
||||
)
|
||||
|
||||
# Use model selector to choose the best Tavily model
|
||||
# Since we only have Tavily models, we can just pick the first one
|
||||
# or use selector if multiple Tavily models exist
|
||||
if len(tavily_models) == 1:
|
||||
selected_model = tavily_models[0]
|
||||
else:
|
||||
selected_model = modelSelector.selectModel(
|
||||
prompt=query,
|
||||
context="",
|
||||
options=options,
|
||||
availableModels=tavily_models
|
||||
)
|
||||
|
||||
if not selected_model:
|
||||
return "Error: Could not select Tavily model for web search."
|
||||
|
||||
# Create web search prompt with more results and deeper research
|
||||
web_search_prompt = AiCallPromptWebSearch(
|
||||
instruction=query,
|
||||
maxNumberPages=15, # Request more results for comprehensive information
|
||||
country=None, # No country filter by default
|
||||
language=None, # No language filter by default
|
||||
researchDepth="deep" # Deep research for comprehensive results
|
||||
)
|
||||
|
||||
# Create model call with JSON prompt
|
||||
model_call = AiModelCall(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(web_search_prompt.model_dump())
|
||||
}
|
||||
],
|
||||
model=selected_model,
|
||||
options=options
|
||||
)
|
||||
|
||||
# Call the model's functionCall (which routes to _routeWebOperation)
|
||||
if not selected_model.functionCall:
|
||||
return "Error: Tavily model has no functionCall defined."
|
||||
|
||||
response: AiModelResponse = await selected_model.functionCall(model_call)
|
||||
|
||||
if not response.success:
|
||||
error_msg = response.error or "Unknown error"
|
||||
logger.error(f"Tavily search failed: {error_msg}")
|
||||
return f"Error performing Tavily search: {error_msg}"
|
||||
|
||||
# Parse response content (should be JSON with URLs and content)
|
||||
try:
|
||||
result_data = json.loads(response.content) if response.content else {}
|
||||
|
||||
# Handle different response formats
|
||||
if isinstance(result_data, list):
|
||||
# List of URLs or results
|
||||
results = result_data
|
||||
elif isinstance(result_data, dict):
|
||||
# Dictionary with URLs or results key
|
||||
results = result_data.get("urls", []) or result_data.get("results", []) or []
|
||||
else:
|
||||
results = []
|
||||
|
||||
if not results:
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
# Format results with full content (not truncated)
|
||||
lines = [f"Internet search results for: {query}\n"]
|
||||
|
||||
# Return all results with full content (up to 15 results)
|
||||
for i, result in enumerate(results[:15], 1):
|
||||
if isinstance(result, str):
|
||||
# Simple URL string
|
||||
lines.append(f"{i}. {result}")
|
||||
lines.append(f" URL: {result}")
|
||||
elif isinstance(result, dict):
|
||||
# Dictionary with url, title, content
|
||||
url = result.get("url", "")
|
||||
title = result.get("title", url)
|
||||
content = result.get("content", "")
|
||||
|
||||
lines.append(f"{i}. {title}")
|
||||
lines.append(f" URL: {url}")
|
||||
if content:
|
||||
# Return FULL content, not truncated - let the LLM decide what to use
|
||||
lines.append(f" Content: {content}")
|
||||
else:
|
||||
# Fallback
|
||||
lines.append(f"{i}. {str(result)}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If response is not JSON, try to parse as plain text
|
||||
if response.content:
|
||||
return f"Internet search results for: {query}\n\n{response.content}"
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error performing Tavily search via AI Center: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return error_msg
|
||||
|
||||
|
||||
# Note: send_streaming_message will be created in the LangGraph integration
|
||||
# where it has access to the event manager. For now, we define it here as a placeholder.
|
||||
|
||||
def create_send_streaming_message_tool(event_manager=None):
|
||||
"""
|
||||
Create the send_streaming_message tool with access to event manager.
|
||||
|
||||
Args:
|
||||
event_manager: Event manager instance for emitting events (not used directly,
|
||||
events are captured via LangGraph tool events)
|
||||
|
||||
Returns:
|
||||
LangChain tool for sending streaming messages
|
||||
"""
|
||||
@tool
|
||||
async def send_streaming_message(message: str) -> str:
|
||||
"""
|
||||
Send a streaming status update to the user.
|
||||
|
||||
Use this tool frequently to keep the user informed about what you are doing.
|
||||
This helps provide a better user experience by showing progress updates.
|
||||
|
||||
Args:
|
||||
message: A short message describing what you are currently doing.
|
||||
Examples:
|
||||
- "Durchsuche Datenbank nach Lampen, LED, Leuchten, und Ähnlichem."
|
||||
- "Suche im Internet nach Produktinformationen."
|
||||
- "Analysiere Suchergebnisse und bereite Antwort vor."
|
||||
|
||||
Returns:
|
||||
Confirmation that the message was sent.
|
||||
"""
|
||||
# This tool doesn't actually do anything in the tool execution
|
||||
# The actual event emission happens in the streaming bridge
|
||||
# This is just for LangGraph to recognize it as a tool call
|
||||
return f"Status-Update gesendet: {message}"
|
||||
|
||||
return send_streaming_message
|
||||
348
modules/features/chatbot/chatbot.py
Normal file
348
modules/features/chatbot/chatbot.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Chatbot domain logic."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, AsyncIterator, Any, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from modules.features.chatbot.bridges.ai import AICenterChatModel
|
||||
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
|
||||
from modules.features.chatbot.bridges.tools import (
|
||||
sqlite_query,
|
||||
tavily_search,
|
||||
create_send_streaming_message_tool,
|
||||
)
|
||||
from modules.features.chatbot.streaming.helpers import ChatStreamingHelper
|
||||
from modules.features.chatbot.streaming.events import get_event_manager
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Represents the state of a chat session."""
|
||||
|
||||
messages: Annotated[List[BaseMessage], add_messages]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot."""
|
||||
|
||||
model: AICenterChatModel
|
||||
memory: DatabaseCheckpointer
|
||||
app: CompiledStateGraph = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
workflow_id: str = "default"
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
model: AICenterChatModel,
|
||||
memory: DatabaseCheckpointer,
|
||||
system_prompt: str,
|
||||
workflow_id: str = "default",
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
Args:
|
||||
model: The chat model to use (AICenterChatModel).
|
||||
memory: The chat memory to use (DatabaseCheckpointer).
|
||||
system_prompt: The system prompt to initialize the chatbot.
|
||||
workflow_id: The workflow ID (maps to thread_id).
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
"""
|
||||
instance = Chatbot(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
configured_tools = await instance._configure_tools()
|
||||
instance.app = instance._build_app(memory, configured_tools)
|
||||
return instance
|
||||
|
||||
async def _configure_tools(self) -> List[Any]:
|
||||
"""Configure tools for the chatbot.
|
||||
|
||||
Returns:
|
||||
List of configured tools.
|
||||
"""
|
||||
tools = []
|
||||
|
||||
# SQL query tool
|
||||
tools.append(sqlite_query)
|
||||
|
||||
# Tavily search tool
|
||||
tools.append(tavily_search)
|
||||
|
||||
# Streaming status tool (needs event manager)
|
||||
event_manager = get_event_manager()
|
||||
send_streaming_message = create_send_streaming_message_tool(event_manager)
|
||||
tools.append(send_streaming_message)
|
||||
|
||||
return tools
|
||||
|
||||
def _build_app(
|
||||
self, memory: DatabaseCheckpointer, tools: List[Any]
|
||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||
"""Builds the chatbot application workflow using LangGraph.
|
||||
|
||||
Args:
|
||||
memory: The chat memory to use.
|
||||
tools: The list of tools the chatbot can use.
|
||||
|
||||
Returns:
|
||||
A compiled state graph representing the chatbot application.
|
||||
"""
|
||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||
|
||||
def select_window(msgs: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Selects a window of messages that fit within the context window size.
|
||||
|
||||
Args:
|
||||
msgs: The list of messages to select from.
|
||||
|
||||
Returns:
|
||||
A list of messages that fit within the context window size.
|
||||
"""
|
||||
|
||||
def approx_counter(items: List[BaseMessage]) -> int:
|
||||
"""Approximate token counter for messages.
|
||||
|
||||
Args:
|
||||
items: List of messages to count tokens for.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
"""
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
|
||||
# Use model's context length if available, otherwise default
|
||||
max_tokens = getattr(self.model._selected_model, "contextLength", 128000) if hasattr(self.model, "_selected_model") and self.model._selected_model else 128000
|
||||
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=int(max_tokens * 0.8), # Use 80% of context window
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
async def agent_node(state: ChatState) -> dict:
|
||||
"""Agent node for the chatbot workflow.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after processing.
|
||||
"""
|
||||
# Select the message window to fit in context (trim if needed)
|
||||
window = select_window(state.messages)
|
||||
|
||||
# Ensure the system prompt is present at the start
|
||||
if not window or not isinstance(window[0], SystemMessage):
|
||||
window = [SystemMessage(content=self.system_prompt)] + window
|
||||
|
||||
# Call the LLM with tools (use ainvoke for async)
|
||||
response = await llm_with_tools.ainvoke(window)
|
||||
|
||||
# Return the new state
|
||||
return {"messages": [response]}
|
||||
|
||||
def should_continue(state: ChatState) -> str:
|
||||
"""Determines whether to continue the workflow or end it.
|
||||
|
||||
This conditional edge is called after the agent node to decide
|
||||
whether to continue to the tools node (if the last message contains
|
||||
tool calls) or to end the workflow (if no tool calls are present).
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The next node to transition to ("tools" or END).
|
||||
"""
|
||||
# Get the last message
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# Check if the last message contains tool calls
|
||||
# If so, continue to the tools node; otherwise, end the workflow
|
||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||
|
||||
async def tools_with_retry(state: ChatState) -> dict:
|
||||
"""Tools node with retry logic.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after tool execution.
|
||||
"""
|
||||
# Execute tools normally
|
||||
tool_node = ToolNode(tools=tools)
|
||||
result = await tool_node.ainvoke(state)
|
||||
|
||||
# Check if we got no results and should retry
|
||||
no_results_keywords = [
|
||||
"returned 0 rows",
|
||||
"no data",
|
||||
"keine artikel gefunden",
|
||||
"keine ergebnisse"
|
||||
]
|
||||
|
||||
# Check tool results for no data
|
||||
for msg in result.get("messages", []):
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str):
|
||||
content_lower = content.lower()
|
||||
if any(keyword in content_lower for keyword in no_results_keywords):
|
||||
# Check if we haven't retried yet (avoid infinite loops)
|
||||
retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower())
|
||||
if retry_count < 2: # Allow max 2 retries
|
||||
logger.info("No results found in tool output, adding retry instruction")
|
||||
retry_message = HumanMessage(
|
||||
content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. "
|
||||
"Bitte versuche eine alternative Suchstrategie:\n"
|
||||
"1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), "
|
||||
"verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') "
|
||||
"UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n"
|
||||
"2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n"
|
||||
"3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n"
|
||||
"4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden"
|
||||
)
|
||||
result["messages"].append(retry_message)
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
# Compose the workflow
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("agent", agent_node)
|
||||
workflow.add_node("tools", tools_with_retry)
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("tools", "agent")
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]:
|
||||
"""Processes a chat message by calling the LLM and tools and returns the chat history.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: The chat thread ID.
|
||||
|
||||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Stream UI-focused events using astream_events v2.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||
memory and tools are scoped per thread.
|
||||
|
||||
Yields:
|
||||
dict: One of:
|
||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||
"""
|
||||
# Thread-aware config for LangGraph/LangChain
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
def _is_root(ev: dict) -> bool:
|
||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2",
|
||||
):
|
||||
etype = event.get("event")
|
||||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
# Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py)
|
||||
if etype == "on_tool_start":
|
||||
# Log all tool starts to debug
|
||||
logger.debug(f"Tool start event: name='{ename}', event='{etype}'")
|
||||
if ename == "send_streaming_message":
|
||||
tool_in = edata.get("input") or {}
|
||||
msg = tool_in.get("message")
|
||||
logger.info(f"send_streaming_message tool called with input: {tool_in}")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
logger.info(f"Status-Update gesendet: {msg.strip()}")
|
||||
yield {"type": "status", "label": msg.strip()}
|
||||
continue
|
||||
|
||||
# Emit the final payload when the root run finishes
|
||||
if etype == "on_chain_end" and _is_root(event):
|
||||
output_obj = edata.get("output")
|
||||
|
||||
# Extract message list from the graph's final output
|
||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||
output_obj=output_obj
|
||||
)
|
||||
|
||||
# Normalize for the frontend (only user/assistant with text content)
|
||||
chat_history_payload: List[dict] = []
|
||||
for m in final_msgs:
|
||||
if isinstance(m, BaseMessage):
|
||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||
elif isinstance(m, dict):
|
||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||
else:
|
||||
continue
|
||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||
chat_history_payload.append(d)
|
||||
|
||||
yield {
|
||||
"type": "final",
|
||||
"response": {
|
||||
"thread": chat_id,
|
||||
"chat_history": chat_history_payload,
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|
||||
170
modules/features/chatbot/chatbotConstants.py
Normal file
170
modules/features/chatbot/chatbotConstants.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chatbot constants and helper functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_conversation_name(
|
||||
services,
|
||||
prompt: str,
|
||||
user_language: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a conversation name from the user's prompt using AI.
|
||||
Creates a concise, informative summary name in German based on the user input.
|
||||
|
||||
Args:
|
||||
services: Services object with AI service
|
||||
prompt: User's input prompt (always in German)
|
||||
user_language: User's language preference (not used, always German)
|
||||
|
||||
Returns:
|
||||
A short, informative conversation name in German
|
||||
"""
|
||||
if not prompt or not prompt.strip():
|
||||
return "Neue Unterhaltung"
|
||||
|
||||
try:
|
||||
# Check if AI service is available
|
||||
if not hasattr(services, 'ai') or services.ai is None:
|
||||
logger.warning("AI service not available, generating name from prompt")
|
||||
return _generate_name_from_prompt(prompt)
|
||||
|
||||
# Ensure AI service is initialized before use
|
||||
await services.ai.ensureAiObjectsInitialized()
|
||||
|
||||
# Create AI prompt - very explicit that answer must be in German
|
||||
ai_prompt = f"""Du bist ein deutscher Assistent. Der Benutzer hat folgende Anfrage auf Deutsch gestellt:
|
||||
|
||||
"{prompt.strip()}"
|
||||
|
||||
Erstelle einen kurzen, zusammenfassenden Titel für diese Unterhaltung. Der Titel muss:
|
||||
- Auf Deutsch sein (KEIN Englisch!)
|
||||
- Maximal 50 Zeichen lang sein
|
||||
- Das Hauptthema zusammenfassen
|
||||
- Informativ sein
|
||||
|
||||
Beispiele für gute deutsche Titel:
|
||||
- "LED-Artikel Suche"
|
||||
- "Lagerbestandsabfrage"
|
||||
- "Produktinformationen"
|
||||
- "Artikel-Suche"
|
||||
|
||||
Antworte NUR mit dem deutschen Titel, ohne Anführungszeichen oder Erklärungen."""
|
||||
|
||||
# Create AI request
|
||||
request = AiCallRequest(
|
||||
prompt=ai_prompt,
|
||||
context="",
|
||||
options=AiCallOptions(
|
||||
operationType=OperationTypeEnum.DATA_GENERATE,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
compressPrompt=False,
|
||||
compressContext=False,
|
||||
temperature=0.3 # Lower temperature for more consistent German output
|
||||
)
|
||||
)
|
||||
|
||||
# Call AI service
|
||||
logger.info(f"Calling AI to generate conversation name for prompt: {prompt[:50]}...")
|
||||
response = await services.ai.callAi(request)
|
||||
|
||||
if not response or not hasattr(response, 'content') or not response.content:
|
||||
logger.warning("AI response invalid, generating name from prompt")
|
||||
return _generate_name_from_prompt(prompt)
|
||||
|
||||
logger.info(f"AI response received: {response.content[:100]}...")
|
||||
|
||||
# Clean up the AI response
|
||||
name = str(response.content).strip()
|
||||
name = name.strip('"\'')
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
if name.startswith('```'):
|
||||
lines = name.split('\n')
|
||||
if len(lines) > 1:
|
||||
name = '\n'.join(lines[1:-1]) if lines[-1].strip() == '```' else '\n'.join(lines[1:])
|
||||
|
||||
# Remove newlines and extra spaces
|
||||
name = " ".join(name.split())
|
||||
|
||||
# Check if name contains English words - if so, generate from prompt instead
|
||||
name_lower = name.lower()
|
||||
english_words = ["search", "find", "show", "display", "query", "article", "product", "item", "led articles", "product search"]
|
||||
if any(word in name_lower for word in english_words):
|
||||
logger.warning(f"AI generated English name '{name}', generating from prompt instead")
|
||||
return _generate_name_from_prompt(prompt)
|
||||
|
||||
# Limit to 50 characters
|
||||
if len(name) > 50:
|
||||
name = name[:47] + "..."
|
||||
|
||||
# If we got a valid name, return it
|
||||
if name and len(name) >= 3:
|
||||
logger.info(f"Successfully generated conversation name via AI: '{name}'")
|
||||
return name
|
||||
else:
|
||||
logger.warning(f"Generated name is too short: '{name}', generating from prompt")
|
||||
return _generate_name_from_prompt(prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating conversation name with AI: {e}", exc_info=True)
|
||||
return _generate_name_from_prompt(prompt)
|
||||
|
||||
|
||||
def _generate_name_from_prompt(prompt: str) -> str:
|
||||
"""
|
||||
Generate a conversation name directly from the German prompt.
|
||||
Creates a concise title by extracting key words and formatting them.
|
||||
|
||||
Args:
|
||||
prompt: User's input prompt in German
|
||||
|
||||
Returns:
|
||||
A short conversation name in German
|
||||
"""
|
||||
if not prompt or not prompt.strip():
|
||||
return "Neue Unterhaltung"
|
||||
|
||||
# Clean up the prompt
|
||||
name = prompt.strip()
|
||||
|
||||
# Remove newlines and extra spaces
|
||||
name = " ".join(name.split())
|
||||
|
||||
# Remove common question words and phrases
|
||||
question_words = ["wie", "was", "wo", "wann", "wer", "welche", "welcher", "welches"]
|
||||
words = name.split()
|
||||
filtered_words = [w for w in words if w.lower() not in question_words]
|
||||
|
||||
if filtered_words:
|
||||
name = " ".join(filtered_words)
|
||||
|
||||
# Capitalize first letter
|
||||
if name:
|
||||
name = name[0].upper() + name[1:] if len(name) > 1 else name.upper()
|
||||
|
||||
# Limit to 50 characters
|
||||
if len(name) > 50:
|
||||
# Try to cut at word boundary
|
||||
truncated = name[:47]
|
||||
last_space = truncated.rfind(' ')
|
||||
if last_space > 20: # Only cut at word boundary if reasonable
|
||||
name = truncated[:last_space] + "..."
|
||||
else:
|
||||
name = truncated + "..."
|
||||
|
||||
# If name is empty or too short, use default
|
||||
if not name or len(name) < 3:
|
||||
return "Neue Unterhaltung"
|
||||
|
||||
return name
|
||||
130
modules/features/chatbot/config.py
Normal file
130
modules/features/chatbot/config.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Configuration system for chatbot instances.
|
||||
Loads JSON configuration files from configs/ directory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for loaded configs
|
||||
_config_cache: Dict[str, 'ChatbotConfig'] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database configuration for a chatbot instance."""
|
||||
schema: Dict[str, Any]
|
||||
connector: str = "preprocessor"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolConfig:
|
||||
"""Tool configuration for a chatbot instance."""
|
||||
sql: Dict[str, Any]
|
||||
tavily: Optional[Dict[str, Any]] = None
|
||||
streaming: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Model configuration for a chatbot instance."""
|
||||
operationType: str = "DATA_ANALYSE"
|
||||
processingMode: str = "DETAILED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatbotConfig:
|
||||
"""Configuration for a chatbot instance."""
|
||||
id: str
|
||||
name: str
|
||||
systemPrompt: str
|
||||
database: DatabaseConfig
|
||||
tools: ToolConfig
|
||||
model: ModelConfig
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ChatbotConfig':
|
||||
"""Create ChatbotConfig from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", "default"),
|
||||
name=data.get("name", "Default Chatbot"),
|
||||
systemPrompt=data.get("systemPrompt", "You are a helpful assistant."),
|
||||
database=DatabaseConfig(
|
||||
schema=data.get("database", {}).get("schema", {}),
|
||||
connector=data.get("database", {}).get("connector", "preprocessor")
|
||||
),
|
||||
tools=ToolConfig(
|
||||
sql=data.get("tools", {}).get("sql", {"enabled": True}),
|
||||
tavily=data.get("tools", {}).get("tavily"),
|
||||
streaming=data.get("tools", {}).get("streaming", {"enabled": True})
|
||||
),
|
||||
model=ModelConfig(
|
||||
operationType=data.get("model", {}).get("operationType", "DATA_ANALYSE"),
|
||||
processingMode=data.get("model", {}).get("processingMode", "DETAILED")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_chatbot_config(config_id: str) -> ChatbotConfig:
|
||||
"""
|
||||
Load chatbot configuration from JSON file.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID (e.g., "althaus", "default")
|
||||
|
||||
Returns:
|
||||
ChatbotConfig instance
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file not found
|
||||
ValueError: If config file is invalid
|
||||
"""
|
||||
# Check cache first
|
||||
if config_id in _config_cache:
|
||||
logger.debug(f"Returning cached config for {config_id}")
|
||||
return _config_cache[config_id]
|
||||
|
||||
# Get path to configs directory
|
||||
current_dir = Path(__file__).parent
|
||||
configs_dir = current_dir / "configs"
|
||||
config_file = configs_dir / f"{config_id}.json"
|
||||
|
||||
if not config_file.exists():
|
||||
# Try default config if requested config not found
|
||||
if config_id != "default":
|
||||
logger.warning(f"Config {config_id} not found, trying default")
|
||||
return load_chatbot_config("default")
|
||||
raise FileNotFoundError(f"Chatbot config file not found: {config_file}")
|
||||
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
config = ChatbotConfig.from_dict(data)
|
||||
|
||||
# Cache the config
|
||||
_config_cache[config_id] = config
|
||||
logger.info(f"Loaded chatbot config: {config_id} ({config.name})")
|
||||
|
||||
return config
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing chatbot config JSON {config_file}: {e}")
|
||||
raise ValueError(f"Invalid JSON in config file {config_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading chatbot config {config_file}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def clear_config_cache():
|
||||
"""Clear the configuration cache."""
|
||||
global _config_cache
|
||||
_config_cache.clear()
|
||||
logger.debug("Cleared chatbot config cache")
|
||||
156
modules/features/chatbot/configs/althaus.json
Normal file
156
modules/features/chatbot/configs/althaus.json
Normal file
File diff suppressed because one or more lines are too long
31
modules/features/chatbot/configs/default.json
Normal file
31
modules/features/chatbot/configs/default.json
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
{
|
||||
"id": "default",
|
||||
"name": "Default Chatbot",
|
||||
"systemPrompt": "You are a helpful assistant. You have access to SQL query tools and web search tools. Use them to help answer user questions.",
|
||||
"database": {
|
||||
"schema": {
|
||||
"database": {
|
||||
"path": "/data/database.db",
|
||||
"type": "SQLite"
|
||||
},
|
||||
"tables": {},
|
||||
"relationships": []
|
||||
},
|
||||
"connector": "preprocessor"
|
||||
},
|
||||
"tools": {
|
||||
"sql": {
|
||||
"enabled": true
|
||||
},
|
||||
"tavily": {
|
||||
"enabled": false
|
||||
},
|
||||
"streaming": {
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"operationType": "DATA_ANALYSE",
|
||||
"processingMode": "DETAILED"
|
||||
}
|
||||
}
|
||||
|
|
@ -29,8 +29,11 @@ from .datamodelFeatureChatbot import ChatWorkflow, UserInputRequest, WorkflowMod
|
|||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
# Import chatbot feature
|
||||
from . import chatProcess
|
||||
from .eventManager import get_event_manager
|
||||
from modules.features.chatbot import chatProcess
|
||||
from modules.features.chatbot.streaming.events import get_event_manager
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.features.workflow import chatStop
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -241,17 +244,26 @@ async def stream_chatbot_start(
|
|||
event_type = event.get("type")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# Emit chatdata events (messages, logs, stats) in exact chatData format
|
||||
# Emit chatdata events (messages, logs, stats, status) in exact chatData format
|
||||
if event_type == "chatdata" and event_data:
|
||||
# Emit item directly in exact chatData format: {type, createdAt, item}
|
||||
chatdata_item = event_data
|
||||
# Ensure item field is serializable (convert Pydantic models to dicts)
|
||||
if isinstance(chatdata_item, dict) and "item" in chatdata_item:
|
||||
item_obj = chatdata_item.get("item")
|
||||
if hasattr(item_obj, "dict"):
|
||||
chatdata_item = chatdata_item.copy()
|
||||
chatdata_item["item"] = item_obj.dict()
|
||||
yield f"data: {json.dumps(chatdata_item)}\n\n"
|
||||
# Handle status events (transient UI feedback)
|
||||
if event_data.get("type") == "status":
|
||||
# Status events have simple structure: {type: "status", label: "..."}
|
||||
status_item = {
|
||||
"type": "status",
|
||||
"label": event_data.get("label", "")
|
||||
}
|
||||
yield f"data: {json.dumps(status_item)}\n\n"
|
||||
else:
|
||||
# Emit other chatdata items (messages, logs, stats) in exact chatData format
|
||||
chatdata_item = event_data
|
||||
# Ensure item field is serializable (convert Pydantic models to dicts)
|
||||
if isinstance(chatdata_item, dict) and "item" in chatdata_item:
|
||||
item_obj = chatdata_item.get("item")
|
||||
if hasattr(item_obj, "dict"):
|
||||
chatdata_item = chatdata_item.copy()
|
||||
chatdata_item["item"] = item_obj.dict()
|
||||
yield f"data: {json.dumps(chatdata_item)}\n\n"
|
||||
|
||||
# Handle completion/stopped events to close stream
|
||||
elif event_type == "complete":
|
||||
|
|
|
|||
1262
modules/features/chatbot/service.py
Normal file
1262
modules/features/chatbot/service.py
Normal file
File diff suppressed because it is too large
Load diff
3
modules/features/chatbot/streaming/__init__.py
Normal file
3
modules/features/chatbot/streaming/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Streaming infrastructure for chatbot events."""
|
||||
159
modules/features/chatbot/streaming/events.py
Normal file
159
modules/features/chatbot/streaming/events.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Event manager for chatbot streaming.
|
||||
Manages event queues for Server-Sent Events (SSE) streaming.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Any
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventManager:
|
||||
"""
|
||||
Manages event queues for chatbot streaming.
|
||||
Each workflow has its own async queue for events.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event manager."""
|
||||
self._queues: Dict[str, asyncio.Queue] = {}
|
||||
self._cleanup_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def create_queue(self, workflow_id: str) -> asyncio.Queue:
|
||||
"""
|
||||
Create an event queue for a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
|
||||
Returns:
|
||||
Async queue for events
|
||||
"""
|
||||
if workflow_id not in self._queues:
|
||||
self._queues[workflow_id] = asyncio.Queue()
|
||||
logger.debug(f"Created event queue for workflow {workflow_id}")
|
||||
return self._queues[workflow_id]
|
||||
|
||||
def get_queue(self, workflow_id: str) -> Optional[asyncio.Queue]:
|
||||
"""
|
||||
Get the event queue for a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
|
||||
Returns:
|
||||
Async queue if exists, None otherwise
|
||||
"""
|
||||
return self._queues.get(workflow_id)
|
||||
|
||||
def has_queue(self, workflow_id: str) -> bool:
|
||||
"""
|
||||
Check if a queue exists for a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
|
||||
Returns:
|
||||
True if queue exists, False otherwise
|
||||
"""
|
||||
return workflow_id in self._queues
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
context_id: str,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
event_category: str = "chat",
|
||||
message: Optional[str] = None,
|
||||
step: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Emit an event to the queue for a workflow.
|
||||
|
||||
Args:
|
||||
context_id: Workflow ID (context)
|
||||
event_type: Type of event (e.g., "chatdata", "complete", "error")
|
||||
data: Event data dictionary
|
||||
event_category: Category of event (e.g., "chat", "workflow")
|
||||
message: Optional message string
|
||||
step: Optional step identifier
|
||||
"""
|
||||
queue = self._queues.get(context_id)
|
||||
if not queue:
|
||||
logger.warning(f"No queue found for workflow {context_id}, event not emitted")
|
||||
return
|
||||
|
||||
event = {
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"category": event_category,
|
||||
"message": message,
|
||||
"step": step
|
||||
}
|
||||
|
||||
try:
|
||||
await queue.put(event)
|
||||
logger.debug(f"Emitted {event_type} event for workflow {context_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting event for workflow {context_id}: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self, workflow_id: str, delay: float = 60.0) -> None:
|
||||
"""
|
||||
Schedule cleanup of a queue after a delay.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
delay: Delay in seconds before cleanup
|
||||
"""
|
||||
# Cancel existing cleanup task if any
|
||||
if workflow_id in self._cleanup_tasks:
|
||||
self._cleanup_tasks[workflow_id].cancel()
|
||||
|
||||
async def _cleanup():
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
if workflow_id in self._queues:
|
||||
# Drain remaining events
|
||||
queue = self._queues[workflow_id]
|
||||
while not queue.empty():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# Remove queue
|
||||
del self._queues[workflow_id]
|
||||
logger.info(f"Cleaned up event queue for workflow {workflow_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Cleanup cancelled for workflow {workflow_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup for workflow {workflow_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
if workflow_id in self._cleanup_tasks:
|
||||
del self._cleanup_tasks[workflow_id]
|
||||
|
||||
# Schedule cleanup
|
||||
task = asyncio.create_task(_cleanup())
|
||||
self._cleanup_tasks[workflow_id] = task
|
||||
|
||||
|
||||
# Global event manager instance
|
||||
_event_manager: Optional[EventManager] = None
|
||||
|
||||
|
||||
def get_event_manager() -> EventManager:
|
||||
"""
|
||||
Get the global event manager instance.
|
||||
|
||||
Returns:
|
||||
EventManager instance
|
||||
"""
|
||||
global _event_manager
|
||||
if _event_manager is None:
|
||||
_event_manager = EventManager()
|
||||
return _event_manager
|
||||
242
modules/features/chatbot/streaming/helpers.py
Normal file
242
modules/features/chatbot/streaming/helpers.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Streaming helper utilities for chat message processing and normalization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
Role = Literal["user", "assistant", "system", "tool"]
|
||||
|
||||
|
||||
class ChatStreamingHelper:
|
||||
"""Pure helper methods for streaming and message normalization.
|
||||
|
||||
This class provides static utility methods for converting between different
|
||||
message formats, extracting content, and normalizing message structures
|
||||
for streaming chat applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||
"""Extract the role from a BaseMessage instance.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to extract the role from.
|
||||
|
||||
Returns:
|
||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||
Defaults to "assistant" if the message type is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello")
|
||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||
'user'
|
||||
"""
|
||||
if isinstance(msg, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage):
|
||||
return "system"
|
||||
if isinstance(msg, ToolMessage):
|
||||
return "tool"
|
||||
return getattr(msg, "role", "assistant")
|
||||
|
||||
@staticmethod
|
||||
def flatten_content(*, content: Any) -> str:
|
||||
"""Convert complex content structures to plain text.
|
||||
|
||||
This method handles various content formats including strings, lists of
|
||||
content parts, and dictionaries with text fields. It's designed to
|
||||
normalize content from different message sources into a consistent
|
||||
plain text format.
|
||||
|
||||
Args:
|
||||
content: The content to flatten. Can be:
|
||||
- str: Returned as-is after stripping whitespace
|
||||
- list: Each item processed and joined with newlines
|
||||
- dict: Text extracted from "text" or "content" fields
|
||||
- None: Returns empty string
|
||||
- Any other type: Converted to string
|
||||
|
||||
Returns:
|
||||
The flattened content as a plain text string with whitespace stripped.
|
||||
|
||||
Examples:
|
||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Hello\nworld'
|
||||
|
||||
>>> content = {"text": "Simple message"}
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Simple message'
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part and isinstance(part["text"], str):
|
||||
parts.append(part["text"])
|
||||
elif part.get("type") == "text" and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
parts.append(part["text"])
|
||||
elif "content" in part and isinstance(part["content"], str):
|
||||
parts.append(part["content"])
|
||||
else:
|
||||
# Fallback for unknown dictionary structures
|
||||
val = part.get("value")
|
||||
if isinstance(val, str):
|
||||
parts.append(val)
|
||||
else:
|
||||
parts.append(str(part))
|
||||
return "\n".join(p.strip() for p in parts if p is not None)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
return content["text"].strip()
|
||||
if "content" in content and isinstance(content["content"], str):
|
||||
return content["content"].strip()
|
||||
return str(content).strip()
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||
|
||||
This method normalizes BaseMessage instances into a consistent dictionary
|
||||
format suitable for JSON serialization and streaming to clients.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello there")
|
||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello there'
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||
"content": ChatStreamingHelper.flatten_content(
|
||||
content=getattr(msg, "content", "")
|
||||
),
|
||||
}
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
name = getattr(msg, "name", None)
|
||||
if name:
|
||||
payload["name"] = name
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||
|
||||
This method handles messages that come from serialized state and are
|
||||
represented as dictionaries rather than BaseMessage instances. It
|
||||
normalizes various dictionary formats into a consistent structure.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-shaped message to convert. Expected to contain
|
||||
fields like "role", "type", "content", "text", etc.
|
||||
|
||||
Returns:
|
||||
A normalized dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> obj = {"type": "human", "content": "Hello"}
|
||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello'
|
||||
"""
|
||||
role: Optional[str] = obj.get("role")
|
||||
if not role:
|
||||
# Handle alternative type field mappings
|
||||
typ = obj.get("type")
|
||||
if typ in ("human", "user"):
|
||||
role = "user"
|
||||
elif typ in ("ai", "assistant"):
|
||||
role = "assistant"
|
||||
elif typ in ("system",):
|
||||
role = "system"
|
||||
elif typ in ("tool", "function"):
|
||||
role = "tool"
|
||||
|
||||
content = obj.get("content")
|
||||
if content is None and "text" in obj:
|
||||
content = obj["text"]
|
||||
|
||||
out: Dict[str, Any] = {
|
||||
"role": role or "assistant",
|
||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||
}
|
||||
if "tool_calls" in obj:
|
||||
out["tool_calls"] = obj["tool_calls"]
|
||||
if obj.get("name"):
|
||||
out["name"] = obj["name"]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||
"""Extract messages from LangGraph output objects.
|
||||
|
||||
This method handles various output formats from LangGraph execution,
|
||||
extracting the messages list from different possible structures.
|
||||
|
||||
Args:
|
||||
output_obj: The output object from LangGraph execution. Can be:
|
||||
- An object with a "messages" attribute
|
||||
- A dictionary with a "messages" key
|
||||
- Any other object (returns empty list)
|
||||
|
||||
Returns:
|
||||
A list of extracted messages, or an empty list if no messages
|
||||
are found or if the output object is None.
|
||||
|
||||
Examples:
|
||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||
>>> len(messages)
|
||||
1
|
||||
"""
|
||||
if output_obj is None:
|
||||
return []
|
||||
|
||||
# Try to parse dicts first
|
||||
if isinstance(output_obj, dict):
|
||||
msgs = output_obj.get("messages")
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
||||
# Then try to get messages attribute
|
||||
msgs = getattr(output_obj, "messages", None)
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
|
@ -1109,6 +1109,29 @@ class ChatObjects:
|
|||
actionName=createdMessage.get("actionName")
|
||||
)
|
||||
|
||||
<<<<<<< HEAD:modules/interfaces/interfaceDbChat.py
|
||||
=======
|
||||
# Emit message event for streaming (if event manager is available)
|
||||
try:
|
||||
from modules.features.chatbot.streaming.events import get_event_manager
|
||||
event_manager = get_event_manager()
|
||||
message_timestamp = parseTimestamp(chat_message.publishedAt, default=getUtcTimestamp())
|
||||
# Emit message event in exact chatData format: {type, createdAt, item}
|
||||
asyncio.create_task(event_manager.emit_event(
|
||||
context_id=workflowId,
|
||||
event_type="chatdata",
|
||||
data={
|
||||
"type": "message",
|
||||
"createdAt": message_timestamp,
|
||||
"item": chat_message.dict()
|
||||
},
|
||||
event_category="chat"
|
||||
))
|
||||
except Exception as e:
|
||||
# Event manager not available or error - continue without emitting
|
||||
logger.debug(f"Could not emit message event: {e}")
|
||||
|
||||
>>>>>>> feat/chatbot-althaus-integration:modules/interfaces/interfaceDbChatObjects.py
|
||||
# Debug: Store message and documents for debugging - only if debug enabled
|
||||
storeDebugMessageAndDocuments(chat_message, self.currentUser)
|
||||
|
||||
|
|
@ -1469,6 +1492,29 @@ class ChatObjects:
|
|||
# Create log in normalized table
|
||||
createdLog = self.db.recordCreate(ChatLog, log_model)
|
||||
|
||||
<<<<<<< HEAD:modules/interfaces/interfaceDbChat.py
|
||||
=======
|
||||
# Emit log event for streaming (if event manager is available)
|
||||
try:
|
||||
from modules.features.chatbot.streaming.events import get_event_manager
|
||||
event_manager = get_event_manager()
|
||||
log_timestamp = parseTimestamp(createdLog.get("timestamp"), default=getUtcTimestamp())
|
||||
# Emit log event in exact chatData format: {type, createdAt, item}
|
||||
asyncio.create_task(event_manager.emit_event(
|
||||
context_id=workflowId,
|
||||
event_type="chatdata",
|
||||
data={
|
||||
"type": "log",
|
||||
"createdAt": log_timestamp,
|
||||
"item": ChatLog(**createdLog).dict()
|
||||
},
|
||||
event_category="chat"
|
||||
))
|
||||
except Exception as e:
|
||||
# Event manager not available or error - continue without emitting
|
||||
logger.debug(f"Could not emit log event: {e}")
|
||||
|
||||
>>>>>>> feat/chatbot-althaus-integration:modules/interfaces/interfaceDbChatObjects.py
|
||||
# Return validated ChatLog instance
|
||||
return ChatLog(**createdLog)
|
||||
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ azure-communication-email>=1.0.0 # Azure Communication Services Email
|
|||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
|
||||
## Configuration Validation
|
||||
jsonschema>=4.0.0 # Required for chatbot workflow config validation
|
||||
|
||||
## For Scheduling / Repeated Tasks
|
||||
APScheduler==3.11.0
|
||||
|
||||
|
|
|
|||
3
tests/functional/chatbot/__init__.py
Normal file
3
tests/functional/chatbot/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Chatbot functional tests."""
|
||||
217
tests/functional/chatbot/test_chatbot.py
Normal file
217
tests/functional/chatbot/test_chatbot.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Chatbot Functional Tests
|
||||
|
||||
Tests the chatbot implementation to ensure:
|
||||
1. Chatbot initialization works correctly
|
||||
2. Streaming events are emitted properly
|
||||
3. Tool calls execute correctly
|
||||
4. Messages are stored in database
|
||||
5. No infinite loops occur
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the gateway to path (go up 2 levels from tests/functional/chatbot/)
|
||||
_gateway_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _gateway_path not in sys.path:
|
||||
sys.path.insert(0, _gateway_path)
|
||||
|
||||
import pytest
|
||||
from modules.features.chatbot.chatbot import Chatbot
|
||||
from modules.features.chatbot.chatbotAIBridge import AICenterChatModel
|
||||
from modules.features.chatbot.chatbotMemory import DatabaseCheckpointer
|
||||
from modules.features.chatbot.chatbotConfig import load_chatbot_config
|
||||
from modules.features.chatbot.streamingHelper import ChatStreamingHelper
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelAi import OperationTypeEnum, ProcessingModeEnum
|
||||
|
||||
|
||||
class TestChatbot:
|
||||
"""Test suite for chatbot functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(self):
|
||||
"""Create a test user."""
|
||||
return User(
|
||||
id="test_user_chatbot",
|
||||
username="test_chatbot",
|
||||
email="test@example.com",
|
||||
fullName="Test Chatbot User",
|
||||
language="de",
|
||||
mandateId="test_mandate",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_id(self):
|
||||
"""Generate a test workflow ID."""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatbot_initialization(self, test_user, workflow_id):
|
||||
"""Test that chatbot can be initialized correctly."""
|
||||
# Load config
|
||||
config = load_chatbot_config("althaus")
|
||||
|
||||
# Create system prompt
|
||||
from datetime import datetime
|
||||
system_prompt = config.systemPrompt.replace(
|
||||
"{{DATE}}",
|
||||
datetime.now().strftime("%d.%m.%Y")
|
||||
)
|
||||
|
||||
# Create AI center model
|
||||
operation_type = OperationTypeEnum[config.model.operationType]
|
||||
processing_mode = ProcessingModeEnum[config.model.processingMode]
|
||||
|
||||
model = AICenterChatModel(
|
||||
user=test_user,
|
||||
operation_type=operation_type,
|
||||
processing_mode=processing_mode
|
||||
)
|
||||
|
||||
# Create memory/checkpointer
|
||||
memory = DatabaseCheckpointer(user=test_user, workflow_id=workflow_id)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
|
||||
assert chatbot is not None
|
||||
assert chatbot.model is not None
|
||||
assert chatbot.memory is not None
|
||||
assert chatbot.app is not None
|
||||
assert chatbot.system_prompt == system_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_helper_role_from_message(self):
|
||||
"""Test ChatStreamingHelper.role_from_message."""
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
human_msg = HumanMessage(content="Hello")
|
||||
assert ChatStreamingHelper.role_from_message(msg=human_msg) == "user"
|
||||
|
||||
ai_msg = AIMessage(content="Hi there")
|
||||
assert ChatStreamingHelper.role_from_message(msg=ai_msg) == "assistant"
|
||||
|
||||
system_msg = SystemMessage(content="You are a helpful assistant")
|
||||
assert ChatStreamingHelper.role_from_message(msg=system_msg) == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_helper_flatten_content(self):
|
||||
"""Test ChatStreamingHelper.flatten_content."""
|
||||
# Test string
|
||||
assert ChatStreamingHelper.flatten_content(content="Hello") == "Hello"
|
||||
|
||||
# Test list
|
||||
content_list = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "World"}]
|
||||
result = ChatStreamingHelper.flatten_content(content=content_list)
|
||||
assert "Hello" in result
|
||||
assert "World" in result
|
||||
|
||||
# Test dict
|
||||
content_dict = {"text": "Simple message"}
|
||||
assert ChatStreamingHelper.flatten_content(content=content_dict) == "Simple message"
|
||||
|
||||
# Test None
|
||||
assert ChatStreamingHelper.flatten_content(content=None) == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_helper_message_to_dict(self):
|
||||
"""Test ChatStreamingHelper.message_to_dict."""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
msg = HumanMessage(content="Hello there")
|
||||
result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello there"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_helper_extract_messages_from_output(self):
|
||||
"""Test ChatStreamingHelper.extract_messages_from_output."""
|
||||
# Test dict with messages
|
||||
output_dict = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output_dict)
|
||||
assert len(messages) == 1
|
||||
|
||||
# Test None
|
||||
messages = ChatStreamingHelper.extract_messages_from_output(output_obj=None)
|
||||
assert len(messages) == 0
|
||||
|
||||
# Test object with messages attribute
|
||||
class MockOutput:
|
||||
def __init__(self):
|
||||
self.messages = [{"role": "assistant", "content": "Hi"}]
|
||||
|
||||
mock_output = MockOutput()
|
||||
messages = ChatStreamingHelper.extract_messages_from_output(output_obj=mock_output)
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatbot_should_continue_logic(self, test_user, workflow_id):
|
||||
"""Test that should_continue logic works correctly (no infinite loops)."""
|
||||
# Load config
|
||||
config = load_chatbot_config("althaus")
|
||||
|
||||
# Create system prompt
|
||||
from datetime import datetime
|
||||
system_prompt = config.systemPrompt.replace(
|
||||
"{{DATE}}",
|
||||
datetime.now().strftime("%d.%m.%Y")
|
||||
)
|
||||
|
||||
# Create AI center model
|
||||
operation_type = OperationTypeEnum[config.model.operationType]
|
||||
processing_mode = ProcessingModeEnum[config.model.processingMode]
|
||||
|
||||
model = AICenterChatModel(
|
||||
user=test_user,
|
||||
operation_type=operation_type,
|
||||
processing_mode=processing_mode
|
||||
)
|
||||
|
||||
# Create memory/checkpointer
|
||||
memory = DatabaseCheckpointer(user=test_user, workflow_id=workflow_id)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
|
||||
# The should_continue logic is internal to the workflow
|
||||
# We can test that the workflow compiles successfully
|
||||
assert chatbot.app is not None
|
||||
|
||||
# Test that we can invoke the workflow (this will test should_continue internally)
|
||||
# Use a simple message that shouldn't cause infinite loops
|
||||
try:
|
||||
result = await chatbot.chat(
|
||||
message="Hallo",
|
||||
chat_id=workflow_id
|
||||
)
|
||||
# Should return messages without infinite loop
|
||||
assert result is not None
|
||||
assert isinstance(result, list)
|
||||
except Exception as e:
|
||||
# If there's an error, it shouldn't be an infinite loop error
|
||||
# (infinite loops would timeout or hit max iterations)
|
||||
assert "infinite" not in str(e).lower()
|
||||
assert "loop" not in str(e).lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Reference in a new issue