275 lines
No EOL
11 KiB
Python
275 lines
No EOL
11 KiB
Python
from enum import Enum
|
|
from typing import Dict, List, Optional, Any, Literal
|
|
from datetime import datetime, UTC
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
from modules.interfaces.interfaceChatModel import ActionResult
|
|
from functools import wraps
|
|
import inspect
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def action(func):
|
|
"""Decorator to mark a method as an available action"""
|
|
@wraps(func)
|
|
async def wrapper(self, parameters: Dict[str, Any], *args, **kwargs):
|
|
return await func(self, parameters, *args, **kwargs)
|
|
wrapper.is_action = True
|
|
return wrapper
|
|
|
|
class MethodBase:
|
|
"""Base class for all methods"""
|
|
|
|
def __init__(self, serviceCenter: Any):
|
|
"""Initialize method with service center"""
|
|
self.service = serviceCenter
|
|
self.name: str
|
|
self.description: str
|
|
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
|
|
@property
|
|
def actions(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Dynamically collect all actions decorated with @action in the class."""
|
|
actions = {}
|
|
for attr_name in dir(self):
|
|
# Skip the actions property itself to avoid recursion
|
|
if attr_name == 'actions':
|
|
continue
|
|
try:
|
|
attr = getattr(self, attr_name)
|
|
if callable(attr) and getattr(attr, 'is_action', False):
|
|
sig = inspect.signature(attr)
|
|
params = {}
|
|
for param_name, param in sig.parameters.items():
|
|
if param_name not in ['self', 'parameters', 'authData']:
|
|
param_type = param.annotation if param.annotation != param.empty else Any
|
|
params[param_name] = {
|
|
'type': param_type,
|
|
'required': param.default == param.empty,
|
|
'description': None,
|
|
'default': param.default if param.default != param.empty else None
|
|
}
|
|
actions[attr_name] = {
|
|
'description': attr.__doc__ or '',
|
|
'parameters': params,
|
|
'method': attr
|
|
}
|
|
except (AttributeError, RecursionError):
|
|
# Skip attributes that cause issues
|
|
continue
|
|
return actions
|
|
|
|
def getActionSignature(self, actionName: str) -> str:
|
|
"""Get formatted action signature for AI prompt generation (detailed version)"""
|
|
if actionName not in self.actions:
|
|
return ""
|
|
|
|
action = self.actions[actionName]
|
|
paramList = []
|
|
|
|
# Extract detailed parameter information from docstring
|
|
docstring = action.get('description', '')
|
|
paramDescriptions, paramTypes = self._extractParameterDetails(docstring)
|
|
|
|
for paramName in paramDescriptions:
|
|
paramType = paramTypes.get(paramName, 'Any')
|
|
paramDesc = paramDescriptions.get(paramName, '')
|
|
# Mark required parameters with * if possible (not available from docstring, so omit)
|
|
if paramDesc:
|
|
paramList.append(f"{paramName}:{paramType} # {paramDesc}")
|
|
else:
|
|
paramList.append(f"{paramName}:{paramType}")
|
|
|
|
signature = f"{self.name}.{actionName}"
|
|
|
|
if paramList:
|
|
signature += f"({', '.join(paramList)})"
|
|
|
|
# Add return type and main description
|
|
returnType = "ActionResult"
|
|
mainDesc = self._extractMainDescription(docstring)
|
|
|
|
if mainDesc:
|
|
signature += f" -> {returnType} # {mainDesc}"
|
|
|
|
return signature
|
|
|
|
def _extractParameterDetails(self, docstring: str):
|
|
"""Extract parameter names, types, and descriptions from docstring"""
|
|
descriptions = {}
|
|
types = {}
|
|
if not docstring:
|
|
return descriptions, types
|
|
|
|
lines = docstring.split('\n')
|
|
inParameters = False
|
|
for line in lines:
|
|
line = line.strip()
|
|
if 'Parameters:' in line:
|
|
inParameters = True
|
|
continue
|
|
elif inParameters and (line.startswith('Returns:') or line.startswith('Raises:') or line.startswith('Args:')):
|
|
break
|
|
elif inParameters and line:
|
|
# Look for parameter descriptions like "paramName (type): description"
|
|
if ':' in line and '(' in line:
|
|
parts = line.split(':', 1)
|
|
if len(parts) == 2:
|
|
paramPart = parts[0].strip()
|
|
descPart = parts[1].strip()
|
|
# Extract parameter name and type
|
|
if '(' in paramPart:
|
|
paramName = paramPart.split('(')[0].strip()
|
|
paramType = paramPart[paramPart.find('(')+1:paramPart.find(')')].strip()
|
|
descriptions[paramName] = descPart
|
|
types[paramName] = paramType
|
|
# Also handle multi-line descriptions
|
|
elif line and not line.startswith('Each document') and not line.startswith('contains'):
|
|
if descriptions:
|
|
lastParam = list(descriptions.keys())[-1]
|
|
descriptions[lastParam] += " " + line
|
|
return descriptions, types
|
|
|
|
def _validateDocumentListParameter(self, parameters: Dict[str, Any], paramName: str = "documentList") -> bool:
|
|
"""Validate that documentList parameter is a list of strings"""
|
|
if paramName not in parameters:
|
|
return False
|
|
|
|
value = parameters[paramName]
|
|
if not isinstance(value, list):
|
|
return False
|
|
|
|
# Check that all items in the list are strings
|
|
return all(isinstance(item, str) for item in value)
|
|
|
|
def _extractMainDescription(self, docstring: str) -> str:
|
|
"""Extract main description from docstring"""
|
|
if not docstring:
|
|
return ""
|
|
|
|
lines = docstring.split('\n')
|
|
mainDesc = ""
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line and not line.startswith('Parameters:') and not line.startswith('Returns:') and not line.startswith('Raises:'):
|
|
mainDesc = line
|
|
break
|
|
|
|
return mainDesc
|
|
|
|
def _formatType(self, type_annotation) -> str:
|
|
"""Format type annotation for display"""
|
|
if type_annotation == Any:
|
|
return "Any"
|
|
elif hasattr(type_annotation, '__name__'):
|
|
return type_annotation.__name__
|
|
elif hasattr(type_annotation, '_name'):
|
|
return type_annotation._name
|
|
else:
|
|
return str(type_annotation)
|
|
|
|
async def execute(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> ActionResult:
|
|
"""
|
|
Execute method action with authentication data
|
|
|
|
Args:
|
|
action: The action to execute
|
|
parameters: Action parameters
|
|
authData: Authentication data
|
|
|
|
Returns:
|
|
ActionResult containing execution results
|
|
|
|
Raises:
|
|
ValueError: If action is not supported
|
|
RuntimeError: If authentication fails
|
|
"""
|
|
try:
|
|
# Validate action
|
|
if action not in self.actions:
|
|
raise ValueError(f"Unsupported action: {action}")
|
|
|
|
# Validate parameters
|
|
if not await self.validateParameters(action, parameters):
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error="Invalid parameters"
|
|
)
|
|
|
|
# Validate authentication
|
|
if not self._validateAuth(authData):
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error="Authentication failed"
|
|
)
|
|
|
|
# Execute action
|
|
return await self._executeAction(action, parameters, authData)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error executing action {action}: {str(e)}")
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error=str(e)
|
|
)
|
|
|
|
async def _executeAction(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> ActionResult:
|
|
"""Execute specific action - to be implemented by subclasses"""
|
|
raise NotImplementedError
|
|
|
|
async def validateParameters(self, action: str, parameters: Dict[str, Any]) -> bool:
|
|
"""Validate action parameters"""
|
|
try:
|
|
if action not in self.actions:
|
|
return False
|
|
|
|
actionDef = self.actions[action]
|
|
requiredParams = {k for k, v in actionDef['parameters'].items() if v['required']}
|
|
|
|
# Check required parameters
|
|
if not all(param in parameters for param in requiredParams):
|
|
return False
|
|
|
|
# Validate documentList parameter if present
|
|
if "documentList" in parameters:
|
|
if not self._validateDocumentListParameter(parameters, "documentList"):
|
|
self.logger.error("documentList parameter must be a list of strings")
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error validating parameters: {str(e)}")
|
|
return False
|
|
|
|
async def rollback(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Rollback action if needed"""
|
|
try:
|
|
await self._rollbackAction(action, parameters, authData)
|
|
except Exception as e:
|
|
self.logger.error(f"Error rolling back action {action}: {str(e)}")
|
|
raise
|
|
|
|
async def _rollbackAction(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Rollback specific action - to be implemented by subclasses"""
|
|
pass
|
|
|
|
def _createResult(self, success: bool, data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None, error: Optional[str] = None) -> ActionResult:
|
|
"""Create a method result"""
|
|
return ActionResult(
|
|
success=success,
|
|
data=data,
|
|
metadata=metadata or {},
|
|
validation={},
|
|
error=error
|
|
)
|
|
|
|
def _addValidationMessage(self, result: ActionResult, message: str) -> None:
|
|
"""Add a validation message to the result"""
|
|
if 'messages' not in result.validation:
|
|
result.validation['messages'] = []
|
|
result.validation['messages'].append(message) |