Merge pull request #144 from valueonag/feat/demo-system-readieness
Feat/demo system readieness
This commit is contained in:
commit
e8abd553d0
14 changed files with 1180 additions and 20 deletions
3
app.py
3
app.py
|
|
@ -602,6 +602,9 @@ app.include_router(googleRouter)
|
||||||
from modules.routes.routeSecurityClickup import router as clickupRouter
|
from modules.routes.routeSecurityClickup import router as clickupRouter
|
||||||
app.include_router(clickupRouter)
|
app.include_router(clickupRouter)
|
||||||
|
|
||||||
|
from modules.routes.routeSecurityInfomaniak import router as infomaniakRouter
|
||||||
|
app.include_router(infomaniakRouter)
|
||||||
|
|
||||||
from modules.routes.routeClickup import router as clickupApiRouter
|
from modules.routes.routeClickup import router as clickupApiRouter
|
||||||
app.include_router(clickupApiRouter)
|
app.include_router(clickupApiRouter)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,3 +40,16 @@ msftDataScopes = [
|
||||||
def msftDataScopesForRefresh() -> str:
|
def msftDataScopesForRefresh() -> str:
|
||||||
"""Space-separated scope string identical to authorization request (Token v2 refresh)."""
|
"""Space-separated scope string identical to authorization request (Token v2 refresh)."""
|
||||||
return " ".join(msftDataScopes)
|
return " ".join(msftDataScopes)
|
||||||
|
|
||||||
|
|
||||||
|
# Infomaniak — Data app (kDrive + Mail; user_info needed for /1/profile lookup)
|
||||||
|
infomaniakDataScopes = [
|
||||||
|
"user_info",
|
||||||
|
"kdrive",
|
||||||
|
"mail",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def infomaniakDataScopesForRefresh() -> str:
|
||||||
|
"""Space-separated scope string identical to authorization request."""
|
||||||
|
return " ".join(infomaniakDataScopes)
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from modules.datamodels.datamodelSecurity import Token, TokenPurpose
|
||||||
from modules.datamodels.datamodelUam import AuthAuthority
|
from modules.datamodels.datamodelUam import AuthAuthority
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
from modules.shared.timeUtils import getUtcTimestamp, createExpirationTimestamp, parseTimestamp
|
from modules.shared.timeUtils import getUtcTimestamp, createExpirationTimestamp, parseTimestamp
|
||||||
from modules.auth.oauthProviderConfig import msftDataScopesForRefresh
|
from modules.auth.oauthProviderConfig import msftDataScopesForRefresh, infomaniakDataScopesForRefresh
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -29,6 +29,10 @@ class TokenManager:
|
||||||
# Google Data-app OAuth
|
# Google Data-app OAuth
|
||||||
self.google_client_id = APP_CONFIG.get("Service_GOOGLE_DATA_CLIENT_ID")
|
self.google_client_id = APP_CONFIG.get("Service_GOOGLE_DATA_CLIENT_ID")
|
||||||
self.google_client_secret = APP_CONFIG.get("Service_GOOGLE_DATA_CLIENT_SECRET")
|
self.google_client_secret = APP_CONFIG.get("Service_GOOGLE_DATA_CLIENT_SECRET")
|
||||||
|
|
||||||
|
# Infomaniak Data OAuth (kDrive + Mail)
|
||||||
|
self.infomaniak_client_id = APP_CONFIG.get("Service_INFOMANIAK_DATA_CLIENT_ID")
|
||||||
|
self.infomaniak_client_secret = APP_CONFIG.get("Service_INFOMANIAK_DATA_CLIENT_SECRET")
|
||||||
|
|
||||||
def refreshMicrosoftToken(self, refreshToken: str, userId: str, oldToken: Token) -> Optional[Token]:
|
def refreshMicrosoftToken(self, refreshToken: str, userId: str, oldToken: Token) -> Optional[Token]:
|
||||||
"""Refresh Microsoft OAuth token using refresh token"""
|
"""Refresh Microsoft OAuth token using refresh token"""
|
||||||
|
|
@ -161,7 +165,66 @@ class TokenManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error refreshing Google token: {str(e)}")
|
logger.error(f"Error refreshing Google token: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def refreshInfomaniakToken(self, refreshToken: str, userId: str, oldToken: Token) -> Optional[Token]:
|
||||||
|
"""Refresh Infomaniak OAuth token using refresh token"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"refreshInfomaniakToken: Starting Infomaniak token refresh for user {userId}")
|
||||||
|
|
||||||
|
if not self.infomaniak_client_id or not self.infomaniak_client_secret:
|
||||||
|
logger.error("Infomaniak OAuth configuration not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokenUrl = "https://login.infomaniak.com/token"
|
||||||
|
data = {
|
||||||
|
"client_id": self.infomaniak_client_id,
|
||||||
|
"client_secret": self.infomaniak_client_secret,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
"scope": infomaniakDataScopesForRefresh(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.post(tokenUrl, data=data)
|
||||||
|
logger.debug(f"refreshInfomaniakToken: HTTP response status: {response.status_code}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
tokenData = response.json()
|
||||||
|
if "access_token" not in tokenData:
|
||||||
|
logger.error("Infomaniak token refresh response missing access_token")
|
||||||
|
return None
|
||||||
|
|
||||||
|
newToken = Token(
|
||||||
|
userId=userId,
|
||||||
|
authority=AuthAuthority.INFOMANIAK,
|
||||||
|
connectionId=oldToken.connectionId,
|
||||||
|
tokenPurpose=TokenPurpose.DATA_CONNECTION,
|
||||||
|
tokenAccess=tokenData["access_token"],
|
||||||
|
tokenRefresh=tokenData.get("refresh_token", refreshToken),
|
||||||
|
tokenType=tokenData.get("token_type", "bearer"),
|
||||||
|
expiresAt=createExpirationTimestamp(tokenData.get("expires_in", 3600)),
|
||||||
|
createdAt=getUtcTimestamp(),
|
||||||
|
)
|
||||||
|
return newToken
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Failed to refresh Infomaniak token: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
if response.status_code == 400:
|
||||||
|
try:
|
||||||
|
errorData = response.json()
|
||||||
|
if errorData.get("error") == "invalid_grant":
|
||||||
|
logger.warning(
|
||||||
|
"Infomaniak refresh token is invalid or expired - user needs to re-authenticate"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing Infomaniak token: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
def refreshToken(self, oldToken: Token) -> Optional[Token]:
|
def refreshToken(self, oldToken: Token) -> Optional[Token]:
|
||||||
"""Refresh an expired token using the appropriate OAuth service"""
|
"""Refresh an expired token using the appropriate OAuth service"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -205,6 +268,9 @@ class TokenManager:
|
||||||
elif oldToken.authority == AuthAuthority.GOOGLE:
|
elif oldToken.authority == AuthAuthority.GOOGLE:
|
||||||
logger.debug(f"refreshToken: Refreshing Google token")
|
logger.debug(f"refreshToken: Refreshing Google token")
|
||||||
return self.refreshGoogleToken(oldToken.tokenRefresh, oldToken.userId, oldToken)
|
return self.refreshGoogleToken(oldToken.tokenRefresh, oldToken.userId, oldToken)
|
||||||
|
elif oldToken.authority == AuthAuthority.INFOMANIAK:
|
||||||
|
logger.debug(f"refreshToken: Refreshing Infomaniak token")
|
||||||
|
return self.refreshInfomaniakToken(oldToken.tokenRefresh, oldToken.userId, oldToken)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown authority for token refresh: {oldToken.authority}")
|
logger.warning(f"Unknown authority for token refresh: {oldToken.authority}")
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,46 @@ class TokenRefreshService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error refreshing Microsoft token for connection {connection.id}: {str(e)}")
|
logger.error(f"Error refreshing Microsoft token for connection {connection.id}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _refresh_infomaniak_token(self, interface, connection: UserConnection) -> bool:
|
||||||
|
"""Refresh Infomaniak OAuth token"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Refreshing Infomaniak token for connection {connection.id}")
|
||||||
|
|
||||||
|
current_token = interface.getConnectionToken(connection.id)
|
||||||
|
if not current_token:
|
||||||
|
logger.warning(f"No Infomaniak token found for connection {connection.id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
from modules.auth.tokenManager import TokenManager
|
||||||
|
token_manager = TokenManager()
|
||||||
|
|
||||||
|
refreshedToken = token_manager.refreshToken(current_token)
|
||||||
|
if refreshedToken:
|
||||||
|
interface.saveConnectionToken(refreshedToken)
|
||||||
|
interface.db.recordModify(UserConnection, connection.id, {
|
||||||
|
"lastChecked": getUtcTimestamp(),
|
||||||
|
"expiresAt": refreshedToken.expiresAt,
|
||||||
|
})
|
||||||
|
logger.info(f"Successfully refreshed Infomaniak token for connection {connection.id}")
|
||||||
|
try:
|
||||||
|
audit_logger.logSecurityEvent(
|
||||||
|
userId=str(connection.userId),
|
||||||
|
mandateId="system",
|
||||||
|
action="token_refresh",
|
||||||
|
details=f"Infomaniak token refreshed for connection {connection.id}",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.warning(f"Failed to refresh Infomaniak token for connection {connection.id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing Infomaniak token for connection {connection.id}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def refresh_expired_tokens(self, user_id: str) -> Dict[str, Any]:
|
async def refresh_expired_tokens(self, user_id: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Refresh expired OAuth tokens for a user
|
Refresh expired OAuth tokens for a user
|
||||||
|
|
@ -177,7 +216,7 @@ class TokenRefreshService:
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
# Only refresh expired OAuth connections
|
# Only refresh expired OAuth connections
|
||||||
if (connection.tokenStatus == 'expired' and
|
if (connection.tokenStatus == 'expired' and
|
||||||
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT]):
|
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT, AuthAuthority.INFOMANIAK]):
|
||||||
|
|
||||||
# Check rate limiting
|
# Check rate limiting
|
||||||
if self._is_rate_limited(connection.id):
|
if self._is_rate_limited(connection.id):
|
||||||
|
|
@ -194,6 +233,8 @@ class TokenRefreshService:
|
||||||
success = await self._refresh_google_token(root_interface, connection)
|
success = await self._refresh_google_token(root_interface, connection)
|
||||||
elif connection.authority == AuthAuthority.MSFT:
|
elif connection.authority == AuthAuthority.MSFT:
|
||||||
success = await self._refresh_microsoft_token(root_interface, connection)
|
success = await self._refresh_microsoft_token(root_interface, connection)
|
||||||
|
elif connection.authority == AuthAuthority.INFOMANIAK:
|
||||||
|
success = await self._refresh_infomaniak_token(root_interface, connection)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
refreshed_count += 1
|
refreshed_count += 1
|
||||||
|
|
@ -248,7 +289,7 @@ class TokenRefreshService:
|
||||||
# Only refresh active tokens that expire soon
|
# Only refresh active tokens that expire soon
|
||||||
if (connection.tokenStatus == 'active' and
|
if (connection.tokenStatus == 'active' and
|
||||||
connection.tokenExpiresAt and
|
connection.tokenExpiresAt and
|
||||||
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT]):
|
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT, AuthAuthority.INFOMANIAK]):
|
||||||
|
|
||||||
# Check if token expires within 5 minutes
|
# Check if token expires within 5 minutes
|
||||||
time_until_expiry = connection.tokenExpiresAt - current_time
|
time_until_expiry = connection.tokenExpiresAt - current_time
|
||||||
|
|
@ -269,6 +310,8 @@ class TokenRefreshService:
|
||||||
success = await self._refresh_google_token(root_interface, connection)
|
success = await self._refresh_google_token(root_interface, connection)
|
||||||
elif connection.authority == AuthAuthority.MSFT:
|
elif connection.authority == AuthAuthority.MSFT:
|
||||||
success = await self._refresh_microsoft_token(root_interface, connection)
|
success = await self._refresh_microsoft_token(root_interface, connection)
|
||||||
|
elif connection.authority == AuthAuthority.INFOMANIAK:
|
||||||
|
success = await self._refresh_infomaniak_token(root_interface, connection)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
refreshed_count += 1
|
refreshed_count += 1
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,12 @@ class ConnectorResolver:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("ClickupConnector not available")
|
logger.warning("ClickupConnector not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from modules.connectors.providerInfomaniak.connectorInfomaniak import InfomaniakConnector
|
||||||
|
ConnectorResolver._providerRegistry["infomaniak"] = InfomaniakConnector
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("InfomaniakConnector not available")
|
||||||
|
|
||||||
async def resolve(self, connectionId: str) -> ProviderConnector:
|
async def resolve(self, connectionId: str) -> ProviderConnector:
|
||||||
"""Resolve connectionId to a ProviderConnector with a fresh access token."""
|
"""Resolve connectionId to a ProviderConnector with a fresh access token."""
|
||||||
connection = await self._loadConnection(connectionId)
|
connection = await self._loadConnection(connectionId)
|
||||||
|
|
|
||||||
3
modules/connectors/providerInfomaniak/__init__.py
Normal file
3
modules/connectors/providerInfomaniak/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Infomaniak Provider Connector -- 1 Connection : n Services (kDrive, Mail)."""
|
||||||
420
modules/connectors/providerInfomaniak/connectorInfomaniak.py
Normal file
420
modules/connectors/providerInfomaniak/connectorInfomaniak.py
Normal file
|
|
@ -0,0 +1,420 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Infomaniak ProviderConnector -- kDrive and Mail via Infomaniak OAuth.
|
||||||
|
|
||||||
|
All ServiceAdapters share the same OAuth access token obtained from the
|
||||||
|
UserConnection (authority=infomaniak).
|
||||||
|
|
||||||
|
Path conventions (leading slash):
|
||||||
|
kDrive:
|
||||||
|
/ -- list drives the user has access to
|
||||||
|
/{driveId} -- root folder of a drive (children)
|
||||||
|
/{driveId}/{fileId} -- folder children OR file (download)
|
||||||
|
Mail:
|
||||||
|
/ -- list user's mailboxes
|
||||||
|
/{mailboxId} -- folders in mailbox
|
||||||
|
/{mailboxId}/{folderId} -- messages in folder
|
||||||
|
/{mailboxId}/{folderId}/{uid} -- single message (download as .eml)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from modules.connectors.connectorProviderBase import (
|
||||||
|
ProviderConnector,
|
||||||
|
ServiceAdapter,
|
||||||
|
DownloadResult,
|
||||||
|
)
|
||||||
|
from modules.datamodels.datamodelDataSource import ExternalEntry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_API_BASE = "https://api.infomaniak.com"
|
||||||
|
|
||||||
|
|
||||||
|
async def _infomaniakGet(token: str, endpoint: str) -> Dict[str, Any]:
|
||||||
|
"""Single GET call against the Infomaniak API. Returns parsed JSON or {'error': ...}."""
|
||||||
|
url = f"{_API_BASE}/{endpoint.lstrip('/')}"
|
||||||
|
headers = {"Authorization": f"Bearer {token}", "Accept": "application/json"}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=20)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status in (200, 201):
|
||||||
|
return await resp.json()
|
||||||
|
errorText = await resp.text()
|
||||||
|
logger.warning(f"Infomaniak API {resp.status}: {errorText[:300]}")
|
||||||
|
return {"error": f"{resp.status}: {errorText[:200]}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
async def _infomaniakDownload(token: str, endpoint: str) -> Optional[bytes]:
|
||||||
|
"""Binary download from the Infomaniak API. Returns bytes or None on error."""
|
||||||
|
url = f"{_API_BASE}/{endpoint.lstrip('/')}"
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=120)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.read()
|
||||||
|
logger.warning(f"Infomaniak download {resp.status}: {(await resp.text())[:300]}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Infomaniak download error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrapData(payload: Any) -> Any:
|
||||||
|
"""Infomaniak wraps successful responses as ``{result: 'success', data: ...}``."""
|
||||||
|
if isinstance(payload, dict) and "data" in payload and "result" in payload:
|
||||||
|
return payload.get("data")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
class KdriveAdapter(ServiceAdapter):
|
||||||
|
"""kDrive ServiceAdapter -- browse drives, folders, and files."""
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._token = accessToken
|
||||||
|
|
||||||
|
async def browse(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
segments = [s for s in cleanPath.split("/") if s]
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
return await self._listDrives()
|
||||||
|
|
||||||
|
driveId = segments[0]
|
||||||
|
if len(segments) == 1:
|
||||||
|
return await self._listChildren(driveId, fileId=None, limit=limit)
|
||||||
|
|
||||||
|
fileId = segments[-1]
|
||||||
|
return await self._listChildren(driveId, fileId=fileId, limit=limit)
|
||||||
|
|
||||||
|
async def _listDrives(self) -> List[ExternalEntry]:
|
||||||
|
result = await _infomaniakGet(self._token, "/2/drive")
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
logger.warning(f"kDrive list-drives failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
drives = data.get("drives", {}).get("accounts", []) if isinstance(data, dict) else []
|
||||||
|
if not drives and isinstance(data, list):
|
||||||
|
drives = data
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for drive in drives:
|
||||||
|
driveId = str(drive.get("id", ""))
|
||||||
|
if not driveId:
|
||||||
|
continue
|
||||||
|
name = drive.get("name") or driveId
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=name,
|
||||||
|
path=f"/{driveId}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": driveId, "kind": "drive"},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def _listChildren(
|
||||||
|
self,
|
||||||
|
driveId: str,
|
||||||
|
fileId: Optional[str],
|
||||||
|
limit: Optional[int],
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
# Infomaniak treats every folder (including drive root) as a file-id.
|
||||||
|
# When fileId is None, we ask the drive for root children via the
|
||||||
|
# documented `/files` collection endpoint.
|
||||||
|
if fileId is None:
|
||||||
|
endpoint = f"/2/drive/{driveId}/files"
|
||||||
|
else:
|
||||||
|
endpoint = f"/2/drive/{driveId}/files/{fileId}/files"
|
||||||
|
|
||||||
|
pageSize = max(1, min(int(limit or 200), 1000))
|
||||||
|
endpoint = f"{endpoint}?per_page={pageSize}"
|
||||||
|
|
||||||
|
result = await _infomaniakGet(self._token, endpoint)
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
logger.warning(f"kDrive list-children {driveId}/{fileId or 'root'} failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
items = data if isinstance(data, list) else data.get("items", []) if isinstance(data, dict) else []
|
||||||
|
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for item in items:
|
||||||
|
itemId = str(item.get("id", ""))
|
||||||
|
if not itemId:
|
||||||
|
continue
|
||||||
|
isFolder = item.get("type") == "dir"
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=item.get("name", itemId),
|
||||||
|
path=f"/{driveId}/{itemId}",
|
||||||
|
isFolder=isFolder,
|
||||||
|
size=item.get("size") if not isFolder else None,
|
||||||
|
mimeType=item.get("mime_type") if not isFolder else None,
|
||||||
|
lastModified=item.get("last_modified_at"),
|
||||||
|
metadata={"id": itemId, "kind": item.get("type", "")},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def download(self, path: str) -> DownloadResult:
|
||||||
|
segments = [s for s in (path or "").strip("/").split("/") if s]
|
||||||
|
if len(segments) < 2:
|
||||||
|
return DownloadResult()
|
||||||
|
driveId, fileId = segments[0], segments[-1]
|
||||||
|
|
||||||
|
meta = await _infomaniakGet(self._token, f"/2/drive/{driveId}/files/{fileId}")
|
||||||
|
fileName = fileId
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
if isinstance(meta, dict) and not meta.get("error"):
|
||||||
|
data = _unwrapData(meta)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
fileName = data.get("name") or fileId
|
||||||
|
mimeType = data.get("mime_type") or mimeType
|
||||||
|
|
||||||
|
content = await _infomaniakDownload(self._token, f"/2/drive/{driveId}/files/{fileId}/download")
|
||||||
|
if content is None:
|
||||||
|
return DownloadResult()
|
||||||
|
return DownloadResult(data=content, fileName=fileName, mimeType=mimeType)
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "kDrive upload not yet implemented"}
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
segments = [s for s in (path or "").strip("/").split("/") if s]
|
||||||
|
if not segments:
|
||||||
|
drives = await self._listDrives()
|
||||||
|
if not drives:
|
||||||
|
return []
|
||||||
|
driveId = (drives[0].metadata or {}).get("id") or drives[0].path.strip("/")
|
||||||
|
else:
|
||||||
|
driveId = segments[0]
|
||||||
|
|
||||||
|
pageSize = max(1, min(int(limit or 50), 200))
|
||||||
|
endpoint = f"/2/drive/{driveId}/files/search?query={query}&per_page={pageSize}"
|
||||||
|
result = await _infomaniakGet(self._token, endpoint)
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
items = data if isinstance(data, list) else data.get("items", []) if isinstance(data, dict) else []
|
||||||
|
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for item in items:
|
||||||
|
itemId = str(item.get("id", ""))
|
||||||
|
if not itemId:
|
||||||
|
continue
|
||||||
|
isFolder = item.get("type") == "dir"
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=item.get("name", itemId),
|
||||||
|
path=f"/{driveId}/{itemId}",
|
||||||
|
isFolder=isFolder,
|
||||||
|
size=item.get("size") if not isFolder else None,
|
||||||
|
mimeType=item.get("mime_type") if not isFolder else None,
|
||||||
|
metadata={"id": itemId},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
class MailAdapter(ServiceAdapter):
|
||||||
|
"""Infomaniak Mail ServiceAdapter -- browse mailboxes, folders and messages."""
|
||||||
|
|
||||||
|
_DEFAULT_MESSAGE_LIMIT = 100
|
||||||
|
_MAX_MESSAGE_LIMIT = 500
|
||||||
|
|
||||||
|
def __init__(self, accessToken: str):
|
||||||
|
self._token = accessToken
|
||||||
|
|
||||||
|
async def browse(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
cleanPath = (path or "").strip("/")
|
||||||
|
segments = [s for s in cleanPath.split("/") if s]
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
return await self._listMailboxes()
|
||||||
|
if len(segments) == 1:
|
||||||
|
return await self._listFolders(segments[0])
|
||||||
|
if len(segments) == 2:
|
||||||
|
return await self._listMessages(segments[0], segments[1], limit=limit)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _listMailboxes(self) -> List[ExternalEntry]:
|
||||||
|
result = await _infomaniakGet(self._token, "/1/mail")
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
logger.warning(f"Mail list-mailboxes failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
mailboxes = data if isinstance(data, list) else data.get("mailboxes", []) if isinstance(data, dict) else []
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for mb in mailboxes:
|
||||||
|
mbId = str(mb.get("id") or mb.get("mailbox_id") or "")
|
||||||
|
if not mbId:
|
||||||
|
continue
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=mb.get("email") or mb.get("name") or mbId,
|
||||||
|
path=f"/{mbId}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": mbId, "kind": "mailbox"},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def _listFolders(self, mailboxId: str) -> List[ExternalEntry]:
|
||||||
|
result = await _infomaniakGet(self._token, f"/1/mail/{mailboxId}/folder")
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
logger.warning(f"Mail list-folders {mailboxId} failed: {result['error']}")
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
folders = data if isinstance(data, list) else data.get("folders", []) if isinstance(data, dict) else []
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for f in folders:
|
||||||
|
folderId = str(f.get("id") or f.get("path") or "")
|
||||||
|
if not folderId:
|
||||||
|
continue
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=f.get("name") or folderId,
|
||||||
|
path=f"/{mailboxId}/{folderId}",
|
||||||
|
isFolder=True,
|
||||||
|
metadata={"id": folderId, "kind": "folder"},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def _listMessages(
|
||||||
|
self,
|
||||||
|
mailboxId: str,
|
||||||
|
folderId: str,
|
||||||
|
limit: Optional[int],
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
effectiveLimit = self._DEFAULT_MESSAGE_LIMIT if limit is None else max(
|
||||||
|
1, min(int(limit), self._MAX_MESSAGE_LIMIT),
|
||||||
|
)
|
||||||
|
endpoint = f"/1/mail/{mailboxId}/folder/{folderId}/message?per_page={effectiveLimit}"
|
||||||
|
result = await _infomaniakGet(self._token, endpoint)
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
messages = data if isinstance(data, list) else data.get("messages", []) if isinstance(data, dict) else []
|
||||||
|
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for msg in messages:
|
||||||
|
uid = str(msg.get("uid") or msg.get("id") or "")
|
||||||
|
if not uid:
|
||||||
|
continue
|
||||||
|
subject = msg.get("subject") or "(no subject)"
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=subject,
|
||||||
|
path=f"/{mailboxId}/{folderId}/{uid}",
|
||||||
|
isFolder=False,
|
||||||
|
lastModified=msg.get("date") or msg.get("internal_date"),
|
||||||
|
metadata={
|
||||||
|
"uid": uid,
|
||||||
|
"from": msg.get("from") or msg.get("sender", ""),
|
||||||
|
"snippet": msg.get("preview", ""),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def download(self, path: str) -> DownloadResult:
|
||||||
|
import re
|
||||||
|
segments = [s for s in (path or "").strip("/").split("/") if s]
|
||||||
|
if len(segments) < 3:
|
||||||
|
return DownloadResult()
|
||||||
|
mailboxId, folderId, uid = segments[0], segments[1], segments[2]
|
||||||
|
|
||||||
|
content = await _infomaniakDownload(
|
||||||
|
self._token, f"/1/mail/{mailboxId}/folder/{folderId}/message/{uid}/download",
|
||||||
|
)
|
||||||
|
if content is None:
|
||||||
|
return DownloadResult()
|
||||||
|
|
||||||
|
meta = await _infomaniakGet(
|
||||||
|
self._token, f"/1/mail/{mailboxId}/folder/{folderId}/message/{uid}",
|
||||||
|
)
|
||||||
|
subject = uid
|
||||||
|
if isinstance(meta, dict) and not meta.get("error"):
|
||||||
|
unwrapped = _unwrapData(meta)
|
||||||
|
if isinstance(unwrapped, dict):
|
||||||
|
subject = unwrapped.get("subject") or uid
|
||||||
|
safeName = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", subject)[:80].strip(". ") or "email"
|
||||||
|
return DownloadResult(
|
||||||
|
data=content,
|
||||||
|
fileName=f"{safeName}.eml",
|
||||||
|
mimeType="message/rfc822",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def upload(self, path: str, data: bytes, fileName: str) -> dict:
|
||||||
|
return {"error": "Mail upload not applicable"}
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[ExternalEntry]:
|
||||||
|
segments = [s for s in (path or "").strip("/").split("/") if s]
|
||||||
|
if not segments:
|
||||||
|
mailboxes = await self._listMailboxes()
|
||||||
|
if not mailboxes:
|
||||||
|
return []
|
||||||
|
mailboxId = (mailboxes[0].metadata or {}).get("id") or mailboxes[0].path.strip("/")
|
||||||
|
else:
|
||||||
|
mailboxId = segments[0]
|
||||||
|
|
||||||
|
effectiveLimit = self._DEFAULT_MESSAGE_LIMIT if limit is None else max(
|
||||||
|
1, min(int(limit), self._MAX_MESSAGE_LIMIT),
|
||||||
|
)
|
||||||
|
endpoint = f"/1/mail/{mailboxId}/message/search?query={query}&per_page={effectiveLimit}"
|
||||||
|
result = await _infomaniakGet(self._token, endpoint)
|
||||||
|
if isinstance(result, dict) and result.get("error"):
|
||||||
|
return []
|
||||||
|
data = _unwrapData(result)
|
||||||
|
messages = data if isinstance(data, list) else data.get("messages", []) if isinstance(data, dict) else []
|
||||||
|
|
||||||
|
entries: List[ExternalEntry] = []
|
||||||
|
for msg in messages:
|
||||||
|
uid = str(msg.get("uid") or msg.get("id") or "")
|
||||||
|
if not uid:
|
||||||
|
continue
|
||||||
|
folderId = str(msg.get("folder_id") or msg.get("folderId") or "")
|
||||||
|
entries.append(ExternalEntry(
|
||||||
|
name=msg.get("subject") or uid,
|
||||||
|
path=f"/{mailboxId}/{folderId}/{uid}" if folderId else f"/{mailboxId}/{uid}",
|
||||||
|
isFolder=False,
|
||||||
|
metadata={"uid": uid, "from": msg.get("from", "")},
|
||||||
|
))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
class InfomaniakConnector(ProviderConnector):
|
||||||
|
"""Infomaniak ProviderConnector -- 1 connection -> kDrive + Mail."""
|
||||||
|
|
||||||
|
_SERVICE_MAP = {
|
||||||
|
"kdrive": KdriveAdapter,
|
||||||
|
"mail": MailAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
def getAvailableServices(self) -> List[str]:
|
||||||
|
return list(self._SERVICE_MAP.keys())
|
||||||
|
|
||||||
|
def getServiceAdapter(self, service: str) -> ServiceAdapter:
|
||||||
|
adapterClass = self._SERVICE_MAP.get(service)
|
||||||
|
if not adapterClass:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown Infomaniak service: {service}. "
|
||||||
|
f"Available: {list(self._SERVICE_MAP.keys())}"
|
||||||
|
)
|
||||||
|
return adapterClass(self.accessToken)
|
||||||
|
|
@ -32,6 +32,7 @@ class AuthAuthority(str, Enum):
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
MSFT = "msft"
|
MSFT = "msft"
|
||||||
CLICKUP = "clickup"
|
CLICKUP = "clickup"
|
||||||
|
INFOMANIAK = "infomaniak"
|
||||||
|
|
||||||
class ConnectionStatus(str, Enum):
|
class ConnectionStatus(str, Enum):
|
||||||
ACTIVE = "active"
|
ACTIVE = "active"
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ def get_auth_authority_options(
|
||||||
"google": "Google",
|
"google": "Google",
|
||||||
"msft": "Microsoft",
|
"msft": "Microsoft",
|
||||||
"clickup": "ClickUp",
|
"clickup": "ClickUp",
|
||||||
|
"infomaniak": "Infomaniak",
|
||||||
}
|
}
|
||||||
return [
|
return [
|
||||||
{"value": auth.value, "label": authorityLabels.get(auth.value, auth.value)}
|
{"value": auth.value, "label": authorityLabels.get(auth.value, auth.value)}
|
||||||
|
|
@ -329,6 +330,7 @@ def create_connection(
|
||||||
'msft': AuthAuthority.MSFT,
|
'msft': AuthAuthority.MSFT,
|
||||||
'google': AuthAuthority.GOOGLE,
|
'google': AuthAuthority.GOOGLE,
|
||||||
'clickup': AuthAuthority.CLICKUP,
|
'clickup': AuthAuthority.CLICKUP,
|
||||||
|
'infomaniak': AuthAuthority.INFOMANIAK,
|
||||||
}
|
}
|
||||||
|
|
||||||
authority = authority_map.get(connection_data.get('type'))
|
authority = authority_map.get(connection_data.get('type'))
|
||||||
|
|
@ -516,6 +518,8 @@ def connect_service(
|
||||||
auth_url = f"/api/google/auth/connect?connectionId={quote(connectionId, safe='')}"
|
auth_url = f"/api/google/auth/connect?connectionId={quote(connectionId, safe='')}"
|
||||||
elif connection.authority == AuthAuthority.CLICKUP:
|
elif connection.authority == AuthAuthority.CLICKUP:
|
||||||
auth_url = f"/api/clickup/auth/connect?connectionId={quote(connectionId, safe='')}"
|
auth_url = f"/api/clickup/auth/connect?connectionId={quote(connectionId, safe='')}"
|
||||||
|
elif connection.authority == AuthAuthority.INFOMANIAK:
|
||||||
|
auth_url = f"/api/infomaniak/auth/connect?connectionId={quote(connectionId, safe='')}"
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
|
||||||
328
modules/routes/routeSecurityInfomaniak.py
Normal file
328
modules/routes/routeSecurityInfomaniak.py
Normal file
|
|
@ -0,0 +1,328 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Infomaniak OAuth for data connections (UserConnection + Token).
|
||||||
|
|
||||||
|
Pure DATA_CONNECTION flow -- Infomaniak is NOT a login authority for PowerOn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, status, Depends, Query
|
||||||
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
import httpx
|
||||||
|
from jose import jwt as jose_jwt
|
||||||
|
from jose import JWTError
|
||||||
|
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
from modules.interfaces.interfaceDbApp import getInterface, getRootInterface
|
||||||
|
from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection
|
||||||
|
from modules.datamodels.datamodelSecurity import Token, TokenPurpose
|
||||||
|
from modules.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM
|
||||||
|
from modules.auth.oauthProviderConfig import infomaniakDataScopes
|
||||||
|
from modules.shared.timeUtils import createExpirationTimestamp, getUtcTimestamp, parseTimestamp
|
||||||
|
from modules.shared.i18nRegistry import apiRouteContext
|
||||||
|
|
||||||
|
routeApiMsg = apiRouteContext("routeSecurityInfomaniak")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_FLOW_CONNECT = "infomaniak_connect"
|
||||||
|
|
||||||
|
INFOMANIAK_AUTHORIZE_URL = "https://login.infomaniak.com/authorize"
|
||||||
|
INFOMANIAK_TOKEN_URL = "https://login.infomaniak.com/token"
|
||||||
|
INFOMANIAK_API_BASE = "https://api.infomaniak.com"
|
||||||
|
|
||||||
|
CLIENT_ID = APP_CONFIG.get("Service_INFOMANIAK_DATA_CLIENT_ID")
|
||||||
|
CLIENT_SECRET = APP_CONFIG.get("Service_INFOMANIAK_DATA_CLIENT_SECRET")
|
||||||
|
REDIRECT_URI = APP_CONFIG.get("Service_INFOMANIAK_OAUTH_REDIRECT_URI")
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_oauth_state(claims: Dict[str, Any]) -> str:
|
||||||
|
body = {**claims, "exp": int(time.time()) + 600}
|
||||||
|
return jose_jwt.encode(body, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_oauth_state(state: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return jose_jwt.decode(state, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
except JWTError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid OAuth state: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _require_infomaniak_config():
|
||||||
|
if not CLIENT_ID or not CLIENT_SECRET or not REDIRECT_URI:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=routeApiMsg(
|
||||||
|
"Infomaniak OAuth is not configured "
|
||||||
|
"(Service_INFOMANIAK_DATA_CLIENT_ID, Service_INFOMANIAK_DATA_CLIENT_SECRET, "
|
||||||
|
"Service_INFOMANIAK_OAUTH_REDIRECT_URI)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/infomaniak",
|
||||||
|
tags=["Security Infomaniak"],
|
||||||
|
responses={
|
||||||
|
404: {"description": "Not found"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
401: {"description": "Unauthorized"},
|
||||||
|
500: {"description": "Internal server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/connect")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
|
def auth_connect(
|
||||||
|
request: Request,
|
||||||
|
connectionId: str = Query(..., description="UserConnection id"),
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Start Infomaniak OAuth for an existing connection (requires gateway session)."""
|
||||||
|
try:
|
||||||
|
_require_infomaniak_config()
|
||||||
|
interface = getInterface(currentUser)
|
||||||
|
connections = interface.getUserConnections(currentUser.id)
|
||||||
|
connection = None
|
||||||
|
for conn in connections:
|
||||||
|
if conn.id == connectionId and conn.authority == AuthAuthority.INFOMANIAK:
|
||||||
|
connection = conn
|
||||||
|
break
|
||||||
|
if not connection:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=routeApiMsg("Infomaniak connection not found"),
|
||||||
|
)
|
||||||
|
|
||||||
|
state_jwt = _issue_oauth_state(
|
||||||
|
{
|
||||||
|
"flow": _FLOW_CONNECT,
|
||||||
|
"connectionId": connectionId,
|
||||||
|
"userId": str(currentUser.id),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
query = urlencode(
|
||||||
|
{
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"response_type": "code",
|
||||||
|
"access_type": "offline",
|
||||||
|
"redirect_uri": REDIRECT_URI,
|
||||||
|
"scope": " ".join(infomaniakDataScopes),
|
||||||
|
"state": state_jwt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
auth_url = f"{INFOMANIAK_AUTHORIZE_URL}?{query}"
|
||||||
|
return RedirectResponse(auth_url)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initiating Infomaniak connect: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to initiate Infomaniak connect: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/connect/callback")
|
||||||
|
async def auth_connect_callback(
|
||||||
|
code: str = Query(...),
|
||||||
|
state: str = Query(...),
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""OAuth callback for Infomaniak data connection."""
|
||||||
|
state_data = _parse_oauth_state(state)
|
||||||
|
if state_data.get("flow") != _FLOW_CONNECT:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=routeApiMsg("Invalid OAuth flow for this callback")
|
||||||
|
)
|
||||||
|
connection_id = state_data.get("connectionId")
|
||||||
|
user_id = state_data.get("userId")
|
||||||
|
if not connection_id or not user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=routeApiMsg("Missing connection or user in OAuth state")
|
||||||
|
)
|
||||||
|
|
||||||
|
_require_infomaniak_config()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_resp = await client.post(
|
||||||
|
INFOMANIAK_TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"client_secret": CLIENT_SECRET,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": REDIRECT_URI,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
if token_resp.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
f"Infomaniak token exchange failed: {token_resp.status_code} {token_resp.text}"
|
||||||
|
)
|
||||||
|
return HTMLResponse(
|
||||||
|
content=f"<html><body><h1>Connection Failed</h1><p>{token_resp.text}</p></body></html>",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
token_json = token_resp.json()
|
||||||
|
access_token = token_json.get("access_token")
|
||||||
|
refresh_token = token_json.get("refresh_token", "")
|
||||||
|
expires_in = int(token_json.get("expires_in", 0))
|
||||||
|
granted_scopes = token_json.get("scope", "")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
return HTMLResponse(
|
||||||
|
content="<html><body><h1>Connection Failed</h1><p>No access token.</p></body></html>",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
if not refresh_token:
|
||||||
|
try:
|
||||||
|
existing_tokens = rootInterface.getTokensByConnectionIdAndAuthority(
|
||||||
|
connection_id, AuthAuthority.INFOMANIAK
|
||||||
|
)
|
||||||
|
if existing_tokens:
|
||||||
|
existing_tokens.sort(
|
||||||
|
key=lambda x: parseTimestamp(x.createdAt, default=0), reverse=True
|
||||||
|
)
|
||||||
|
refresh_token = existing_tokens[0].tokenRefresh or ""
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
profile_resp = await client.get(
|
||||||
|
f"{INFOMANIAK_API_BASE}/1/profile",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
if profile_resp.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
f"Infomaniak profile lookup failed: {profile_resp.status_code} {profile_resp.text}"
|
||||||
|
)
|
||||||
|
return HTMLResponse(
|
||||||
|
content="<html><body><h1>Connection Failed</h1><p>Could not load Infomaniak profile.</p></body></html>",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
profile_payload = profile_resp.json()
|
||||||
|
profile = profile_payload.get("data") if isinstance(profile_payload, dict) else None
|
||||||
|
profile = profile or {}
|
||||||
|
|
||||||
|
user = rootInterface.getUser(user_id)
|
||||||
|
if not user:
|
||||||
|
return HTMLResponse(
|
||||||
|
content="""
|
||||||
|
<html><body><script>
|
||||||
|
if (window.opener) {
|
||||||
|
window.opener.postMessage({ type: 'infomaniak_connection_error', error: 'User not found' }, '*');
|
||||||
|
setTimeout(() => window.close(), 1000);
|
||||||
|
} else window.close();
|
||||||
|
</script></body></html>
|
||||||
|
""",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
interface = getInterface(user)
|
||||||
|
connections = interface.getUserConnections(user_id)
|
||||||
|
connection = None
|
||||||
|
for conn in connections:
|
||||||
|
if conn.id == connection_id:
|
||||||
|
connection = conn
|
||||||
|
break
|
||||||
|
if not connection:
|
||||||
|
return HTMLResponse(
|
||||||
|
content="""
|
||||||
|
<html><body><script>
|
||||||
|
if (window.opener) {
|
||||||
|
window.opener.postMessage({ type: 'infomaniak_connection_error', error: 'Connection not found' }, '*');
|
||||||
|
setTimeout(() => window.close(), 1000);
|
||||||
|
} else window.close();
|
||||||
|
</script></body></html>
|
||||||
|
""",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
ext_id = str(profile.get("id", "")) if profile.get("id") is not None else ""
|
||||||
|
username = profile.get("login") or profile.get("email") or ext_id
|
||||||
|
email = profile.get("email")
|
||||||
|
|
||||||
|
expires_at = createExpirationTimestamp(expires_in)
|
||||||
|
granted_scopes_list = (
|
||||||
|
granted_scopes
|
||||||
|
if isinstance(granted_scopes, list)
|
||||||
|
else (granted_scopes.split(" ") if granted_scopes else infomaniakDataScopes)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection.status = ConnectionStatus.ACTIVE
|
||||||
|
connection.lastChecked = getUtcTimestamp()
|
||||||
|
connection.expiresAt = expires_at
|
||||||
|
connection.externalId = ext_id
|
||||||
|
connection.externalUsername = username
|
||||||
|
if email:
|
||||||
|
connection.externalEmail = email
|
||||||
|
connection.grantedScopes = granted_scopes_list
|
||||||
|
rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump())
|
||||||
|
|
||||||
|
token = Token(
|
||||||
|
userId=user.id,
|
||||||
|
authority=AuthAuthority.INFOMANIAK,
|
||||||
|
connectionId=connection_id,
|
||||||
|
tokenPurpose=TokenPurpose.DATA_CONNECTION,
|
||||||
|
tokenAccess=access_token,
|
||||||
|
tokenRefresh=refresh_token,
|
||||||
|
tokenType=token_json.get("token_type", "bearer"),
|
||||||
|
expiresAt=expires_at,
|
||||||
|
createdAt=getUtcTimestamp(),
|
||||||
|
)
|
||||||
|
interface.saveConnectionToken(token)
|
||||||
|
|
||||||
|
return HTMLResponse(
|
||||||
|
content=f"""
|
||||||
|
<html>
|
||||||
|
<head><title>Connection Successful</title></head>
|
||||||
|
<body>
|
||||||
|
<script>
|
||||||
|
if (window.opener) {{
|
||||||
|
window.opener.postMessage({{
|
||||||
|
type: 'infomaniak_connection_success',
|
||||||
|
connection: {{
|
||||||
|
id: '{connection.id}',
|
||||||
|
status: 'connected',
|
||||||
|
type: 'infomaniak',
|
||||||
|
lastChecked: {getUtcTimestamp()},
|
||||||
|
expiresAt: {expires_at}
|
||||||
|
}}
|
||||||
|
}}, '*');
|
||||||
|
setTimeout(() => window.close(), 1000);
|
||||||
|
}} else {{
|
||||||
|
window.close();
|
||||||
|
}}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating Infomaniak connection: {str(e)}", exc_info=True)
|
||||||
|
return HTMLResponse(
|
||||||
|
content=f"""
|
||||||
|
<html><body><script>
|
||||||
|
if (window.opener) {{
|
||||||
|
window.opener.postMessage({{ type: 'infomaniak_connection_error', error: {json.dumps(str(e))} }}, '*');
|
||||||
|
setTimeout(() => window.close(), 1000);
|
||||||
|
}} else window.close();
|
||||||
|
</script></body></html>
|
||||||
|
""",
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
@ -297,6 +297,11 @@ async def runAgentLoop(
|
||||||
"featureInstanceId": featureInstanceId,
|
"featureInstanceId": featureInstanceId,
|
||||||
"mandateId": mandateId,
|
"mandateId": mandateId,
|
||||||
"modelMaxOutputTokens": getattr(aiResponse, "_modelMaxTokens", None) or 0,
|
"modelMaxOutputTokens": getattr(aiResponse, "_modelMaxTokens", None) or 0,
|
||||||
|
# Propagate the parent agent's budget to sub-agent tools (e.g.
|
||||||
|
# queryFeatureInstance) so they don't cap themselves at a smaller
|
||||||
|
# hardcoded round count than the user configured for the workspace.
|
||||||
|
"parentMaxRounds": state.maxRounds,
|
||||||
|
"parentMaxCostCHF": config.maxCostCHF,
|
||||||
})
|
})
|
||||||
state.totalToolCalls += len(results)
|
state.totalToolCalls += len(results)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,6 +181,15 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services):
|
||||||
req.requireNeutralization = True
|
req.requireNeutralization = True
|
||||||
return await aiService.callAi(req)
|
return await aiService.callAi(req)
|
||||||
|
|
||||||
|
# Inherit the parent agent's round/cost budget so the sub-agent
|
||||||
|
# actually runs the maxRounds the user configured for the
|
||||||
|
# workspace (workspace user setting `maxAgentRounds` ->
|
||||||
|
# `AgentConfig.maxRounds` -> tool context). Without this the
|
||||||
|
# sub-agent caps itself at the legacy 8-round default and aborts
|
||||||
|
# mid-investigation on data-heavy questions.
|
||||||
|
parentMaxRounds = context.get("parentMaxRounds")
|
||||||
|
parentMaxCostCHF = context.get("parentMaxCostCHF")
|
||||||
|
|
||||||
answer = await runFeatureDataAgent(
|
answer = await runFeatureDataAgent(
|
||||||
question=question,
|
question=question,
|
||||||
featureInstanceId=featureInstanceId,
|
featureInstanceId=featureInstanceId,
|
||||||
|
|
@ -194,6 +203,8 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services):
|
||||||
tableFilters=tableFilters,
|
tableFilters=tableFilters,
|
||||||
requestLang=requestLang,
|
requestLang=requestLang,
|
||||||
neutralizeFields=neutralizeFieldsPerTable if neutralizeFieldsPerTable else None,
|
neutralizeFields=neutralizeFieldsPerTable if neutralizeFieldsPerTable else None,
|
||||||
|
maxRounds=parentMaxRounds,
|
||||||
|
maxCostCHF=parentMaxCostCHF,
|
||||||
)
|
)
|
||||||
|
|
||||||
_featureQueryCache[cacheKey] = (time.time(), answer)
|
_featureQueryCache[cacheKey] = (time.time(), answer)
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,14 @@
|
||||||
"""Feature Data Sub-Agent.
|
"""Feature Data Sub-Agent.
|
||||||
|
|
||||||
Specialized mini-agent that queries feature-instance data tables. Receives
|
Specialized mini-agent that queries feature-instance data tables. Receives
|
||||||
schema context (fields, descriptions) for the selected tables and has two
|
schema context (fields, descriptions) for the selected tables and has its
|
||||||
tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds,
|
tools: browseTable, queryTable, aggregateTable. Runs its own agent loop
|
||||||
low budget) and returns structured results back to the main agent.
|
and returns structured results back to the main agent.
|
||||||
|
|
||||||
|
Round/cost budgets are inherited from the parent agent (workspace user
|
||||||
|
setting `maxAgentRounds` -> `AgentConfig.maxRounds`) and propagated through
|
||||||
|
the tool-call context. Defaults below are only used when the sub-agent is
|
||||||
|
invoked outside an agent loop (e.g. in tests).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -15,6 +20,7 @@ from typing import Any, Callable, Awaitable, Dict, List, Optional
|
||||||
from modules.datamodels.datamodelAi import (
|
from modules.datamodels.datamodelAi import (
|
||||||
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
|
||||||
)
|
)
|
||||||
|
from modules.datamodels.datamodelBase import MODEL_REGISTRY
|
||||||
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
||||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
||||||
|
|
@ -26,8 +32,11 @@ from modules.shared.timeUtils import getRequestNow, getRequestTimezone
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MAX_ROUNDS = 8
|
_DEFAULT_MAX_ROUNDS = 8
|
||||||
_MAX_COST_CHF = 0.15
|
# Per-round CHF cap. Multiplied by `maxRounds` so the cost guard scales
|
||||||
|
# with the configured round budget instead of cutting the loop short.
|
||||||
|
# 0.15 / 8 ≈ 0.019 — round up to 0.02 for some headroom.
|
||||||
|
_MAX_COST_CHF_PER_ROUND = 0.02
|
||||||
|
|
||||||
|
|
||||||
async def runFeatureDataAgent(
|
async def runFeatureDataAgent(
|
||||||
|
|
@ -43,6 +52,8 @@ async def runFeatureDataAgent(
|
||||||
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
|
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
|
||||||
requestLang: Optional[str] = None,
|
requestLang: Optional[str] = None,
|
||||||
neutralizeFields: Optional[Dict[str, List[str]]] = None,
|
neutralizeFields: Optional[Dict[str, List[str]]] = None,
|
||||||
|
maxRounds: Optional[int] = None,
|
||||||
|
maxCostCHF: Optional[float] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the feature data sub-agent and return the textual result.
|
"""Run the feature data sub-agent and return the textual result.
|
||||||
|
|
||||||
|
|
@ -60,6 +71,12 @@ async def runFeatureDataAgent(
|
||||||
requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt.
|
requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt.
|
||||||
neutralizeFields: Per-table list of field names to mask with placeholders
|
neutralizeFields: Per-table list of field names to mask with placeholders
|
||||||
before returning data to the AI.
|
before returning data to the AI.
|
||||||
|
maxRounds: Inherited from the parent agent's configured `maxRounds`
|
||||||
|
(workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`).
|
||||||
|
Falls back to the legacy 8-round default when not provided so direct
|
||||||
|
callers / tests still work.
|
||||||
|
maxCostCHF: Hard cost cap for the sub-agent run. When omitted, scales
|
||||||
|
with `maxRounds` to keep per-round budget constant.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Plain-text answer produced by the sub-agent.
|
Plain-text answer produced by the sub-agent.
|
||||||
|
|
@ -78,11 +95,22 @@ async def runFeatureDataAgent(
|
||||||
|
|
||||||
systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang)
|
systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang)
|
||||||
|
|
||||||
|
effectiveMaxRounds = int(maxRounds) if maxRounds and maxRounds > 0 else _DEFAULT_MAX_ROUNDS
|
||||||
|
effectiveMaxCost = (
|
||||||
|
float(maxCostCHF)
|
||||||
|
if maxCostCHF is not None and maxCostCHF > 0
|
||||||
|
else effectiveMaxRounds * _MAX_COST_CHF_PER_ROUND
|
||||||
|
)
|
||||||
|
|
||||||
config = AgentConfig(
|
config = AgentConfig(
|
||||||
maxRounds=_MAX_ROUNDS,
|
maxRounds=effectiveMaxRounds,
|
||||||
maxCostCHF=_MAX_COST_CHF,
|
maxCostCHF=effectiveMaxCost,
|
||||||
operationType=OperationTypeEnum.DATA_QUERY,
|
operationType=OperationTypeEnum.DATA_QUERY,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"Feature data sub-agent starting: featureInstanceId=%s, maxRounds=%d, maxCostCHF=%.4f",
|
||||||
|
featureInstanceId, effectiveMaxRounds, effectiveMaxCost,
|
||||||
|
)
|
||||||
|
|
||||||
costAccumulator = 0.0
|
costAccumulator = 0.0
|
||||||
|
|
||||||
|
|
@ -302,20 +330,26 @@ def _buildSchemaContext(
|
||||||
selectedTables: List[Dict[str, Any]],
|
selectedTables: List[Dict[str, Any]],
|
||||||
requestLang: Optional[str] = None,
|
requestLang: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a system prompt describing available tables and query strategy."""
|
"""Build a system prompt describing available tables and query strategy.
|
||||||
tableNames = []
|
|
||||||
tableBlocks = []
|
Per table the prompt now lists every selected column with its Python type,
|
||||||
|
German label, description and FK target (when available via the registered
|
||||||
|
Pydantic model). This gives the sub-agent enough context to:
|
||||||
|
* pick the right table (e.g. period-bucketed *AccountBalance over raw
|
||||||
|
JournalLine for "Saldo per <date>"),
|
||||||
|
* format date filters as UNIX timestamps when the type is float,
|
||||||
|
* recognise FK relations even though the tools cannot JOIN.
|
||||||
|
"""
|
||||||
|
tableNames: List[str] = []
|
||||||
|
tableBlocks: List[str] = []
|
||||||
|
|
||||||
for obj in selectedTables:
|
for obj in selectedTables:
|
||||||
meta = obj.get("meta", {})
|
meta = obj.get("meta", {})
|
||||||
tbl = meta.get("table", "?")
|
tbl = meta.get("table", "?")
|
||||||
fields = meta.get("fields", [])
|
fields = list(meta.get("fields") or [])
|
||||||
labelStr = resolveText(obj.get("label"), requestLang)
|
labelStr = resolveText(obj.get("label"), requestLang)
|
||||||
tableNames.append(tbl)
|
tableNames.append(tbl)
|
||||||
block = f" Table: {tbl} ({labelStr})"
|
tableBlocks.append(_buildTableSchemaBlock(tbl, labelStr, fields))
|
||||||
if fields:
|
|
||||||
block += f"\n Fields: {', '.join(fields)}"
|
|
||||||
tableBlocks.append(block)
|
|
||||||
|
|
||||||
header = f"You are a data query assistant for the '{featureCode}' feature"
|
header = f"You are a data query assistant for the '{featureCode}' feature"
|
||||||
if instanceLabel:
|
if instanceLabel:
|
||||||
|
|
@ -351,9 +385,96 @@ def _buildSchemaContext(
|
||||||
"",
|
"",
|
||||||
"RULES:",
|
"RULES:",
|
||||||
"- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.",
|
"- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.",
|
||||||
|
"- Float fields whose description mentions 'unix timestamp' (e.g. bookingDate, lastSyncAt) "
|
||||||
|
"store seconds since epoch. Convert dates to a unix-seconds float before filtering "
|
||||||
|
"(e.g. '2025-12-31' -> 1735603200.0); never compare such fields against ISO strings.",
|
||||||
|
"- The query tools operate on ONE table at a time and CANNOT JOIN. To combine related "
|
||||||
|
"tables (FK target shown in [FK -> Table.field]), query each separately and reason "
|
||||||
|
"about the link in your answer.",
|
||||||
|
"- When a table has period-bucketed aggregates (opening/closing balances or totals per "
|
||||||
|
"period), prefer it over recomputing the same aggregate from raw transactional rows.",
|
||||||
"- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.",
|
"- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.",
|
||||||
"- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.",
|
"- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.",
|
||||||
"- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.",
|
"- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.",
|
||||||
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
||||||
]
|
]
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _buildTableSchemaBlock(tableName: str, tableLabel: str, fields: List[str]) -> str:
|
||||||
|
"""Render a single table's schema block, enriched from its Pydantic model.
|
||||||
|
|
||||||
|
Falls back to a flat field list when the model isn't registered (e.g. pure
|
||||||
|
UDB tables, or in early-startup contexts before datamodels are imported).
|
||||||
|
"""
|
||||||
|
headerLine = f' Table: {tableName} "{tableLabel}"'
|
||||||
|
|
||||||
|
modelClass = MODEL_REGISTRY.get(tableName)
|
||||||
|
docLine = ""
|
||||||
|
if modelClass is not None:
|
||||||
|
rawDoc = (modelClass.__doc__ or "").strip()
|
||||||
|
if rawDoc:
|
||||||
|
docLine = " Description: " + " ".join(rawDoc.split())
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return headerLine + (("\n" + docLine) if docLine else "")
|
||||||
|
|
||||||
|
if modelClass is None:
|
||||||
|
return headerLine + f"\n Fields: {', '.join(fields)}"
|
||||||
|
|
||||||
|
fieldSet = set(fields)
|
||||||
|
fieldLines: List[str] = []
|
||||||
|
for fieldName, fieldInfo in modelClass.model_fields.items():
|
||||||
|
if fieldName not in fieldSet:
|
||||||
|
continue
|
||||||
|
fieldLines.append(" - " + _formatFieldLine(fieldName, fieldInfo))
|
||||||
|
|
||||||
|
extras = sorted(fieldSet.difference(modelClass.model_fields.keys()))
|
||||||
|
for extra in extras:
|
||||||
|
fieldLines.append(f" - {extra} (unknown)")
|
||||||
|
|
||||||
|
parts = [headerLine]
|
||||||
|
if docLine:
|
||||||
|
parts.append(docLine)
|
||||||
|
parts.append(" Fields:")
|
||||||
|
parts.extend(fieldLines)
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _formatFieldLine(fieldName: str, fieldInfo: Any) -> str:
|
||||||
|
"""Format one field as: '<name> (<type>) "<label>": <description> [FK -> Table.field]'."""
|
||||||
|
pyType = _summarizePythonType(getattr(fieldInfo, "annotation", None))
|
||||||
|
|
||||||
|
extra = getattr(fieldInfo, "json_schema_extra", None)
|
||||||
|
if not isinstance(extra, dict):
|
||||||
|
extra = {}
|
||||||
|
|
||||||
|
rawLabel = extra.get("label")
|
||||||
|
label = rawLabel if isinstance(rawLabel, str) else None
|
||||||
|
|
||||||
|
rawDesc = getattr(fieldInfo, "description", None)
|
||||||
|
desc = rawDesc.strip() if isinstance(rawDesc, str) else ""
|
||||||
|
|
||||||
|
line = f"{fieldName} ({pyType})"
|
||||||
|
if label and label != fieldName:
|
||||||
|
line += f' "{label}"'
|
||||||
|
if desc:
|
||||||
|
line += f": {desc}"
|
||||||
|
|
||||||
|
fkTarget = extra.get("fk_target")
|
||||||
|
if isinstance(fkTarget, dict) and fkTarget.get("table"):
|
||||||
|
targetField = fkTarget.get("targetField") or "id"
|
||||||
|
line += f" [FK -> {fkTarget['table']}.{targetField}]"
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
def _summarizePythonType(annotation: Any) -> str:
|
||||||
|
"""Compact stringification of a Pydantic field annotation for AI prompts."""
|
||||||
|
if annotation is None:
|
||||||
|
return "any"
|
||||||
|
raw = str(annotation)
|
||||||
|
raw = raw.replace("typing.", "")
|
||||||
|
if raw.startswith("<class '") and raw.endswith("'>"):
|
||||||
|
raw = raw[len("<class '"):-len("'>")]
|
||||||
|
return raw
|
||||||
|
|
|
||||||
136
tests/unit/services/test_featureDataAgent_schema.py
Normal file
136
tests/unit/services/test_featureDataAgent_schema.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unit test: feature data sub-agent schema context is rich enough.
|
||||||
|
|
||||||
|
The sub-agent's quality of answers depends almost entirely on the schema
|
||||||
|
prompt it receives. This test guards the contract that, for every selected
|
||||||
|
table, the prompt exposes:
|
||||||
|
|
||||||
|
* the technical table name + i18n label,
|
||||||
|
* every selected field with its Python type, German label, description and
|
||||||
|
FK target (when registered via Pydantic models),
|
||||||
|
* the structural rules around date-as-unix-timestamp, no JOINs, and
|
||||||
|
preference for period-bucketed aggregate tables.
|
||||||
|
|
||||||
|
Without that context the agent silently returns wrong numbers (e.g. summing
|
||||||
|
`TrusteeDataJournalLine.debitAmount` without a date filter when the user
|
||||||
|
asked for the closing balance per period).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from modules.shared import fkRegistry
|
||||||
|
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
|
||||||
|
_buildSchemaContext,
|
||||||
|
_buildTableSchemaBlock,
|
||||||
|
_formatFieldLine,
|
||||||
|
_summarizePythonType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def _ensureModels():
|
||||||
|
fkRegistry._ensureModelsLoaded()
|
||||||
|
|
||||||
|
|
||||||
|
def _trusteeAccountBalanceObj():
|
||||||
|
return {
|
||||||
|
"objectKey": "data.feature.trustee.TrusteeDataAccountBalance",
|
||||||
|
"label": {"de": "Kontosalden", "en": "Account balances"},
|
||||||
|
"meta": {
|
||||||
|
"table": "TrusteeDataAccountBalance",
|
||||||
|
"fields": [
|
||||||
|
"id", "accountNumber", "periodYear", "periodMonth",
|
||||||
|
"openingBalance", "debitTotal", "creditTotal",
|
||||||
|
"closingBalance", "currency",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _trusteeJournalLineObj():
|
||||||
|
return {
|
||||||
|
"objectKey": "data.feature.trustee.TrusteeDataJournalLine",
|
||||||
|
"label": {"de": "Buchungszeilen", "en": "Journal lines"},
|
||||||
|
"meta": {
|
||||||
|
"table": "TrusteeDataJournalLine",
|
||||||
|
"fields": [
|
||||||
|
"id", "journalEntryId", "accountNumber",
|
||||||
|
"debitAmount", "creditAmount", "currency", "description",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarizePythonType_compactsTypingPrefix():
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
assert _summarizePythonType(str) == "str"
|
||||||
|
assert _summarizePythonType(Optional[float]) == "Optional[float]"
|
||||||
|
assert _summarizePythonType(Dict[str, Any]) == "Dict[str, Any]"
|
||||||
|
assert _summarizePythonType(None) == "any"
|
||||||
|
|
||||||
|
|
||||||
|
def test_formatFieldLine_includesLabelDescriptionAndFk():
|
||||||
|
from modules.datamodels.datamodelBase import MODEL_REGISTRY
|
||||||
|
cls = MODEL_REGISTRY.get("TrusteeDataJournalLine")
|
||||||
|
assert cls is not None, "Trustee datamodels must be registered for this test"
|
||||||
|
journalEntryId = cls.model_fields["journalEntryId"]
|
||||||
|
line = _formatFieldLine("journalEntryId", journalEntryId)
|
||||||
|
assert line.startswith("journalEntryId (str)")
|
||||||
|
assert '"Buchung"' in line
|
||||||
|
assert "[FK -> TrusteeDataJournalEntry.id]" in line
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildTableSchemaBlock_listsAccountBalanceFields():
|
||||||
|
obj = _trusteeAccountBalanceObj()
|
||||||
|
block = _buildTableSchemaBlock(
|
||||||
|
obj["meta"]["table"], "Kontosalden", obj["meta"]["fields"],
|
||||||
|
)
|
||||||
|
assert "Table: TrusteeDataAccountBalance" in block
|
||||||
|
assert "Description: Account balance per period" in block
|
||||||
|
assert "closingBalance (float)" in block
|
||||||
|
assert "periodYear (int)" in block
|
||||||
|
assert "periodMonth (int)" in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildTableSchemaBlock_unknownTableFallsBackToFlatFields():
|
||||||
|
block = _buildTableSchemaBlock(
|
||||||
|
"NoSuchTable", "Demo", ["foo", "bar"],
|
||||||
|
)
|
||||||
|
assert "NoSuchTable" in block
|
||||||
|
assert "Fields: foo, bar" in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildSchemaContext_containsRichFieldsAndKeyRules():
|
||||||
|
selected = [_trusteeJournalLineObj(), _trusteeAccountBalanceObj()]
|
||||||
|
prompt = _buildSchemaContext(
|
||||||
|
featureCode="trustee",
|
||||||
|
instanceLabel="Demo AG",
|
||||||
|
selectedTables=selected,
|
||||||
|
requestLang="de",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "TrusteeDataJournalLine" in prompt
|
||||||
|
assert "TrusteeDataAccountBalance" in prompt
|
||||||
|
assert 'debitAmount (float) "Soll"' in prompt
|
||||||
|
assert 'closingBalance (float) "Schlusssaldo"' in prompt
|
||||||
|
|
||||||
|
assert "[FK -> TrusteeDataJournalEntry.id]" in prompt
|
||||||
|
|
||||||
|
assert "unix timestamp" in prompt
|
||||||
|
assert "CANNOT JOIN" in prompt
|
||||||
|
assert "period-bucketed aggregates" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildTableSchemaBlock_journalLineHasNoBookingDate():
|
||||||
|
"""JournalLine has no bookingDate column. The agent must see this so it does
|
||||||
|
not invent a `bookingDate` filter on JournalLine and instead either joins to
|
||||||
|
JournalEntry or uses *AccountBalance for period filters."""
|
||||||
|
obj = _trusteeJournalLineObj()
|
||||||
|
block = _buildTableSchemaBlock(
|
||||||
|
obj["meta"]["table"], "Buchungszeilen", obj["meta"]["fields"],
|
||||||
|
)
|
||||||
|
assert "Table: TrusteeDataJournalLine" in block
|
||||||
|
assert "bookingDate" not in block
|
||||||
Loading…
Reference in a new issue