290 lines
No EOL
11 KiB
Python
290 lines
No EOL
11 KiB
Python
"""
|
|
Updated registry for all available agents in the system.
|
|
Provides centralized agent registration and access with improved error handling.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import importlib
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
# Import direct base agent module
|
|
from modules.agentservice_base import BaseAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AgentRegistry:
|
|
"""Registry for all available agents in the system"""
|
|
|
|
_instance = None
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
"""Get a singleton instance of the Agent Registry"""
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
"""Initialize the Agent Registry"""
|
|
if AgentRegistry._instance is not None:
|
|
raise RuntimeError("Singleton instance already exists - use get_instance()")
|
|
self.agents = {}
|
|
self.ai_service = None
|
|
self.document_handler = None
|
|
self.lucydom_interface = None
|
|
self._load_agents()
|
|
|
|
def _load_agents(self):
|
|
"""Load all available agents"""
|
|
# List of all agent modules to load
|
|
logger.info("Automatically loading agent modules...")
|
|
agent_modules = []
|
|
for filename in os.listdir(os.path.dirname(__file__)):
|
|
if filename.startswith("agentservice_agent_") and filename.endswith(".py"):
|
|
agent_modules.append(filename[:-3]) # Remove .py extension
|
|
if not agent_modules:
|
|
logger.warning("No agent modules found")
|
|
return
|
|
logger.info(f"Found {len(agent_modules)} agent modules")
|
|
|
|
for module_name in agent_modules:
|
|
try:
|
|
# Import the module
|
|
module = importlib.import_module(f"modules.{module_name}")
|
|
|
|
# Look for the agent class or a get_*_agent function
|
|
agent_type = module_name.split('_')[-1]
|
|
class_name = f"{agent_type.capitalize()}Agent"
|
|
getter_name = f"get_{agent_type}_agent"
|
|
|
|
agent = None
|
|
|
|
# Try to get the agent via the get_*_agent function
|
|
if hasattr(module, getter_name):
|
|
getter_func = getattr(module, getter_name)
|
|
agent = getter_func()
|
|
logger.info(f"Agent '{agent.name}' (Type: {agent.type}) loaded via {getter_name}()")
|
|
|
|
# Alternatively, try to instantiate the agent directly
|
|
elif hasattr(module, class_name):
|
|
agent_class = getattr(module, class_name)
|
|
agent = agent_class()
|
|
logger.info(f"Agent '{agent.name}' (Type: {agent.type}) directly instantiated")
|
|
|
|
if agent:
|
|
# Register the agent
|
|
self.register_agent(agent)
|
|
else:
|
|
logger.warning(f"No agent class or getter function found in module {module_name}")
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Module {module_name} could not be imported: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading agent from module {module_name}: {e}")
|
|
|
|
def set_dependencies(self, ai_service=None, document_handler=None, lucydom_interface=None):
|
|
"""
|
|
Set system dependencies for all agents.
|
|
|
|
Args:
|
|
ai_service: AI service for text generation
|
|
document_handler: Document handler for document operations
|
|
lucydom_interface: LucyDOM interface for database access
|
|
"""
|
|
self.ai_service = ai_service
|
|
# Update all registered agents
|
|
self.update_agent_dependencies()
|
|
|
|
|
|
def update_agent_dependencies(self):
|
|
"""Update dependencies for all registered agents"""
|
|
for agent_id, agent in self.agents.items():
|
|
if hasattr(agent, 'set_dependencies'):
|
|
agent.set_dependencies(
|
|
ai_service=self.ai_service,
|
|
document_handler=self.document_handler,
|
|
lucydom_interface=self.lucydom_interface
|
|
)
|
|
|
|
def register_agent(self, agent: 'BaseAgent'):
|
|
"""
|
|
Register an agent in the registry.
|
|
|
|
Args:
|
|
agent: The agent to register
|
|
"""
|
|
agent_type = agent.type
|
|
agent_id = getattr(agent, 'id', agent_type)
|
|
|
|
# Initialize enhanced agents with dependencies
|
|
if hasattr(agent, 'set_dependencies'):
|
|
agent.set_dependencies(
|
|
ai_service=self.ai_service,
|
|
document_handler=self.document_handler,
|
|
lucydom_interface=self.lucydom_interface
|
|
)
|
|
|
|
self.agents[agent_type] = agent
|
|
# Also register by ID if it's different from type
|
|
if agent_id != agent_type:
|
|
self.agents[agent_id] = agent
|
|
|
|
logger.debug(f"Agent '{agent.name}' (Type: {agent_type}, ID: {agent_id}) registered")
|
|
|
|
def get_agent(self, agent_identifier: str) -> Optional[BaseAgent]:
|
|
"""
|
|
Get an agent instance by ID or type.
|
|
|
|
Args:
|
|
agent_identifier: ID or type of the desired agent
|
|
|
|
Returns:
|
|
Agent instance or None if not found
|
|
"""
|
|
# Try to find directly by type
|
|
if agent_identifier in self.agents:
|
|
return self.agents[agent_identifier]
|
|
|
|
# If not found, try different name variants
|
|
variants = [
|
|
agent_identifier,
|
|
agent_identifier.replace('_agent', ''),
|
|
f"{agent_identifier}_agent"
|
|
]
|
|
|
|
for variant in variants:
|
|
if variant in self.agents:
|
|
return self.agents[variant]
|
|
|
|
logger.warning(f"Agent with identifier '{agent_identifier}' not found")
|
|
return None
|
|
|
|
def get_all_agents(self) -> Dict[str, BaseAgent]:
|
|
"""Get all registered agents."""
|
|
return self.agents
|
|
|
|
def get_agent_infos(self) -> List[Dict[str, Any]]:
|
|
"""Get information about all registered agents."""
|
|
agent_infos = []
|
|
# Only once per agent instance (since we register both by type and ID)
|
|
seen_agents = set()
|
|
for agent in self.agents.values():
|
|
if agent not in seen_agents:
|
|
agent_infos.append(agent.get_agent_info())
|
|
seen_agents.add(agent)
|
|
return agent_infos
|
|
|
|
def get_agent_by_format(self, required_format: str) -> Optional[BaseAgent]:
|
|
"""
|
|
Find an agent that can produce the required output format.
|
|
|
|
Args:
|
|
required_format: The required output format
|
|
|
|
Returns:
|
|
Agent that can produce the required format, or None if not found
|
|
"""
|
|
# Create mapping of result format -> agent for faster lookup
|
|
format_to_agent = {}
|
|
seen_agents = set()
|
|
|
|
for agent in self.agents.values():
|
|
if agent not in seen_agents:
|
|
# Get the agent's result format
|
|
agent_format = getattr(agent, 'result_format', None)
|
|
if agent_format:
|
|
format_to_agent[agent_format.lower()] = agent
|
|
seen_agents.add(agent)
|
|
|
|
# Try to find an exact match
|
|
if required_format.lower() in format_to_agent:
|
|
return format_to_agent[required_format.lower()]
|
|
|
|
# If no exact match, try to find a partial match
|
|
for fmt, agent in format_to_agent.items():
|
|
if required_format.lower() in fmt or fmt in required_format.lower():
|
|
return agent
|
|
|
|
# No match found
|
|
return None
|
|
|
|
def initialize_agents_for_workflow(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Initialize agents for a workflow."""
|
|
initialized_agents = {}
|
|
seen_agents = set()
|
|
for agent in self.agents.values():
|
|
if agent not in seen_agents:
|
|
agent_info = agent.get_agent_info()
|
|
agent_id = agent_info["id"]
|
|
initialized_agents[agent_id] = agent_info
|
|
seen_agents.add(agent)
|
|
return initialized_agents
|
|
|
|
def get_agent_capabilities(self) -> Dict[str, List[str]]:
|
|
"""
|
|
Get a mapping of capabilities to agents.
|
|
Useful for finding the right agent for a specific task.
|
|
|
|
Returns:
|
|
Dict mapping capability keywords to agent IDs
|
|
"""
|
|
capabilities_map = {}
|
|
seen_agents = set()
|
|
|
|
for agent in self.agents.values():
|
|
if agent not in seen_agents:
|
|
# Get agent info
|
|
agent_id = getattr(agent, 'id', agent.type)
|
|
|
|
# Extract capabilities - check for get_capabilities method first
|
|
if hasattr(agent, 'get_capabilities') and callable(getattr(agent, 'get_capabilities')):
|
|
capabilities = agent.get_capabilities()
|
|
else:
|
|
# Fall back to string parsing
|
|
capabilities_str = getattr(agent, 'capabilities', "")
|
|
capabilities = [kw.strip().lower() for kw in capabilities_str.split(',') if kw.strip()]
|
|
|
|
# Add each capability to the mapping
|
|
for capability in capabilities:
|
|
if capability not in capabilities_map:
|
|
capabilities_map[capability] = []
|
|
if agent_id not in capabilities_map[capability]:
|
|
capabilities_map[capability].append(agent_id)
|
|
|
|
seen_agents.add(agent)
|
|
|
|
return capabilities_map
|
|
|
|
def get_agent_by_capability(self, capability: str) -> Optional['BaseAgent']:
|
|
"""
|
|
Find an agent with a specific capability.
|
|
|
|
Args:
|
|
capability: The required capability
|
|
|
|
Returns:
|
|
Agent with the required capability, or None if not found
|
|
"""
|
|
# Create mapping of capabilities for faster lookup
|
|
capability_map = self.get_agent_capabilities()
|
|
|
|
# Look for the capability (case-insensitive)
|
|
capability = capability.lower()
|
|
matching_agents = []
|
|
|
|
# Direct match
|
|
if capability in capability_map:
|
|
matching_agents = capability_map[capability]
|
|
else:
|
|
# Partial matches
|
|
for cap, agents in capability_map.items():
|
|
if capability in cap or cap in capability:
|
|
matching_agents.extend(agents)
|
|
|
|
# Return the first matching agent
|
|
if matching_agents:
|
|
agent_id = matching_agents[0]
|
|
return self.get_agent(agent_id)
|
|
|
|
return None |