597 lines
No EOL
23 KiB
Python
597 lines
No EOL
23 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Connection routes for the backend API.
|
|
Implements the endpoints for connection management.
|
|
|
|
SECURITY NOTE:
|
|
- Regular connections endpoint (/api/connections/) only returns connections for the current user
|
|
- Admin endpoint (/api/connections/admin/all) provides access to all connections for management purposes
|
|
- This prevents security vulnerabilities where admin users could see other users' connections
|
|
"""
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
|
from typing import List, Dict, Any, Optional
|
|
from fastapi import status
|
|
import logging
|
|
import json
|
|
import math
|
|
|
|
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
|
|
from modules.datamodels.datamodelSecurity import Token
|
|
from modules.auth import getCurrentUser, limiter
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
|
from modules.interfaces.interfaceDbApp import getInterface
|
|
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
|
from modules.interfaces.interfaceDbManagement import ComponentObjects
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def getTokenStatusForConnection(interface, connectionId: str) -> tuple[str, Optional[float]]:
|
|
"""
|
|
Get token status and expiration for a connection.
|
|
|
|
Args:
|
|
interface: The database interface
|
|
connectionId: The connection ID to check
|
|
|
|
Returns:
|
|
tuple: (tokenStatus, tokenExpiresAt)
|
|
- tokenStatus: 'active', 'expired', or 'none'
|
|
- tokenExpiresAt: UTC timestamp or None
|
|
"""
|
|
try:
|
|
# Query tokens table for the latest token for this connection
|
|
tokens = interface.db.getRecordset(
|
|
Token,
|
|
recordFilter={"connectionId": connectionId}
|
|
)
|
|
|
|
if not tokens:
|
|
return "none", None
|
|
|
|
# Find the most recent token (highest createdAt timestamp)
|
|
latestToken = None
|
|
latestCreatedAt = 0
|
|
|
|
for tokenData in tokens:
|
|
createdAt = parseTimestamp(tokenData.get("createdAt"), default=0)
|
|
if createdAt > latestCreatedAt:
|
|
latestCreatedAt = createdAt
|
|
latestToken = tokenData
|
|
|
|
if not latestToken:
|
|
return "none", None
|
|
|
|
# Check if token is expired
|
|
expiresAt = parseTimestamp(latestToken.get("expiresAt"))
|
|
if not expiresAt:
|
|
return "none", None
|
|
|
|
currentTime = getUtcTimestamp()
|
|
|
|
# Add 5 minute buffer for proactive refresh
|
|
bufferTime = 5 * 60 # 5 minutes in seconds
|
|
if expiresAt <= currentTime:
|
|
return "expired", expiresAt
|
|
elif expiresAt <= (currentTime + bufferTime):
|
|
# Token expires soon - mark as active but log for proactive refresh
|
|
logger.debug(f"Token for connection {connectionId} expires soon (in {expiresAt - currentTime} seconds)")
|
|
return "active", expiresAt
|
|
else:
|
|
return "active", expiresAt
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting token status for connection {connectionId}: {str(e)}")
|
|
return "none", None
|
|
|
|
router = APIRouter(
|
|
prefix="/api/connections",
|
|
tags=["Manage Connections"],
|
|
responses={404: {"description": "Not found"}}
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# OPTIONS ENDPOINTS (for dropdowns)
|
|
# ============================================================================
|
|
|
|
@router.get("/statuses/options", response_model=List[Dict[str, Any]])
|
|
@limiter.limit("60/minute")
|
|
async def getConnectionStatusOptions(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get connection status options for select dropdowns.
|
|
Returns standardized format: [{ value, label }]
|
|
"""
|
|
return [
|
|
{"value": status.value, "label": status.value.capitalize()}
|
|
for status in ConnectionStatus
|
|
]
|
|
|
|
|
|
@router.get("/authorities/options", response_model=List[Dict[str, Any]])
|
|
@limiter.limit("60/minute")
|
|
async def getAuthAuthorityOptions(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get authentication authority options for select dropdowns.
|
|
Returns standardized format: [{ value, label }]
|
|
"""
|
|
authorityLabels = {
|
|
"local": "Local",
|
|
"google": "Google",
|
|
"msft": "Microsoft"
|
|
}
|
|
return [
|
|
{"value": auth.value, "label": authorityLabels.get(auth.value, auth.value)}
|
|
for auth in AuthAuthority
|
|
]
|
|
|
|
|
|
# ============================================================================
|
|
# CRUD ENDPOINTS
|
|
# ============================================================================
|
|
|
|
@router.get("/", response_model=PaginatedResponse[UserConnection])
|
|
@limiter.limit("30/minute")
|
|
async def get_connections(
|
|
request: Request,
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> PaginatedResponse[UserConnection]:
|
|
"""Get connections for the current user with optional pagination, sorting, and filtering.
|
|
|
|
SECURITY: This endpoint is secure - users can only see their own connections.
|
|
Automatically refreshes expired OAuth tokens in the background.
|
|
|
|
Query Parameters:
|
|
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
|
|
|
Examples:
|
|
- GET /api/connections/ (no pagination - returns all items)
|
|
- GET /api/connections/?pagination={"page":1,"pageSize":10,"sort":[]}
|
|
- GET /api/connections/?pagination={"page":1,"pageSize":10,"filters":{"status":"active"}}
|
|
"""
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Parse pagination parameter
|
|
paginationParams = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
if paginationDict:
|
|
# Normalize pagination dict (handles top-level "search" field)
|
|
paginationDict = normalize_pagination_dict(paginationDict)
|
|
paginationParams = PaginationParams(**paginationDict)
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid pagination parameter: {str(e)}"
|
|
)
|
|
|
|
# SECURITY FIX: All users (including admins) can only see their own connections
|
|
# This prevents admin from seeing other users' connections and causing confusion
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
|
|
# Perform silent token refresh for expired OAuth connections
|
|
try:
|
|
from modules.auth import token_refresh_service
|
|
refresh_result = await token_refresh_service.refresh_expired_tokens(currentUser.id)
|
|
if refresh_result.get("refreshed", 0) > 0:
|
|
logger.info(f"Silently refreshed {refresh_result['refreshed']} tokens for user {currentUser.id}")
|
|
# Re-fetch connections to get updated token status
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
except Exception as e:
|
|
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
|
|
# Continue with original connections even if refresh fails
|
|
|
|
# Enhance each connection with token status information and convert to dict
|
|
enhanced_connections_dict = []
|
|
for connection in connections:
|
|
# Get token status for this connection
|
|
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
|
|
|
# Convert to dict for filtering/sorting
|
|
connection_dict = {
|
|
"id": connection.id,
|
|
"userId": connection.userId,
|
|
"authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
|
|
"externalId": connection.externalId,
|
|
"externalUsername": connection.externalUsername or "",
|
|
"externalEmail": connection.externalEmail, # Keep None instead of converting to empty string
|
|
"status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
|
|
"connectedAt": connection.connectedAt,
|
|
"lastChecked": connection.lastChecked,
|
|
"expiresAt": connection.expiresAt,
|
|
"tokenStatus": tokenStatus,
|
|
"tokenExpiresAt": tokenExpiresAt
|
|
}
|
|
enhanced_connections_dict.append(connection_dict)
|
|
|
|
# If no pagination requested, return all items
|
|
if paginationParams is None:
|
|
# Convert back to UserConnection objects (enum strings are already in dict)
|
|
items = []
|
|
for conn_dict in enhanced_connections_dict:
|
|
conn_dict_copy = dict(conn_dict)
|
|
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
|
try:
|
|
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
|
except ValueError:
|
|
pass
|
|
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
|
try:
|
|
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
|
except ValueError:
|
|
pass
|
|
items.append(UserConnection(**conn_dict_copy))
|
|
return PaginatedResponse(
|
|
items=items,
|
|
pagination=None
|
|
)
|
|
|
|
# Apply filtering if provided
|
|
if paginationParams.filters:
|
|
component_interface = ComponentObjects()
|
|
component_interface.setUserContext(currentUser)
|
|
enhanced_connections_dict = component_interface._applyFilters(
|
|
enhanced_connections_dict,
|
|
paginationParams.filters
|
|
)
|
|
|
|
# Apply sorting if provided
|
|
if paginationParams.sort:
|
|
component_interface = ComponentObjects()
|
|
component_interface.setUserContext(currentUser)
|
|
enhanced_connections_dict = component_interface._applySorting(
|
|
enhanced_connections_dict,
|
|
paginationParams.sort
|
|
)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(enhanced_connections_dict)
|
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
|
endIdx = startIdx + paginationParams.pageSize
|
|
paged_connections = enhanced_connections_dict[startIdx:endIdx]
|
|
|
|
# Convert back to UserConnection objects (convert enum strings back to enums)
|
|
items = []
|
|
for conn_dict in paged_connections:
|
|
# Convert enum strings back to enum objects
|
|
conn_dict_copy = dict(conn_dict)
|
|
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
|
try:
|
|
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
|
except ValueError:
|
|
pass # Keep as string if invalid
|
|
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
|
try:
|
|
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
|
except ValueError:
|
|
pass # Keep as string if invalid
|
|
items.append(UserConnection(**conn_dict_copy))
|
|
|
|
return PaginatedResponse(
|
|
items=items,
|
|
pagination=PaginationMetadata(
|
|
currentPage=paginationParams.page,
|
|
pageSize=paginationParams.pageSize,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages,
|
|
sort=paginationParams.sort,
|
|
filters=paginationParams.filters
|
|
)
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting connections: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get connections: {str(e)}"
|
|
)
|
|
|
|
@router.post("/", response_model=UserConnection)
|
|
@limiter.limit("10/minute")
|
|
async def create_connection(
|
|
request: Request,
|
|
connection_data: Dict[str, Any] = Body(...),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> UserConnection:
|
|
"""Create a new connection for the current user
|
|
|
|
SECURITY: This endpoint is secure - it always creates connections for the current user
|
|
and cannot be used to create connections for other users.
|
|
"""
|
|
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Map type to authority
|
|
authority_map = {
|
|
'msft': AuthAuthority.MSFT,
|
|
'google': AuthAuthority.GOOGLE
|
|
}
|
|
|
|
authority = authority_map.get(connection_data.get('type'))
|
|
if not authority:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported connection type: {connection_data.get('type')}"
|
|
)
|
|
|
|
# Get fresh copy of user from database
|
|
user = interface.getUser(currentUser.id)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="User not found"
|
|
)
|
|
|
|
# Always create a new connection with PENDING status
|
|
connection = interface.addUserConnection(
|
|
userId=currentUser.id,
|
|
authority=authority,
|
|
externalId="", # Will be set after OAuth
|
|
externalUsername="", # Will be set after OAuth
|
|
status=ConnectionStatus.PENDING # Start with PENDING status
|
|
)
|
|
|
|
# Save connection record - models now handle timestamp serialization automatically
|
|
interface.db.recordModify(UserConnection, connection.id, connection.model_dump())
|
|
|
|
|
|
return connection
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error creating connection: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to create connection: {str(e)}"
|
|
)
|
|
|
|
@router.put("/{connectionId}", response_model=UserConnection)
|
|
@limiter.limit("10/minute")
|
|
async def update_connection(
|
|
request: Request,
|
|
connectionId: str = Path(..., description="The ID of the connection to update"),
|
|
connection_data: Dict[str, Any] = Body(...),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> UserConnection:
|
|
"""Update an existing connection for the current user
|
|
|
|
SECURITY: This endpoint is secure - users can only update their own connections.
|
|
"""
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Find the connection
|
|
connection = None
|
|
# SECURITY FIX: All users (including admins) can only update their own connections
|
|
# This prevents admin from updating other users' connections and causing confusion
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
for conn in connections:
|
|
if conn.id == connectionId:
|
|
connection = conn
|
|
break
|
|
|
|
if not connection:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Connection not found"
|
|
)
|
|
|
|
# Update connection fields
|
|
for field, value in connection_data.items():
|
|
if hasattr(connection, field):
|
|
setattr(connection, field, value)
|
|
|
|
# Update lastChecked timestamp using UTC timestamp
|
|
connection.lastChecked = getUtcTimestamp()
|
|
|
|
# Update connection - models now handle timestamp serialization automatically
|
|
interface.db.recordModify(UserConnection, connectionId, connection.model_dump())
|
|
|
|
# Get token status for the updated connection
|
|
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connectionId)
|
|
|
|
# Create enhanced connection with token status
|
|
enhanced_connection = UserConnection(
|
|
id=connection.id,
|
|
userId=connection.userId,
|
|
authority=connection.authority,
|
|
externalId=connection.externalId,
|
|
externalUsername=connection.externalUsername,
|
|
externalEmail=connection.externalEmail,
|
|
status=connection.status,
|
|
connectedAt=connection.connectedAt,
|
|
lastChecked=connection.lastChecked,
|
|
expiresAt=connection.expiresAt,
|
|
tokenStatus=tokenStatus,
|
|
tokenExpiresAt=tokenExpiresAt
|
|
)
|
|
|
|
return enhanced_connection
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error updating connection: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to update connection: {str(e)}"
|
|
)
|
|
|
|
@router.post("/{connectionId}/connect")
|
|
@limiter.limit("10/minute")
|
|
async def connect_service(
|
|
request: Request,
|
|
connectionId: str = Path(..., description="The ID of the connection to connect"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> Dict[str, Any]:
|
|
"""Connect a service for the current user
|
|
|
|
SECURITY: This endpoint is secure - users can only connect their own connections.
|
|
"""
|
|
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Find the connection
|
|
connection = None
|
|
# SECURITY FIX: All users (including admins) can only connect their own connections
|
|
# This prevents admin from connecting other users' connections and causing confusion
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
for conn in connections:
|
|
if conn.id == connectionId:
|
|
connection = conn
|
|
break
|
|
|
|
if not connection:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Connection not found"
|
|
)
|
|
|
|
# Initiate OAuth flow with state=connect
|
|
auth_url = None
|
|
if connection.authority == AuthAuthority.MSFT:
|
|
# Use the same login endpoint with state=connect to ensure account selection
|
|
# Include current user ID in state
|
|
state_data = {
|
|
"type": "connect",
|
|
"connectionId": connectionId,
|
|
"userId": currentUser.id # Add current user ID
|
|
}
|
|
auth_url = f"/api/msft/login?state={json.dumps(state_data)}"
|
|
elif connection.authority == AuthAuthority.GOOGLE:
|
|
state_data = {
|
|
"type": "connect",
|
|
"connectionId": connectionId,
|
|
"userId": currentUser.id # Add current user ID
|
|
}
|
|
auth_url = f"/api/google/login?state={json.dumps(state_data)}"
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported authority: {connection.authority}"
|
|
)
|
|
|
|
return {"authUrl": auth_url}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error connecting service: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to connect service: {str(e)}"
|
|
)
|
|
|
|
@router.post("/{connectionId}/disconnect")
|
|
@limiter.limit("10/minute")
|
|
async def disconnect_service(
|
|
request: Request,
|
|
connectionId: str = Path(..., description="The ID of the connection to disconnect"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> Dict[str, Any]:
|
|
"""Disconnect a service for the current user
|
|
|
|
SECURITY: This endpoint is secure - users can only disconnect their own connections.
|
|
"""
|
|
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Find the connection
|
|
connection = None
|
|
# SECURITY FIX: All users (including admins) can only disconnect their own connections
|
|
# This prevents admin from disconnecting other users' connections and causing confusion
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
for conn in connections:
|
|
if conn.id == connectionId:
|
|
connection = conn
|
|
break
|
|
|
|
if not connection:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Connection not found"
|
|
)
|
|
|
|
# Update connection status
|
|
connection.status = ConnectionStatus.INACTIVE
|
|
connection.lastChecked = getUtcTimestamp()
|
|
|
|
# Update connection record - models now handle timestamp serialization automatically
|
|
interface.db.recordModify(UserConnection, connectionId, connection.model_dump())
|
|
|
|
|
|
return {"message": "Service disconnected successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error disconnecting service: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to disconnect service: {str(e)}"
|
|
)
|
|
|
|
@router.delete("/{connectionId}")
|
|
@limiter.limit("10/minute")
|
|
async def delete_connection(
|
|
request: Request,
|
|
connectionId: str = Path(..., description="The ID of the connection to delete"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> Dict[str, Any]:
|
|
"""Delete a connection for the current user
|
|
|
|
SECURITY: This endpoint is secure - users can only delete their own connections.
|
|
"""
|
|
|
|
try:
|
|
interface = getInterface(currentUser)
|
|
|
|
# Find the connection
|
|
connection = None
|
|
# SECURITY FIX: All users (including admins) can only delete their own connections
|
|
# This prevents admin from deleting other users' connections and causing confusion
|
|
connections = interface.getUserConnections(currentUser.id)
|
|
for conn in connections:
|
|
if conn.id == connectionId:
|
|
connection = conn
|
|
break
|
|
|
|
if not connection:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Connection not found"
|
|
)
|
|
|
|
# Remove the connection - only need connectionId since permissions are verified
|
|
interface.removeUserConnection(connectionId)
|
|
|
|
return {"message": "Connection deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting connection: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to delete connection: {str(e)}"
|
|
) |