182 lines
No EOL
6.8 KiB
Python
182 lines
No EOL
6.8 KiB
Python
from typing import Dict, Any, List, Optional
|
|
import logging
|
|
import json
|
|
from datetime import datetime, UTC
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AIPromptManager:
|
|
"""Manages AI prompts and response validation"""
|
|
|
|
def __init__(self):
|
|
self.prompt_templates = {}
|
|
self.response_schemas = {}
|
|
self._load_templates()
|
|
|
|
def _load_templates(self) -> None:
|
|
"""Load prompt templates and schemas"""
|
|
# Basic templates
|
|
self.prompt_templates = {
|
|
"task_analysis": {
|
|
"template": """Analyze the following task and determine required actions:
|
|
Task: {task}
|
|
Context: {context}
|
|
Available Methods: {methods}
|
|
|
|
Please provide:
|
|
1. Main objective
|
|
2. Required actions
|
|
3. Required data sources
|
|
4. Document processing requirements
|
|
5. Expected output format
|
|
|
|
Format your response as JSON:
|
|
{{
|
|
"objective": "string",
|
|
"actions": [
|
|
{{
|
|
"method": "string",
|
|
"action": "string",
|
|
"parameters": {{
|
|
"param1": "value1"
|
|
}}
|
|
}}
|
|
],
|
|
"dataSources": ["string"],
|
|
"documentRequirements": ["string"],
|
|
"outputFormat": "string"
|
|
}}
|
|
""",
|
|
"schema": {
|
|
"type": "object",
|
|
"required": ["objective", "actions"],
|
|
"properties": {
|
|
"objective": {"type": "string"},
|
|
"actions": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"required": ["method", "action"],
|
|
"properties": {
|
|
"method": {"type": "string"},
|
|
"action": {"type": "string"},
|
|
"parameters": {"type": "object"}
|
|
}
|
|
}
|
|
},
|
|
"dataSources": {
|
|
"type": "array",
|
|
"items": {"type": "string"}
|
|
},
|
|
"documentRequirements": {
|
|
"type": "array",
|
|
"items": {"type": "string"}
|
|
},
|
|
"outputFormat": {"type": "string"}
|
|
}
|
|
}
|
|
},
|
|
"result_analysis": {
|
|
"template": """Analyze the following task results and determine next steps:
|
|
Task Results: {results}
|
|
Workflow History: {history}
|
|
|
|
Please provide:
|
|
1. Task completion status
|
|
2. Next required actions
|
|
3. Required documents
|
|
4. Method recommendations
|
|
|
|
Format your response as JSON:
|
|
{{
|
|
"isComplete": boolean,
|
|
"nextActions": ["string"],
|
|
"requiredDocuments": ["string"],
|
|
"recommendedMethods": ["string"]
|
|
}}
|
|
""",
|
|
"schema": {
|
|
"type": "object",
|
|
"required": ["isComplete"],
|
|
"properties": {
|
|
"isComplete": {"type": "boolean"},
|
|
"nextActions": {
|
|
"type": "array",
|
|
"items": {"type": "string"}
|
|
},
|
|
"requiredDocuments": {
|
|
"type": "array",
|
|
"items": {"type": "string"}
|
|
},
|
|
"recommendedMethods": {
|
|
"type": "array",
|
|
"items": {"type": "string"}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
def generate_prompt(self, context: Dict[str, Any], examples: List[Dict]) -> str:
|
|
"""Generate a context-aware prompt with few-shot examples"""
|
|
try:
|
|
# Get template
|
|
template = self.prompt_templates.get(context.get("type", "task_analysis"))
|
|
if not template:
|
|
raise ValueError(f"Unknown prompt type: {context.get('type')}")
|
|
|
|
# Format prompt
|
|
prompt = template["template"].format(
|
|
task=context.get("task", ""),
|
|
context=json.dumps(context.get("context", {}), indent=2),
|
|
methods=json.dumps(context.get("methods", {}), indent=2),
|
|
results=json.dumps(context.get("results", {}), indent=2),
|
|
history=json.dumps(context.get("history", []), indent=2)
|
|
)
|
|
|
|
# Add examples if provided
|
|
if examples:
|
|
prompt += "\nExamples:\n"
|
|
for ex in examples:
|
|
prompt += f"- {ex['input']} => {ex['output']}\n"
|
|
|
|
return prompt
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating prompt: {str(e)}")
|
|
raise
|
|
|
|
def validate_response(self, response: str, schema: Dict) -> bool:
|
|
"""Validate AI response against a schema"""
|
|
try:
|
|
# Parse response
|
|
if isinstance(response, str):
|
|
try:
|
|
response = json.loads(response)
|
|
except json.JSONDecodeError:
|
|
return False
|
|
|
|
# Validate against schema
|
|
import jsonschema
|
|
jsonschema.validate(instance=response, schema=schema)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating response: {str(e)}")
|
|
return False
|
|
|
|
def get_schema(self, prompt_type: str) -> Optional[Dict]:
|
|
"""Get schema for prompt type"""
|
|
template = self.prompt_templates.get(prompt_type)
|
|
return template.get("schema") if template else None
|
|
|
|
def add_template(self, name: str, template: str, schema: Dict) -> None:
|
|
"""Add new prompt template"""
|
|
self.prompt_templates[name] = {
|
|
"template": template,
|
|
"schema": schema
|
|
}
|
|
|
|
def remove_template(self, name: str) -> None:
|
|
"""Remove prompt template"""
|
|
self.prompt_templates.pop(name, None) |