148 lines
No EOL
5.4 KiB
Python
148 lines
No EOL
5.4 KiB
Python
from enum import Enum
|
|
from typing import Dict, List, Optional, Any, Literal
|
|
from datetime import datetime, UTC
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AuthSource(str, Enum):
|
|
"""Authentication source enumeration"""
|
|
LOCAL = "local"
|
|
MSFT = "msft"
|
|
GOOGLE = "google"
|
|
# Add more auth sources as needed
|
|
|
|
class MethodParameter(BaseModel):
|
|
"""Model for method parameters"""
|
|
name: str
|
|
type: str
|
|
required: bool
|
|
validation: Optional[callable] = None
|
|
description: str
|
|
|
|
class MethodResult(BaseModel):
|
|
"""Model for method results"""
|
|
success: bool
|
|
data: Dict[str, Any]
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
validation: List[str] = Field(default_factory=list)
|
|
error: Optional[str] = Field(None, description="Error message if any")
|
|
|
|
class MethodBase:
|
|
"""Base class for all methods"""
|
|
|
|
def __init__(self, serviceContainer: Any):
|
|
"""Initialize method with service container"""
|
|
self.service = serviceContainer
|
|
self.name: str
|
|
self.description: str
|
|
self.authSource: AuthSource = AuthSource.LOCAL # Default to local auth
|
|
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
|
|
@property
|
|
def actions(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Available actions and their parameters"""
|
|
raise NotImplementedError
|
|
|
|
async def execute(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> MethodResult:
|
|
"""
|
|
Execute method action with authentication data
|
|
|
|
Args:
|
|
action: The action to execute
|
|
parameters: Action parameters
|
|
authData: Authentication data
|
|
|
|
Returns:
|
|
MethodResult containing execution results
|
|
|
|
Raises:
|
|
ValueError: If action is not supported
|
|
RuntimeError: If authentication fails
|
|
"""
|
|
try:
|
|
# Validate action
|
|
if action not in self.actions:
|
|
raise ValueError(f"Unsupported action: {action}")
|
|
|
|
# Validate parameters
|
|
if not await self.validateParameters(action, parameters):
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error="Invalid parameters"
|
|
)
|
|
|
|
# Validate authentication
|
|
if not self._validateAuth(authData):
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error="Authentication failed"
|
|
)
|
|
|
|
# Execute action
|
|
return await self._executeAction(action, parameters, authData)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error executing action {action}: {str(e)}")
|
|
return self._createResult(
|
|
success=False,
|
|
data={},
|
|
error=str(e)
|
|
)
|
|
|
|
async def _executeAction(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> MethodResult:
|
|
"""Execute specific action - to be implemented by subclasses"""
|
|
raise NotImplementedError
|
|
|
|
async def validateParameters(self, action: str, parameters: Dict[str, Any]) -> bool:
|
|
"""Validate action parameters"""
|
|
try:
|
|
if action not in self.actions:
|
|
return False
|
|
|
|
actionDef = self.actions[action]
|
|
requiredParams = {k for k, v in actionDef['parameters'].items() if v['required']}
|
|
return all(param in parameters for param in requiredParams)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error validating parameters: {str(e)}")
|
|
return False
|
|
|
|
async def rollback(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Rollback action if needed"""
|
|
try:
|
|
await self._rollbackAction(action, parameters, authData)
|
|
except Exception as e:
|
|
self.logger.error(f"Error rolling back action {action}: {str(e)}")
|
|
raise
|
|
|
|
async def _rollbackAction(self, action: str, parameters: Dict[str, Any], authData: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Rollback specific action - to be implemented by subclasses"""
|
|
pass
|
|
|
|
def _validateAuth(self, authData: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""Validate authentication data"""
|
|
try:
|
|
if self.authSource == AuthSource.LOCAL:
|
|
return True
|
|
return bool(authData and authData.get('source') == self.authSource)
|
|
except Exception as e:
|
|
self.logger.error(f"Error validating auth: {str(e)}")
|
|
return False
|
|
|
|
def _createResult(self, success: bool, data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None, error: Optional[str] = None) -> MethodResult:
|
|
"""Create a method result"""
|
|
return MethodResult(
|
|
success=success,
|
|
data=data,
|
|
metadata=metadata or {},
|
|
validation=[],
|
|
error=error
|
|
)
|
|
|
|
def _addValidationMessage(self, result: MethodResult, message: str) -> None:
|
|
"""Add a validation message to the result"""
|
|
result.validation.append(message) |