From 1badb8468aba11d4bff8419153b0a01587ad12d9 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Tue, 21 Oct 2025 14:48:43 +0200 Subject: [PATCH] refactored maxToken handling for all models and all ai calls --- env_dev.env | 3 - env_int.env | 3 - env_prod.env | 3 - modules/connectors/connectorAiAnthropic.py | 10 +- modules/connectors/connectorAiOpenai.py | 17 +- modules/connectors/connectorAiPerplexity.py | 23 +- modules/interfaces/interfaceAiObjects.py | 31 +- modules/services/serviceAi/subCoreAi.py | 84 ++-- .../renderers/rendererImage.py | 2 +- .../serviceGeneration/subPromptBuilder.py | 122 ++---- .../serviceWorkflow/mainServiceWorkflow.py | 19 +- .../processing/core/messageCreator.py | 6 +- .../processing/modes/modeActionplan.py | 4 +- .../workflows/processing/modes/modeReact.py | 13 +- modules/workflows/workflowManager.py | 7 +- requirements.txt | 27 +- test_ai_behavior.py | 365 ++++++++++++++++++ .../chatBot/utils/test_toolRegistry.py | 198 ---------- 18 files changed, 528 insertions(+), 409 deletions(-) create mode 100644 test_ai_behavior.py delete mode 100644 tests/features/chatBot/utils/test_toolRegistry.py diff --git a/env_dev.env b/env_dev.env index 0c3fd25b..3fcef74e 100644 --- a/env_dev.env +++ b/env_dev.env @@ -55,21 +55,18 @@ Connector_AiOpenai_API_URL = https://api.openai.com/v1/chat/completions Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEajBuZmtYTVdqLTBpQm9KZ2pCXzRCV3VhZzlYTEhKb1FqWXNrV3lyb25uZUN1WVVQUEY3dGYtejludV9MNGlKeVREanZGOGloV09mY2ttQ3k5SjBFOGFac2ZQTkNKNUZWVnRINVQyeWhsR2wyYnVrRDNzV2NqSHB0ajQ4UWtGeGZtbmR0Q3VvS0hDZlphVmpSc2Z6RG5nPT0= Connector_AiOpenai_MODEL_NAME = gpt-4o Connector_AiOpenai_TEMPERATURE = 0.2 -Connector_AiOpenai_MAX_TOKENS = 2000 # Anthropic configuration Connector_AiAnthropic_API_URL = https://api.anthropic.com/v1/messages Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09 Connector_AiAnthropic_MODEL_NAME = claude-3-5-sonnet-20241022 Connector_AiAnthropic_TEMPERATURE = 0.2 -Connector_AiAnthropic_MAX_TOKENS = 2000 # Perplexity AI configuration Connector_AiPerplexity_API_URL = https://api.perplexity.ai/chat/completions Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0= Connector_AiPerplexity_MODEL_NAME = sonar Connector_AiPerplexity_TEMPERATURE = 0.2 -Connector_AiPerplexity_MAX_TOKENS = 2000 # Agent Mail configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/env_int.env b/env_int.env index c69f1f57..f7c35746 100644 --- a/env_int.env +++ b/env_int.env @@ -55,21 +55,18 @@ Connector_AiOpenai_API_URL = https://api.openai.com/v1/chat/completions Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjSDBNYkptSkQxTUotYVVpZVNZc0dxNGNwSEtkOEE0T3RZWjROTEhSRlRXdlZmQUxxZ0w3Y0xOV2JNV19LNF9yTUZiU1pUNG15U2VDUDdSVlI4VlpnR3JXVFFtcXBaTEZiaUtSclVFd0lCZG1rWVhra1dfWTVQOTBEYUU0MjByYVNEMTFmeXNOcmpUT216MmJKdlVPeW5nPT0= Connector_AiOpenai_MODEL_NAME = gpt-4o Connector_AiOpenai_TEMPERATURE = 0.2 -Connector_AiOpenai_MAX_TOKENS = 2000 # Anthropic configuration Connector_AiAnthropic_API_URL = https://api.anthropic.com/v1/messages Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09 Connector_AiAnthropic_MODEL_NAME = claude-3-5-sonnet-20241022 Connector_AiAnthropic_TEMPERATURE = 0.2 -Connector_AiAnthropic_MAX_TOKENS = 2000 # Perplexity AI configuration Connector_AiPerplexity_API_URL = https://api.perplexity.ai/chat/completions Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0= Connector_AiPerplexity_MODEL_NAME = sonar Connector_AiPerplexity_TEMPERATURE = 0.2 -Connector_AiPerplexity_MAX_TOKENS = 2000 # Agent Mail configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/env_prod.env b/env_prod.env index 3d1aa40a..67092526 100644 --- a/env_prod.env +++ b/env_prod.env @@ -55,21 +55,18 @@ Connector_AiOpenai_API_URL = https://api.openai.com/v1/chat/completions Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pU05XM2hMaExPMnpYeFpwRVhyYl9JZmRITmlmRDlWOUJSSWE4NTFLZUptSkJhNlEycHBLZmh3WFA2ZmU5VmxHZks1UUNVOUZnckZNdXZ2MTY2dFg1Nl8yWDRrcTRlT0tHYkhyRGZINTEzU25iYVFRMzJGeUZIdlc4LU9GbmpQYmtmU3lJT2VVZ1UzLVd3R25ZQ092SUVnPT0= Connector_AiOpenai_MODEL_NAME = gpt-4o Connector_AiOpenai_TEMPERATURE = 0.2 -Connector_AiOpenai_MAX_TOKENS = 2000 # Anthropic configuration Connector_AiAnthropic_API_URL = https://api.anthropic.com/v1/messages Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pNTA1RkZ3UllCOXVsNVZzbkw2Rkl1TWxCZ0wwWEVXUm9ReUhBcVl1cGFUdW9FRVh4elVxR0x3NVRxZkc4SkxHVFdzSU1YNG5Rb0FqSHJhdElwWm1iLWdubTVDcUl3UkVjVHNoU0xLa0ZTSFlfTlJUVXg4cVVwUWdlVDBTSFU5SnBzS0ZnVjlQcmtiNzV2UTNMck1IakZ0OWlubUtlWDZnMk4yX2JsZ1U4Wm1yT29fM2d2NVBNOWNBbWtTRWNyQ2tZNjhwSVF6bG5SU3dTenR2MzA3Z19NUT09 Connector_AiAnthropic_MODEL_NAME = claude-3-5-sonnet-20241022 Connector_AiAnthropic_TEMPERATURE = 0.2 -Connector_AiAnthropic_MAX_TOKENS = 2000 # Perplexity AI configuration Connector_AiPerplexity_API_URL = https://api.perplexity.ai/chat/completions Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0= Connector_AiPerplexity_MODEL_NAME = sonar Connector_AiPerplexity_TEMPERATURE = 0.2 -Connector_AiPerplexity_MAX_TOKENS = 2000 # Agent Mail configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/modules/connectors/connectorAiAnthropic.py b/modules/connectors/connectorAiAnthropic.py index e7eb07a2..85cf62f2 100644 --- a/modules/connectors/connectorAiAnthropic.py +++ b/modules/connectors/connectorAiAnthropic.py @@ -15,7 +15,6 @@ def loadConfigData(): "apiUrl": APP_CONFIG.get('Connector_AiAnthropic_API_URL'), "modelName": APP_CONFIG.get('Connector_AiAnthropic_MODEL_NAME'), "temperature": float(APP_CONFIG.get('Connector_AiAnthropic_TEMPERATURE')), - "maxTokens": int(APP_CONFIG.get('Connector_AiAnthropic_MAX_TOKENS')) } class AiAnthropic: @@ -60,8 +59,8 @@ class AiAnthropic: if temperature is None: temperature = self.config.get("temperature", 0.2) - if maxTokens is None: - maxTokens = self.config.get("maxTokens", 2000) + # Don't set maxTokens from config - let the model use its full context length + # Our continuation system handles stopping early via prompt engineering # Transform OpenAI-style messages to Anthropic format: # - Move any 'system' role content to top-level 'system' @@ -105,8 +104,11 @@ class AiAnthropic: "model": self.modelName, "messages": converted_messages, "temperature": temperature, - "max_tokens": maxTokens, } + + # Only add max_tokens if it's explicitly set + if maxTokens is not None: + payload["max_tokens"] = maxTokens if system_prompt: payload["system"] = system_prompt diff --git a/modules/connectors/connectorAiOpenai.py b/modules/connectors/connectorAiOpenai.py index 692fe422..c768888e 100644 --- a/modules/connectors/connectorAiOpenai.py +++ b/modules/connectors/connectorAiOpenai.py @@ -19,7 +19,6 @@ def loadConfigData(): "apiUrl": APP_CONFIG.get('Connector_AiOpenai_API_URL'), "modelName": APP_CONFIG.get('Connector_AiOpenai_MODEL_NAME'), "temperature": float(APP_CONFIG.get('Connector_AiOpenai_TEMPERATURE')), - "maxTokens": int(APP_CONFIG.get('Connector_AiOpenai_MAX_TOKENS')) } class AiOpenai: @@ -62,16 +61,19 @@ class AiOpenai: if temperature is None: temperature = self.config.get("temperature", 0.2) - if maxTokens is None: - maxTokens = self.config.get("maxTokens", 2000) + # Don't set maxTokens from config - let the model use its full context length + # Our continuation system handles stopping early via prompt engineering payload = { "model": self.modelName, "messages": messages, - "temperature": temperature, - "max_tokens": maxTokens + "temperature": temperature } + # Only add max_tokens if it's explicitly set + if maxTokens is not None: + payload["max_tokens"] = maxTokens + response = await self.httpClient.post( self.apiUrl, json=payload @@ -161,13 +163,12 @@ class AiOpenai: # Use parameters from configuration temperature = self.config.get("temperature", 0.2) - maxTokens = self.config.get("maxTokens", 2000) + # Don't set maxTokens - let the model use its full context length payload = { "model": visionModel, "messages": messages, - "temperature": temperature, - "max_tokens": maxTokens + "temperature": temperature } response = await self.httpClient.post( diff --git a/modules/connectors/connectorAiPerplexity.py b/modules/connectors/connectorAiPerplexity.py index fc97b885..b075a84d 100644 --- a/modules/connectors/connectorAiPerplexity.py +++ b/modules/connectors/connectorAiPerplexity.py @@ -15,7 +15,6 @@ def loadConfigData(): "apiUrl": APP_CONFIG.get('Connector_AiPerplexity_API_URL'), "modelName": APP_CONFIG.get('Connector_AiPerplexity_MODEL_NAME'), "temperature": float(APP_CONFIG.get('Connector_AiPerplexity_TEMPERATURE')), - "maxTokens": int(APP_CONFIG.get('Connector_AiPerplexity_MAX_TOKENS')) } class AiPerplexity: @@ -60,16 +59,19 @@ class AiPerplexity: if temperature is None: temperature = self.config.get("temperature", 0.2) - if maxTokens is None: - maxTokens = self.config.get("maxTokens", 2000) + # Don't set maxTokens from config - let the model use its full context length + # Our continuation system handles stopping early via prompt engineering payload = { "model": self.modelName, "messages": messages, - "temperature": temperature, - "max_tokens": maxTokens + "temperature": temperature } + # Only add max_tokens if it's explicitly set + if maxTokens is not None: + payload["max_tokens"] = maxTokens + response = await self.httpClient.post( self.apiUrl, json=payload @@ -116,8 +118,8 @@ class AiPerplexity: if temperature is None: temperature = self.config.get("temperature", 0.2) - if maxTokens is None: - maxTokens = self.config.get("maxTokens", 2000) + # Don't set maxTokens from config - let the model use its full context length + # Our continuation system handles stopping early via prompt engineering # For web search, we use the configured model name webSearchModel = self.modelName @@ -130,10 +132,13 @@ class AiPerplexity: "content": query } ], - "temperature": temperature, - "max_tokens": maxTokens + "temperature": temperature } + # Only add max_tokens if it's explicitly set + if maxTokens is not None: + payload["max_tokens"] = maxTokens + response = await self.httpClient.post( self.apiUrl, json=payload diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index d272fc41..2e9b6b38 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -512,16 +512,9 @@ class AiObjects: if temperature is None: temperature = 0.2 maxTokens = getattr(options, "maxTokens", None) - # Provide a generous default to avoid truncation for long outputs - if maxTokens is None: - # If resultFormat suggests large outputs (e.g., html, json), allow more tokens - wants_large = str(getattr(options, "resultFormat", "")).lower() in ["html", "json", "md", "markdown"] - maxTokens = 8000 if wants_large else 2000 + # Don't set artificial limits - let the model use its full context length + # Our continuation system handles stopping early via prompt engineering - messages: List[Dict[str, Any]] = [] - if context: - messages.append({"role": "system", "content": f"Context from documents:\n{context}"}) - messages.append({"role": "user", "content": prompt}) # Get fallback models for this operation type fallbackModels = self._getFallbackModels(options.operationType) @@ -532,6 +525,26 @@ class AiObjects: try: logger.info(f"Attempting AI call with model: {modelName} (attempt {attempt + 1}/{len(fallbackModels)})") + # Store the selected model for token limit resolution + self._lastSelectedModel = modelName + + # Replace placeholder in prompt and context if present + context_length = aiModels[modelName].get("contextLength", 0) + if context_length > 0: + token_limit = str(context_length) + else: + token_limit = "4000" # Default for text generation + + if "" in prompt: + prompt = prompt.replace("", token_limit) + logger.debug(f"Replaced with {token_limit} for model {modelName}") + + # Update messages array with replaced content + messages = [] + if context: + messages.append({"role": "system", "content": f"Context from documents:\n{context}"}) + messages.append({"role": "user", "content": prompt}) + # Start timing startTime = time.time() diff --git a/modules/services/serviceAi/subCoreAi.py b/modules/services/serviceAi/subCoreAi.py index 6e8555a6..6afaafce 100644 --- a/modules/services/serviceAi/subCoreAi.py +++ b/modules/services/serviceAi/subCoreAi.py @@ -6,12 +6,27 @@ from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, ModelCa logger = logging.getLogger(__name__) + # Loop instruction texts for different formats LoopInstructionTexts = { "json": """ -CRITICAL: -- If content is too long: deliver a valid partial JSON and set "continuation" to briefly describe the remaining content -- If content fits: deliver complete result and set \"continuation\": null +CRITICAL LIMITS: tokens total (reserve 20% for JSON structure) + +MANDATORY RULES: +1. STOP at approximately 80% of limit to ensure valid JSON completion +2. Return ONLY raw JSON (no ```json blocks, no text before/after) +3. ALWAYS include "continuation" field - this is MANDATORY + +CONTINUATION REQUIREMENTS: +- If you can complete the full request: {"continuation": null} +- If you must stop early: { + "continuation": { + "last_data_items": "exact last items you generated (copy them exactly)", + "next_instruction": "Continue from [exact last item] - generate next items" + } +} + +BE CONSERVATIVE: Stop generating content when you reach approximately 3200-3500 characters to ensure JSON completion. """, # Add more formats here as needed # "xml": "...", @@ -32,6 +47,8 @@ class SubCoreAi: self.services = services self.aiObjects = aiObjects + + # Shared Core Function for AI Calls with Looping async def _callAiWithLooping( self, @@ -70,6 +87,7 @@ class SubCoreAi: logger.error(f"Unsupported loopInstructionFormat for prompt: {loopInstructionFormat}") loopInstruction = "" + while iteration < max_iterations: iteration += 1 logger.debug(f"AI call iteration {iteration}/{max_iterations}") @@ -171,33 +189,36 @@ class SubCoreAi: """ Build continuation content for follow-up iterations. """ - # Extract continuation description from the last response if it exists + # Extract continuation description from the last response continuation_description = "" if accumulatedContent: try: last_response = accumulatedContent[-1] - parsed_response = json.loads(last_response) - if isinstance(parsed_response, dict) and parsed_response.get("continuation"): - continuation_description = parsed_response["continuation"] - except (json.JSONDecodeError, KeyError): + # Use the same JSON extraction logic as the main loop + extracted = self.services.utils.jsonExtractString(last_response) + parsed_response = json.loads(extracted) + if isinstance(parsed_response, dict): + # Check for continuation at root level or in metadata + continuation = parsed_response.get("continuation") + if continuation is None and "metadata" in parsed_response: + continuation = parsed_response["metadata"].get("continuation") + + if continuation: + continuation_description = continuation + except (json.JSONDecodeError, KeyError, ValueError): pass continuation_content = f"""CONTINUATION REQUEST (Iteration {iteration}): -Continue generating content from where you left off. +You are continuing a previous response. DO NOT repeat any previous content. -IMPORTANT: -- Maintain the same JSON structure -- Continue from the exact point where you stopped -- If you can complete the remaining content: set "continuation": null -- If content is still too long, deliver next part and set "continuation": "description of remaining content" +{f"CONTINUATION INSTRUCTIONS: {continuation_description}" if continuation_description else "No specific continuation instructions provided."} -Previous content: -{chr(10).join(accumulatedContent[-1:]) if accumulatedContent else "None"} - -{f"Continuation needed: {continuation_description}" if continuation_description else ""} - -Continue generating content now.""" - +CRITICAL REQUIREMENTS: +- Start from the exact point specified in continuation instructions +- DO NOT repeat any previous content +- BE CONSERVATIVE: Stop at approximately 3200-3500 characters to ensure JSON completion +- ALWAYS include continuation field - set to null if complete, or provide next instruction if incomplete +""" return continuation_content def _mergeJsonContent(self, accumulatedContent: List[str]) -> str: @@ -282,9 +303,7 @@ Continue generating content now.""" generation_prompt = await buildGenerationPrompt( outputFormat=outputFormat, userPrompt=prompt, - title=title, - aiService=self, - services=self.services + title=title ) # If we have extracted content, prepend it to the prompt @@ -537,14 +556,6 @@ Continue generating content now.""" return "text" - - - -# TO CHECK FUNCTIONS TODO - - - - def _getModelCapabilitiesForContent(self, prompt: str, documents: Optional[List[ChatDocument]], options: AiCallOptions) -> Dict[str, int]: """ Get model capabilities for content processing, including appropriate size limits for chunking. @@ -635,15 +646,6 @@ Continue generating content now.""" return full_prompt - def _exceedsTokenLimit(self, text: str, model: ModelCapabilities, safety_margin: float) -> bool: - """ - Check if text exceeds model token limit with safety margin. - """ - # Simple character-based estimation (4 chars per token) - estimated_tokens = len(text) // 4 - max_tokens = int(model.maxTokens * (1 - safety_margin)) - return estimated_tokens > max_tokens - def _reducePlanningPrompt( self, full_prompt: str, diff --git a/modules/services/serviceGeneration/renderers/rendererImage.py b/modules/services/serviceGeneration/renderers/rendererImage.py index 9da52466..f47dd54d 100644 --- a/modules/services/serviceGeneration/renderers/rendererImage.py +++ b/modules/services/serviceGeneration/renderers/rendererImage.py @@ -180,7 +180,7 @@ Return only the compressed prompt, no explanations. prompt=compression_prompt, options=AiCallOptions( operationType=OperationType.GENERAL, - maxTokens=2000, + maxTokens=None, # Let the model use its full context length temperature=0.3 # Lower temperature for more consistent compression ) ) diff --git a/modules/services/serviceGeneration/subPromptBuilder.py b/modules/services/serviceGeneration/subPromptBuilder.py index 37b97917..8f4afdb4 100644 --- a/modules/services/serviceGeneration/subPromptBuilder.py +++ b/modules/services/serviceGeneration/subPromptBuilder.py @@ -17,6 +17,36 @@ else: logger = logging.getLogger(__name__) +# Centralized JSON structure template for document generation +JSON_STRUCTURE_TEMPLATE = """{ + "continuation": null, + "metadata": { + "title": "{{DOCUMENT_TITLE}}", + "splitStrategy": "single_document", + "source_documents": [], + "extraction_method": "ai_generation" + }, + "documents": [{ + "id": "doc_1", + "title": "{{DOCUMENT_TITLE}}", + "filename": "document.json", + "sections": [ + { + "id": "section_1", + "content_type": "heading|paragraph|table|list|code", + "elements": [ + // heading: {"level": 1, "text": "..."} + // paragraph: {"text": "..."} + // table: {"headers": [...], "rows": [[...]], "caption": "..."} + // list: {"items": [{"text": "...", "subitems": [...]}], "list_type": "bullet|numbered"} + // code: {"code": "...", "language": "..."} + ], + "order": 1 + } + ] + }] +}""" + async def buildAdaptiveExtractionPrompt( outputFormat: str, userPrompt: str, @@ -149,6 +179,9 @@ async def buildGenerationPrompt( title: str ) -> str: """Build the unified generation prompt using a single JSON template.""" + # Create a template with the actual title + json_template = JSON_STRUCTURE_TEMPLATE.replace("{{DOCUMENT_TITLE}}", title) + # Always use the proper generation prompt template with LOOP_INSTRUCTION result = f"""Generate structured JSON content for document creation. @@ -156,94 +189,15 @@ USER REQUEST: "{userPrompt}" DOCUMENT TITLE: "{title}" TARGET FORMAT: {outputFormat} -Return ONLY valid JSON matching this structure (template below). Do not include any prose before/after. Use this as the single template reference for your output: -{{ - "metadata": {{ - "title": "{title}", - "splitStrategy": "single_document", - "source_documents": [], - "extraction_method": "ai_generation" - }}, - "documents": [ - {{ - "id": "doc_1", - "title": "{title}", - "filename": "document.{outputFormat}", - "sections": [ - {{ - "id": "section_1", - "content_type": "heading", - "elements": [ - {{ - "level": 1, - "text": "1. SECTION TITLE" - }} - ], - "order": 1 - }}, - {{ - "id": "section_2", - "content_type": "paragraph", - "elements": [ - {{ - "text": "This is the actual content that should be generated." - }} - ], - "order": 2 - }}, - {{ - "id": "section_3", - "content_type": "table", - "elements": [ - {{ - "headers": ["Column 1", "Column 2", "Column 3"], - "rows": [ - ["R1C1", "R1C2", "R1C3"], - ["R2C1", "R2C2", "R2C3"] - ], - "caption": "Example table" - }} - ], - "order": 3 - }}, - {{ - "id": "section_4", - "content_type": "list", - "elements": [ - {{ - "items": [ - {{ "text": "First item" }}, - {{ "text": "Second item", "subitems": [{{ "text": "Second.1" }}] }}, - {{ "text": "Third item" }} - ], - "list_type": "bullet" - }} - ], - "order": 4 - }}, - {{ - "id": "section_5", - "content_type": "code", - "elements": [ - {{ - "code": "print('Hello World')", - "language": "python" - }} - ], - "order": 5 - }} - ] - }} - ], - "continuation": null -}} - RULES: -- Follow the template structure above exactly; emit only one JSON object in the response +- Follow the template structure below exactly; emit only one JSON object in the response - Fill sections with content based on the user request -- Use appropriate content_type: "heading", "paragraph", "table", "list" +- Use appropriate content_type LOOP_INSTRUCTION + +Return ONLY valid JSON matching this structure (template below). Do not include any prose before/after. Use this as the single template reference for your output: +{json_template} """ return result.strip() diff --git a/modules/services/serviceWorkflow/mainServiceWorkflow.py b/modules/services/serviceWorkflow/mainServiceWorkflow.py index 5eae45c2..c30028b9 100644 --- a/modules/services/serviceWorkflow/mainServiceWorkflow.py +++ b/modules/services/serviceWorkflow/mainServiceWorkflow.py @@ -18,6 +18,7 @@ class WorkflowService: self.interfaceDbChat = serviceCenter.interfaceDbChat self.interfaceDbComponent = serviceCenter.interfaceDbComponent self.interfaceDbApp = serviceCenter.interfaceDbApp + self._progressLogger = None async def summarizeChat(self, messages: List[ChatMessage]) -> str: """ @@ -917,20 +918,32 @@ Please provide a comprehensive summary of this conversation.""" logger.error(f"Error getting connection reference list: {str(e)}") return [] + def setWorkflowContext(self, workflow): + """Set the current workflow context for this service""" + self.workflow = workflow + # Reset progress logger for new workflow + self._progressLogger = None + + def _getProgressLogger(self): + """Get or create the progress logger instance""" + if self._progressLogger is None: + self._progressLogger = ProgressLogger(self, self.workflow) + return self._progressLogger + def createProgressLogger(self, workflow) -> ProgressLogger: return ProgressLogger(self, workflow) def progressLogStart(self, operationId: str, serviceName: str, actionName: str, context: str = ""): """Wrapper for ProgressLogger.startOperation""" - progressLogger = self.createProgressLogger(self.workflow) + progressLogger = self._getProgressLogger() return progressLogger.startOperation(operationId, serviceName, actionName, context) def progressLogUpdate(self, operationId: str, progress: float, statusUpdate: str = ""): """Wrapper for ProgressLogger.updateOperation""" - progressLogger = self.createProgressLogger(self.workflow) + progressLogger = self._getProgressLogger() return progressLogger.updateOperation(operationId, progress, statusUpdate) def progressLogFinish(self, operationId: str, success: bool = True): """Wrapper for ProgressLogger.finishOperation""" - progressLogger = self.createProgressLogger(self.workflow) + progressLogger = self._getProgressLogger() return progressLogger.finishOperation(operationId, success) \ No newline at end of file diff --git a/modules/workflows/processing/core/messageCreator.py b/modules/workflows/processing/core/messageCreator.py index 6a195c16..bbae5610 100644 --- a/modules/workflows/processing/core/messageCreator.py +++ b/modules/workflows/processing/core/messageCreator.py @@ -142,12 +142,12 @@ class MessageCreator: # Build a user-friendly message based on success/failure if result.success: - messageText = f"**Action {currentAction}/{totalActions} ({action.execMethod}.{action.execAction})**\n\n" + messageText = f"**Action {currentAction} ({action.execMethod}.{action.execAction})**\n\n" messageText += f"βœ… {taskObjective}\n\n" else: # ⚠️ FAILURE MESSAGE - Show error details to user errorDetails = result.error if result.error else "Unknown error occurred" - messageText = f"**Action {currentAction}/{totalActions} ({action.execMethod}.{action.execAction})**\n\n" + messageText = f"**Action {currentAction} ({action.execMethod}.{action.execAction})**\n\n" messageText += f"❌ {taskObjective}\n\n" messageText += f"{errorDetails}\n\n" @@ -195,7 +195,7 @@ class MessageCreator: self._checkWorkflowStopped(workflow) # Create a task completion message for the user - taskProgress = f"{taskIndex}/{totalTasks}" if totalTasks is not None else str(taskIndex) + taskProgress = str(taskIndex) # Enhanced completion message with criteria details completionMessage = f"🎯 **Task {taskProgress}**\n\nβœ… {reviewResult.reason or 'Task completed successfully'}" diff --git a/modules/workflows/processing/modes/modeActionplan.py b/modules/workflows/processing/modes/modeActionplan.py index bbea997d..aa04a070 100644 --- a/modules/workflows/processing/modes/modeActionplan.py +++ b/modules/workflows/processing/modes/modeActionplan.py @@ -270,8 +270,6 @@ class ActionplanMode(BaseMode): actionNumber = actionIdx + 1 self._updateWorkflowBeforeExecutingAction(actionNumber) - # Update workflow context for this action - self.services.workflow.setWorkflowContext(action_number=actionNumber) # Log action start logger.info(f"Task {taskIndex} - Starting action {actionNumber}/{totalActions}") @@ -280,7 +278,7 @@ class ActionplanMode(BaseMode): actionStartMessage = { "workflowId": workflow.id, "role": "assistant", - "message": f"⚑ **Action {actionNumber}/{totalActions}** (Method {action.execMethod}.{action.execAction})", + "message": f"⚑ **Action {actionNumber}** (Method {action.execMethod}.{action.execAction})", "status": "step", "sequenceNr": len(workflow.messages) + 1, "publishedAt": self.services.utils.timestampGetUtc(), diff --git a/modules/workflows/processing/modes/modeReact.py b/modules/workflows/processing/modes/modeReact.py index 9aea47ea..4b581116 100644 --- a/modules/workflows/processing/modes/modeReact.py +++ b/modules/workflows/processing/modes/modeReact.py @@ -66,11 +66,7 @@ class ReactMode(BaseMode): # Update workflow object before executing task if taskIndex is not None: self._updateWorkflowBeforeExecutingTask(taskIndex) - - # Update workflow context for this task - if taskIndex is not None: - self.services.workflow.setWorkflowContext(task_number=taskIndex) - + # Create task start message await self.messageCreator.createTaskStartMessage(taskStep, workflow, taskIndex, totalTasks) @@ -87,7 +83,6 @@ class ReactMode(BaseMode): # Update workflow[currentAction] for UI self._updateWorkflowBeforeExecutingAction(step) - self.services.workflow.setWorkflowContext(action_number=step) try: t0 = time.time() @@ -309,7 +304,7 @@ class ReactMode(BaseMode): maxCost=0.05, maxProcessingTime=30, temperature=0.3, # Slightly higher temperature for better instruction following - # maxTokens not set - use model's maximum for big JSON responses + # max tokens not set - use model's maximum for big JSON responses resultFormat="json" # Explicitly request JSON format ) @@ -670,7 +665,7 @@ class ReactMode(BaseMode): if messageType == "before": # Message BEFORE action execution userMessage = await self._generateActionIntentionMessage(method, actionName, userLanguage) - messageContent = f"πŸ”„ **Step {step}/{maxSteps}**\n\n{userMessage}" + messageContent = f"πŸ”„ **Step {step}**\n\n{userMessage}" status = "step" actionProgress = "pending" documentsLabel = f"action_{step}_intention" @@ -679,7 +674,7 @@ class ReactMode(BaseMode): # Message AFTER action execution userMessage = await self._generateActionResultMessage(method, actionName, result, observation, userLanguage) successIcon = "βœ…" if result and result.success else "❌" - messageContent = f"{successIcon} **Step {step}/{maxSteps} Complete**\n\n{userMessage}" + messageContent = f"{successIcon} **Step {step} Complete**\n\n{userMessage}" status = "step" actionProgress = "success" if result and result.success else "fail" documentsLabel = observation.get('resultLabel') if observation else f"action_{step}_result" diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index 6c227f10..f32f8ad4 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -139,6 +139,10 @@ class WorkflowManager: # Store the current user prompt in services for easy access throughout the workflow self.services.rawUserPrompt = userInput.prompt self.services.currentUserPrompt = userInput.prompt + + # Update the workflow service with the current workflow context + self.services.workflow.setWorkflowContext(workflow) + self.workflowProcessor = WorkflowProcessor(self.services, workflow) await self._sendFirstMessage(userInput, workflow) task_plan = await self._planTasks(userInput, workflow) @@ -183,9 +187,6 @@ class WorkflowManager: "actionProgress": "pending" } - # Clear trace log for new workflow session - self.workflowProcessor.clearTraceLog() - # Analyze the user's input to detect language, normalize request, extract intent, and offload bulky context into documents created_docs = [] diff --git a/requirements.txt b/requirements.txt index 293bf0e6..5191019b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ websockets==12.0 uvicorn==0.23.2 python-multipart==0.0.6 httpx>=0.25.2 -pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility +pydantic>=2.0.0 # Upgraded to v2 for compatibility email-validator==2.0.0 # Required by Pydantic for email validation slowapi==0.1.8 # For rate limiting @@ -59,7 +59,6 @@ Pillow>=10.0.0 # FΓΌr Bildverarbeitung (als PIL importiert) python-dateutil==2.8.2 python-dotenv==1.0.0 pytz>=2023.3 # For timezone handling and UTC operations -anyio>=4.2.0 # Used by chatbot tools and async utilities ## Dependencies for trio (used by httpx) sortedcontainers>=2.4.0 # Required by trio @@ -72,17 +71,6 @@ google-cloud-texttospeech==2.16.3 ## MSFT Integration msal==1.24.1 -# Enhanced Office document processing -python-docx>=0.8.11 -openpyxl>=3.0.9 -python-pptx>=0.6.21 -xlrd>=2.0.1 # For legacy .xls files -Pillow>=9.0.0 # For image processing -PyPDF2>=3.0.0 -PyMuPDF>=1.20.0 -beautifulsoup4>=4.11.0 -chardet>=4.0.0 # For encoding detection - ## Testing Dependencies pytest>=8.0.0 pytest-asyncio>=0.21.0 @@ -111,15 +99,4 @@ xyzservices>=2021.09.1 # PostgreSQL connector dependencies psycopg2-binary==2.9.9 -asyncpg==0.30.0 - -## LangChain & LangGraph -langchain==0.3.27 -langgraph==0.6.8 -langchain-core==0.3.77 -langchain-anthropic==0.3.1 # For Claude models -psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer) -psycopg-pool==3.2.1 # Connection pooling for PostgreSQL -langgraph-checkpoint-postgres==2.0.24 - -greenlet==3.2.4 \ No newline at end of file +asyncpg==0.30.0 \ No newline at end of file diff --git a/test_ai_behavior.py b/test_ai_behavior.py new file mode 100644 index 00000000..cd6d1b32 --- /dev/null +++ b/test_ai_behavior.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +""" +AI Behavior Test - Tests actual AI responses with different prompt structures +""" + +import asyncio +import json +import sys +import os +from typing import Dict, Any, List + +# Add the gateway to path +sys.path.append(os.path.dirname(__file__)) + +# Import the service initialization +from modules.features.chatPlayground.mainChatPlayground import getServices +from modules.datamodels.datamodelAi import AiCallOptions, OperationType +from modules.datamodels.datamodelUam import User + +# The test uses the AI service which handles JSON template internally + +class AIBehaviorTester: + def __init__(self): + # Create a minimal user context for testing + testUser = User( + id="test_user", + username="test_user", + email="test@example.com", + fullName="Test User", + language="en", + mandateId="test_mandate" + ) + + # Initialize services using the existing system + self.services = getServices(testUser, None) # Test user, no workflow + self.testResults = [] + + async def initialize(self): + """Initialize the AI service.""" + # Set logging level to DEBUG to see debug messages + import logging + logging.getLogger().setLevel(logging.DEBUG) + + # The AI service needs to be recreated with proper initialization + from modules.services.serviceAi.mainServiceAi import AiService + self.services.ai = await AiService.create(self.services) + + # Create a minimal workflow context + from modules.datamodels.datamodelChat import ChatWorkflow + import uuid + + self.services.currentWorkflow = ChatWorkflow( + id=str(uuid.uuid4()), + name="Test Workflow", + status="running", + startedAt=self.services.utils.timestampGetUtc(), + lastActivity=self.services.utils.timestampGetUtc(), + currentRound=1, + currentTask=0, + currentAction=0, + totalTasks=0, + totalActions=0, + mandateId="test_mandate", + messageIds=[], + workflowMode="React", + maxSteps=5 + ) + + async def testPromptBehavior(self, promptName: str, prompt: str, maxIterations: int = 2) -> Dict[str, Any]: + """Test actual AI behavior with a specific prompt structure.""" + print(f"\n{'='*60}") + print(f"TESTING AI BEHAVIOR: {promptName}") + print(f"{'='*60}") + + print(f"User prompt: {prompt}") + print(f"Prompt length: {len(prompt)} characters") + + accumulatedContent = [] + + # Use the AI service directly with the user prompt - it will build the generation prompt internally + try: + # Use the existing AI service with JSON format - it handles looping internally + response = await self.services.ai.coreAi.callAiDocuments( + prompt=prompt, # Use the raw user prompt directly + documents=None, + outputFormat="json", + title="Prime Numbers Test", + loopInstructionFormat="json" # Use the JSON loop instructions + ) + + if isinstance(response, dict): + result = json.dumps(response, indent=2) + else: + result = str(response) + + print(f"Response length: {len(result)} characters") + print(f"Response preview: {result[:200]}...") + + # If we got an error response, try to extract the actual AI content from debug files + if isinstance(response, dict) and not response.get("success", True): + # The AI service wrapped the response in an error format + # We need to get the actual AI content from the debug files + print("⚠️ AI returned error response, but may have generated content") + + # Try to read the actual AI response from debug files + debug_content = self._getLatestDebugResponse() + if debug_content: + result = debug_content + print(f"πŸ“„ Found debug content: {len(result)} characters") + print(f"πŸ“„ Debug preview: {result[:200]}...") + + # Parse and analyze response + parsed_result = self._parseJsonResponse(result) + if parsed_result: + # Check if continuation + if parsed_result.get("continuation") is not None: + continuation_text = parsed_result.get("continuation", "") + print(f"βœ… Continuation detected: {continuation_text[:100]}...") + accumulatedContent.append(result) + + # Analyze continuation quality + continuation_quality = self._analyzeContinuationQuality(continuation_text) + print(f" Continuation quality: {continuation_quality['score']}/10") + print(f" Issues: {', '.join(continuation_quality['issues'])}") + else: + print("βœ… Final response received") + accumulatedContent.append(result) + else: + print("❌ Invalid JSON response") + accumulatedContent.append(result) + + except Exception as e: + print(f"❌ Error in AI call: {str(e)}") + accumulatedContent.append("") + + # Analyze results + result = self._analyzeBehaviorResults(promptName, accumulatedContent) + self.testResults.append(result) + return result + + def _extractContinuationInstruction(self, response: str) -> str: + """Extract continuation instruction from response.""" + try: + parsed = json.loads(response) + return parsed.get("continuation", "") + except: + return "" + + + def _getLatestDebugResponse(self) -> str: + """Get the latest AI response from debug files.""" + try: + import glob + import os + + # Look for the most recent debug response file + debug_pattern = "local/logs/debug/prompts/*document_generation_response*.txt" + debug_files = glob.glob(debug_pattern) + + if debug_files: + # Sort by modification time, get the most recent + latest_file = max(debug_files, key=os.path.getmtime) + with open(latest_file, 'r', encoding='utf-8') as f: + return f.read() + return "" + except Exception as e: + print(f"Error reading debug file: {e}") + return "" + + def _parseJsonResponse(self, response: str) -> Dict[str, Any]: + """Parse JSON response.""" + try: + # First try direct JSON parsing + return json.loads(response) + except: + try: + # Try extracting JSON from markdown code blocks + if "```json" in response: + start = response.find("```json") + 7 + end = response.find("```", start) + if end > start: + json_str = response[start:end].strip() + return json.loads(json_str) + elif "```" in response: + start = response.find("```") + 3 + end = response.find("```", start) + if end > start: + json_str = response[start:end].strip() + return json.loads(json_str) + return None + except: + return None + + def _analyzeContinuationQuality(self, continuation_text: str) -> Dict[str, Any]: + """Analyze the quality of continuation instructions.""" + score = 10 + issues = [] + + try: + # Parse the continuation object + if isinstance(continuation_text, str): + continuation_obj = json.loads(continuation_text) + else: + continuation_obj = continuation_text + + # Check for required fields + if not isinstance(continuation_obj, dict): + score -= 5 + issues.append("Not a valid object") + return {"score": max(0, score), "issues": issues} + + # Check for last_data_items + if "last_data_items" not in continuation_obj: + score -= 3 + issues.append("Missing last_data_items") + elif not continuation_obj["last_data_items"]: + score -= 2 + issues.append("Empty last_data_items") + + # Check for next_instruction + if "next_instruction" not in continuation_obj: + score -= 3 + issues.append("Missing next_instruction") + elif not continuation_obj["next_instruction"]: + score -= 2 + issues.append("Empty next_instruction") + + # Check for specific data points in last_data_items + if "last_data_items" in continuation_obj: + last_items = continuation_obj["last_data_items"] + if not any(char.isdigit() for char in str(last_items)): + score -= 1 + issues.append("No specific numbers in last_data_items") + + # Check for clear instruction in next_instruction + if "next_instruction" in continuation_obj: + instruction = continuation_obj["next_instruction"] + if "continue" not in instruction.lower(): + score -= 1 + issues.append("No 'continue' in next_instruction") + + except (json.JSONDecodeError, TypeError): + score -= 5 + issues.append("Invalid JSON format") + + return { + "score": max(0, score), + "issues": issues + } + + def _analyzeBehaviorResults(self, promptName: str, accumulatedContent: List[str]) -> Dict[str, Any]: + """Analyze AI behavior results.""" + totalContentLength = 0 + iterations = len(accumulatedContent) + continuationInstructions = [] + continuationQualities = [] + + for i, content in enumerate(accumulatedContent): + parsed = self._parseJsonResponse(content) + if parsed: + # Count content length in the response + contentLength = len(content) + totalContentLength += contentLength + + continuation = parsed.get("continuation") + if continuation: + continuationInstructions.append(continuation) + quality = self._analyzeContinuationQuality(continuation) + continuationQualities.append(quality) + + # Calculate averages + avgContinuationQuality = sum(q["score"] for q in continuationQualities) / len(continuationQualities) if continuationQualities else 0 + + return { + "promptName": promptName, + "iterations": iterations, + "totalContentLength": totalContentLength, + "continuationInstructions": continuationInstructions, + "avgContinuationQuality": avgContinuationQuality, + "success": totalContentLength > 0, + "efficiency": totalContentLength / iterations if iterations > 0 else 0 + } + + def _countPrimesInResponse(self, parsed: Dict[str, Any]) -> int: + """Count prime numbers in the parsed response.""" + count = 0 + + if "documents" in parsed: + for doc in parsed["documents"]: + if "sections" in doc: + for section in doc["sections"]: + if section.get("content_type") == "table" and "elements" in section: + for element in section["elements"]: + if "rows" in element: + for row in element["rows"]: + for cell in row: + if isinstance(cell, (str, int)) and str(cell).isdigit(): + count += 1 + + return count + + def printBehaviorResults(self): + """Print AI behavior test results.""" + print(f"\n{'='*80}") + print("AI BEHAVIOR TEST RESULTS") + print(f"{'='*80}") + + for result in self.testResults: + print(f"\n{result['promptName']}:") + print(f" Iterations: {result['iterations']}") + print(f" Total Content Length: {result['totalContentLength']}") + print(f" Efficiency: {result['efficiency']:.1f} chars/iteration") + print(f" Avg Continuation Quality: {result['avgContinuationQuality']:.1f}/10") + print(f" Success: {'βœ…' if result['success'] else '❌'}") + + if result['continuationInstructions']: + print(f" Continuation Instructions:") + for i, instruction in enumerate(result['continuationInstructions']): + print(f" {i+1}: {instruction[:80]}...") + + # Find best performing prompt + if self.testResults: + bestEfficiency = max(self.testResults, key=lambda x: x['efficiency']) + bestQuality = max(self.testResults, key=lambda x: x['avgContinuationQuality']) + + print(f"\n{'='*80}") + print("BEST PERFORMERS") + print(f"{'='*80}") + print(f"πŸ† Best Efficiency: {bestEfficiency['promptName']} ({bestEfficiency['efficiency']:.1f} chars/iteration)") + print(f"🎯 Best Continuation Quality: {bestQuality['promptName']} ({bestQuality['avgContinuationQuality']:.1f}/10)") + +# Test prompt scenarios for GENERIC continuation behavior +# These test different approaches to handle ANY user prompt and ANY data type +PROMPT_SCENARIOS = { + "Prime Numbers Test": """Generate the first 5000 prime numbers in a table with 10 columns per row.""", + + "Fibonacci Sequence": """Generate the first 1000 Fibonacci numbers in a table with 5 columns per row.""", + + "Multiplication Table": """Generate multiplication tables from 1 to 50, each table with 10 columns per row.""", + + "Random Data": """Generate 2000 random numbers between 1 and 10000 in a table with 8 columns per row.""", + + "Text Content": """Generate a comprehensive guide about machine learning with 50 sections, each containing detailed explanations and examples.""" +} + +async def main(): + """Run AI behavior testing.""" + tester = AIBehaviorTester() + + print("Starting AI Behavior Testing...") + print("Initializing AI service...") + await tester.initialize() + + print(f"Testing {len(PROMPT_SCENARIOS)} different prompt scenarios") + + for promptName, prompt in PROMPT_SCENARIOS.items(): + try: + await tester.testPromptBehavior(promptName, prompt, maxIterations=2) + except Exception as e: + print(f"❌ Failed to test {promptName}: {str(e)}") + + tester.printBehaviorResults() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/features/chatBot/utils/test_toolRegistry.py b/tests/features/chatBot/utils/test_toolRegistry.py deleted file mode 100644 index 219752b7..00000000 --- a/tests/features/chatBot/utils/test_toolRegistry.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Pytest tests for the tool registry. - -This module tests that the tool registry correctly discovers and catalogs -all tools in the chatbotTools directory. -""" - -import logging -import pytest -from modules.features.chatBot.utils.toolRegistry import ( - ToolMetadata, - ToolRegistry, - get_registry, - reinitialize_registry, -) -from langchain_core.tools import BaseTool - -logger = logging.getLogger(__name__) - - -class TestToolRegistry: - """Test suite for ToolRegistry class.""" - - @pytest.fixture - def registry(self) -> ToolRegistry: - """Provide a fresh registry instance for each test.""" - return reinitialize_registry() - - def test_registry_initialization(self, registry: ToolRegistry) -> None: - """Test that registry initializes correctly.""" - assert registry.is_initialized - assert isinstance(registry._tools, dict) - - def test_get_all_tools(self, registry: ToolRegistry) -> None: - """Test getting all registered tools.""" - all_tools = registry.get_all_tools() - assert isinstance(all_tools, list) - assert len(all_tools) > 0 - assert all(isinstance(tool, ToolMetadata) for tool in all_tools) - - # Log all discovered tools - logger.info(f"Found {len(all_tools)} tools in registry:") - for tool in all_tools: - logger.info(f"\n{tool}") - - def test_tool_metadata_structure(self, registry: ToolRegistry) -> None: - """Test that tool metadata has correct structure.""" - all_tools = registry.get_all_tools() - for tool in all_tools: - assert isinstance(tool.tool_id, str) - assert isinstance(tool.name, str) - assert isinstance(tool.category, str) - assert tool.category in ["shared", "customer"] - assert isinstance(tool.description, str) - assert isinstance(tool.tool_instance, BaseTool) - assert isinstance(tool.module_path, str) - - def test_list_tool_ids(self, registry: ToolRegistry) -> None: - """Test listing all tool IDs.""" - tool_ids = registry.list_tool_ids() - assert isinstance(tool_ids, list) - assert len(tool_ids) > 0 - assert all(isinstance(tool_id, str) for tool_id in tool_ids) - - # Check that tool IDs follow expected format - for tool_id in tool_ids: - assert "." in tool_id - category, name = tool_id.split(".", 1) - assert category in ["shared", "customer"] - - def test_get_specific_tool(self, registry: ToolRegistry) -> None: - """Test retrieving a specific tool by ID.""" - # Get all tool IDs first - tool_ids = registry.list_tool_ids() - if tool_ids: - # Test with first available tool - test_tool_id = tool_ids[0] - tool_metadata = registry.get_tool(tool_id=test_tool_id) - - assert tool_metadata is not None - assert isinstance(tool_metadata, ToolMetadata) - assert tool_metadata.tool_id == test_tool_id - - def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None: - """Test retrieving a tool that doesn't exist.""" - tool_metadata = registry.get_tool(tool_id="nonexistent.tool") - assert tool_metadata is None - - def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None: - """Test getting all shared tools.""" - shared_tools = registry.get_tools_by_category(category="shared") - assert isinstance(shared_tools, list) - assert all(tool.category == "shared" for tool in shared_tools) - - def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None: - """Test getting all customer tools.""" - customer_tools = registry.get_tools_by_category(category="customer") - assert isinstance(customer_tools, list) - assert all(tool.category == "customer" for tool in customer_tools) - - def test_get_tool_instances(self, registry: ToolRegistry) -> None: - """Test getting tool instances by IDs.""" - tool_ids = registry.list_tool_ids() - if len(tool_ids) >= 2: - # Test with first two tools - test_ids = tool_ids[:2] - instances = registry.get_tool_instances(tool_ids=test_ids) - - assert isinstance(instances, list) - assert len(instances) == 2 - assert all(isinstance(inst, BaseTool) for inst in instances) - - def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None: - """Test getting tool instances with some invalid IDs.""" - tool_ids = registry.list_tool_ids() - if tool_ids: - # Mix valid and invalid IDs - test_ids = [tool_ids[0], "invalid.tool"] - instances = registry.get_tool_instances(tool_ids=test_ids) - - # Should only return the valid one - assert len(instances) == 1 - assert isinstance(instances[0], BaseTool) - - def test_global_registry_singleton(self) -> None: - """Test that get_registry returns same instance.""" - registry1 = get_registry() - registry2 = get_registry() - assert registry1 is registry2 - - def test_reinitialize_registry(self) -> None: - """Test that reinitialize creates new instance.""" - registry1 = get_registry() - registry2 = reinitialize_registry() - # Should be different instances after reinitialize - assert registry1 is not registry2 - assert registry2.is_initialized - - -class TestToolDiscovery: - """Test suite for tool discovery functionality.""" - - def test_discovers_at_least_one_tool(self) -> None: - """Test that at least one tool is discovered.""" - registry = get_registry() - tool_ids = registry.list_tool_ids() - - # At least one tool should be successfully loaded - assert len(tool_ids) >= 1, "Expected at least one tool to be discovered" - - def test_query_althaus_database_if_available(self) -> None: - """Test query_althaus_database tool if it was successfully loaded.""" - registry = get_registry() - tool = registry.get_tool(tool_id="customer.query_althaus_database") - - if tool is not None: - assert tool.name == "query_althaus_database" - assert tool.category == "customer" - assert "database" in tool.description.lower() - else: - # Tool may not have loaded due to import errors - log warning - import logging - - logging.warning( - "customer.query_althaus_database tool not found - " - "may have failed to import" - ) - - def test_tavily_search_if_available(self) -> None: - """Test tavily_search tool if it was successfully loaded.""" - registry = get_registry() - tool = registry.get_tool(tool_id="shared.tavily_search") - - if tool is not None: - assert tool.name == "tavily_search" - assert tool.category == "shared" - assert "search" in tool.description.lower() - else: - # Tool may not have loaded due to import errors - log warning - import logging - - logging.warning( - "shared.tavily_search tool not found - may have failed to import" - ) - - def test_tool_ids_have_correct_format(self) -> None: - """Test that all discovered tool IDs follow the expected format.""" - registry = get_registry() - tool_ids = registry.list_tool_ids() - - for tool_id in tool_ids: - # All tool IDs should have format: category.toolname - assert "." in tool_id, f"Tool ID {tool_id} missing category separator" - category, name = tool_id.split(".", 1) - assert category in [ - "shared", - "customer", - ], f"Tool {tool_id} has invalid category: {category}" - assert len(name) > 0, f"Tool {tool_id} has empty name"