187 lines
8.2 KiB
Python
187 lines
8.2 KiB
Python
# learningEngine.py
|
|
# Learning engine for adaptive Dynamic mode
|
|
|
|
import logging
|
|
from typing import Dict, Any, List
|
|
from datetime import datetime, timezone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LearningEngine:
|
|
"""Learns from feedback and adapts future behavior"""
|
|
|
|
def __init__(self):
|
|
self.strategies = {}
|
|
self.feedbackHistory = []
|
|
|
|
def learnFromFeedback(self, feedback: Dict[str, Any], context: Any, intent: Dict[str, Any]):
|
|
"""Learns from feedback and updates strategies"""
|
|
try:
|
|
# Store feedback
|
|
self.feedbackHistory.append({
|
|
"feedback": feedback,
|
|
"context": self._serializeContext(context),
|
|
"intent": intent,
|
|
"timestamp": datetime.now(timezone.utc).timestamp()
|
|
})
|
|
|
|
# Update strategies based on feedback
|
|
self._updateStrategies(feedback, intent)
|
|
|
|
# Normalize scores for safe logging
|
|
_qs = feedback.get('qualityScore', 0.0)
|
|
_im = feedback.get('intentMatchScore', 0.0)
|
|
try:
|
|
_qs = float(0.0 if _qs is None else _qs)
|
|
except Exception:
|
|
_qs = 0.0
|
|
try:
|
|
_im = float(0.0 if _im is None else _im)
|
|
except Exception:
|
|
_im = 0.0
|
|
logger.info(
|
|
f"Learning from feedback: {feedback.get('actionAttempted', 'unknown')} - "
|
|
f"Quality: {_qs:.2f}, Intent Match: {_im:.2f}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error learning from feedback: {str(e)}")
|
|
|
|
def getImprovedStrategy(self, context: Any, intent: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Returns improved strategy based on learning"""
|
|
try:
|
|
# Get strategy key based on intent
|
|
strategyKey = self._getStrategyKey(intent)
|
|
|
|
# Get existing strategy or create default
|
|
if strategyKey in self.strategies:
|
|
strategy = self.strategies[strategyKey]
|
|
logger.info(f"Using learned strategy for {strategyKey}: {strategy}")
|
|
return strategy
|
|
else:
|
|
# Create default strategy
|
|
defaultStrategy = self._createDefaultStrategy(intent)
|
|
self.strategies[strategyKey] = defaultStrategy
|
|
logger.info(f"Created default strategy for {strategyKey}")
|
|
return defaultStrategy
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting improved strategy: {str(e)}")
|
|
return self._createDefaultStrategy(intent)
|
|
|
|
def _updateStrategies(self, feedback: Dict[str, Any], intent: Dict[str, Any]):
|
|
"""Updates strategies based on feedback"""
|
|
strategyKey = self._getStrategyKey(intent)
|
|
actionAttempted = feedback.get('actionAttempted', 'unknown')
|
|
# Coerce possibly None or non-numeric to floats
|
|
qs_raw = feedback.get('qualityScore', 0.0)
|
|
im_raw = feedback.get('intentMatchScore', 0.0)
|
|
try:
|
|
qualityScore = float(0.0 if qs_raw is None else qs_raw)
|
|
except Exception:
|
|
qualityScore = 0.0
|
|
try:
|
|
intentMatchScore = float(0.0 if im_raw is None else im_raw)
|
|
except Exception:
|
|
intentMatchScore = 0.0
|
|
|
|
# Get or create strategy
|
|
if strategyKey not in self.strategies:
|
|
self.strategies[strategyKey] = self._createDefaultStrategy(intent)
|
|
|
|
strategy = self.strategies[strategyKey]
|
|
|
|
# Update based on success/failure
|
|
if qualityScore > 0.7 and intentMatchScore > 0.7:
|
|
# Successful action - reinforce it
|
|
if 'successfulActions' not in strategy:
|
|
strategy['successfulActions'] = []
|
|
if actionAttempted not in strategy['successfulActions']:
|
|
strategy['successfulActions'].append(actionAttempted)
|
|
strategy['successRate'] = min(strategy.get('successRate', 0.5) + 0.1, 1.0)
|
|
logger.info(f"Reinforced successful action: {actionAttempted}")
|
|
|
|
elif qualityScore < 0.3 or intentMatchScore < 0.3:
|
|
# Failed action - avoid it
|
|
if 'failedActions' not in strategy:
|
|
strategy['failedActions'] = []
|
|
if actionAttempted not in strategy['failedActions']:
|
|
strategy['failedActions'].append(actionAttempted)
|
|
strategy['successRate'] = max(strategy.get('successRate', 0.5) - 0.1, 0.0)
|
|
logger.info(f"Marked failed action to avoid: {actionAttempted}")
|
|
|
|
# Update last modified
|
|
strategy['lastModified'] = datetime.now(timezone.utc).timestamp()
|
|
|
|
def _getStrategyKey(self, intent: Dict[str, Any]) -> str:
|
|
"""Gets strategy key based on intent"""
|
|
dataType = intent.get('dataType', 'unknown')
|
|
expectedFormat = intent.get('expectedFormat', 'unknown')
|
|
return f"{dataType}_{expectedFormat}"
|
|
|
|
def _createDefaultStrategy(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Creates a default strategy for the intent"""
|
|
dataType = intent.get('dataType', 'unknown')
|
|
expectedFormat = intent.get('expectedFormat', 'unknown')
|
|
|
|
# Create strategy based on intent type
|
|
if dataType == 'numbers':
|
|
return {
|
|
'strategyId': f"numbers_{expectedFormat}",
|
|
'successfulActions': [],
|
|
'failedActions': [],
|
|
'successRate': 0.5,
|
|
'lastModified': datetime.now(timezone.utc).timestamp(),
|
|
'recommendedPrompt': f"Deliver {dataType} data in {expectedFormat} format. Provide actual numbers, not code to generate them.",
|
|
'avoidPrompt': "Do not ask AI to write code when user wants data. Deliver the data directly."
|
|
}
|
|
elif dataType == 'text':
|
|
return {
|
|
'strategyId': f"text_{expectedFormat}",
|
|
'successfulActions': [],
|
|
'failedActions': [],
|
|
'successRate': 0.5,
|
|
'lastModified': datetime.now(timezone.utc).timestamp(),
|
|
'recommendedPrompt': f"Generate {dataType} content in {expectedFormat} format.",
|
|
'avoidPrompt': "Ensure content is readable and well-structured."
|
|
}
|
|
elif dataType == 'documents':
|
|
return {
|
|
'strategyId': f"documents_{expectedFormat}",
|
|
'successfulActions': [],
|
|
'failedActions': [],
|
|
'successRate': 0.5,
|
|
'lastModified': datetime.now(timezone.utc).timestamp(),
|
|
'recommendedPrompt': f"Create {dataType} in {expectedFormat} format with proper structure.",
|
|
'avoidPrompt': "Ensure document is properly formatted and organized."
|
|
}
|
|
else:
|
|
return {
|
|
'strategyId': f"unknown_{expectedFormat}",
|
|
'successfulActions': [],
|
|
'failedActions': [],
|
|
'successRate': 0.5,
|
|
'lastModified': datetime.now(timezone.utc).timestamp(),
|
|
'recommendedPrompt': f"Deliver {dataType} content in {expectedFormat} format.",
|
|
'avoidPrompt': "Ensure content matches user requirements."
|
|
}
|
|
|
|
def _serializeContext(self, context: Any) -> Dict[str, Any]:
|
|
"""Serializes context for storage"""
|
|
try:
|
|
return {
|
|
"taskObjective": getattr(context, 'taskStep', {}).get('objective', '') if hasattr(context, 'taskStep') else '',
|
|
"workflowId": getattr(context, 'workflowId', ''),
|
|
"availableDocuments": getattr(context, 'availableDocuments', [])
|
|
}
|
|
except Exception:
|
|
return {}
|
|
|
|
def getLearningSummary(self) -> Dict[str, Any]:
|
|
"""Gets a summary of what has been learned"""
|
|
return {
|
|
"totalStrategies": len(self.strategies),
|
|
"totalFeedback": len(self.feedbackHistory),
|
|
"strategies": list(self.strategies.keys()),
|
|
"averageSuccessRate": sum(s.get('successRate', 0) for s in self.strategies.values()) / max(len(self.strategies), 1)
|
|
}
|