# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Connection routes for the backend API. Implements the endpoints for connection management. SECURITY NOTE: - Regular connections endpoint (/api/connections/) only returns connections for the current user - Admin endpoint (/api/connections/admin/all) provides access to all connections for management purposes - This prevents security vulnerabilities where admin users could see other users' connections """ from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query from typing import List, Dict, Any, Optional from fastapi import status import logging import json import math from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus from modules.datamodels.datamodelSecurity import Token from modules.auth import getCurrentUser, limiter from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict from modules.interfaces.interfaceDbApp import getInterface from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp from modules.interfaces.interfaceDbManagement import ComponentObjects # Configure logger logger = logging.getLogger(__name__) def getTokenStatusForConnection(interface, connectionId: str) -> tuple[str, Optional[float]]: """ Get token status and expiration for a connection. Args: interface: The database interface connectionId: The connection ID to check Returns: tuple: (tokenStatus, tokenExpiresAt) - tokenStatus: 'active', 'expired', or 'none' - tokenExpiresAt: UTC timestamp or None """ try: # Query tokens table for the latest token for this connection tokens = interface.db.getRecordset( Token, recordFilter={"connectionId": connectionId} ) if not tokens: return "none", None # Find the most recent token (highest createdAt timestamp) latestToken = None latestCreatedAt = 0 for tokenData in tokens: createdAt = parseTimestamp(tokenData.get("createdAt"), default=0) if createdAt > latestCreatedAt: latestCreatedAt = createdAt latestToken = tokenData if not latestToken: return "none", None # Check if token is expired expiresAt = parseTimestamp(latestToken.get("expiresAt")) if not expiresAt: return "none", None currentTime = getUtcTimestamp() # Add 5 minute buffer for proactive refresh bufferTime = 5 * 60 # 5 minutes in seconds if expiresAt <= currentTime: return "expired", expiresAt elif expiresAt <= (currentTime + bufferTime): # Token expires soon - mark as active but log for proactive refresh logger.debug(f"Token for connection {connectionId} expires soon (in {expiresAt - currentTime} seconds)") return "active", expiresAt else: return "active", expiresAt except Exception as e: logger.error(f"Error getting token status for connection {connectionId}: {str(e)}") return "none", None router = APIRouter( prefix="/api/connections", tags=["Manage Connections"], responses={404: {"description": "Not found"}} ) # ============================================================================ # OPTIONS ENDPOINTS (for dropdowns) # ============================================================================ @router.get("/statuses/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") async def getConnectionStatusOptions( request: Request, currentUser: User = Depends(getCurrentUser) ) -> List[Dict[str, Any]]: """ Get connection status options for select dropdowns. Returns standardized format: [{ value, label }] """ return [ {"value": status.value, "label": status.value.capitalize()} for status in ConnectionStatus ] @router.get("/authorities/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") async def getAuthAuthorityOptions( request: Request, currentUser: User = Depends(getCurrentUser) ) -> List[Dict[str, Any]]: """ Get authentication authority options for select dropdowns. Returns standardized format: [{ value, label }] """ authorityLabels = { "local": "Local", "google": "Google", "msft": "Microsoft" } return [ {"value": auth.value, "label": authorityLabels.get(auth.value, auth.value)} for auth in AuthAuthority ] # ============================================================================ # CRUD ENDPOINTS # ============================================================================ @router.get("/", response_model=PaginatedResponse[UserConnection]) @limiter.limit("30/minute") async def get_connections( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) ) -> PaginatedResponse[UserConnection]: """Get connections for the current user with optional pagination, sorting, and filtering. SECURITY: This endpoint is secure - users can only see their own connections. Automatically refreshes expired OAuth tokens in the background. Query Parameters: - pagination: JSON-encoded PaginationParams object, or None for no pagination Examples: - GET /api/connections/ (no pagination - returns all items) - GET /api/connections/?pagination={"page":1,"pageSize":10,"sort":[]} - GET /api/connections/?pagination={"page":1,"pageSize":10,"filters":{"status":"active"}} """ try: interface = getInterface(currentUser) # Parse pagination parameter paginationParams = None if pagination: try: paginationDict = json.loads(pagination) if paginationDict: # Normalize pagination dict (handles top-level "search" field) paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) except (json.JSONDecodeError, ValueError) as e: raise HTTPException( status_code=400, detail=f"Invalid pagination parameter: {str(e)}" ) # SECURITY FIX: All users (including admins) can only see their own connections # This prevents admin from seeing other users' connections and causing confusion connections = interface.getUserConnections(currentUser.id) # Perform silent token refresh for expired OAuth connections try: refresh_result = await token_refresh_service.refresh_expired_tokens(currentUser.id) if refresh_result.get("refreshed", 0) > 0: logger.info(f"Silently refreshed {refresh_result['refreshed']} tokens for user {currentUser.id}") # Re-fetch connections to get updated token status connections = interface.getUserConnections(currentUser.id) except Exception as e: logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}") # Continue with original connections even if refresh fails # Enhance each connection with token status information and convert to dict enhanced_connections_dict = [] for connection in connections: # Get token status for this connection tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id) # Convert to dict for filtering/sorting connection_dict = { "id": connection.id, "userId": connection.userId, "authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority), "externalId": connection.externalId, "externalUsername": connection.externalUsername or "", "externalEmail": connection.externalEmail, # Keep None instead of converting to empty string "status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status), "connectedAt": connection.connectedAt, "lastChecked": connection.lastChecked, "expiresAt": connection.expiresAt, "tokenStatus": tokenStatus, "tokenExpiresAt": tokenExpiresAt } enhanced_connections_dict.append(connection_dict) # If no pagination requested, return all items if paginationParams is None: # Convert back to UserConnection objects (enum strings are already in dict) items = [] for conn_dict in enhanced_connections_dict: conn_dict_copy = dict(conn_dict) if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str): try: conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"]) except ValueError: pass if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str): try: conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"]) except ValueError: pass items.append(UserConnection(**conn_dict_copy)) return PaginatedResponse( items=items, pagination=None ) # Apply filtering if provided if paginationParams.filters: component_interface = ComponentObjects() component_interface.setUserContext(currentUser) enhanced_connections_dict = component_interface._applyFilters( enhanced_connections_dict, paginationParams.filters ) # Apply sorting if provided if paginationParams.sort: component_interface = ComponentObjects() component_interface.setUserContext(currentUser) enhanced_connections_dict = component_interface._applySorting( enhanced_connections_dict, paginationParams.sort ) # Count total items after filters totalItems = len(enhanced_connections_dict) totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 # Apply pagination (skip/limit) startIdx = (paginationParams.page - 1) * paginationParams.pageSize endIdx = startIdx + paginationParams.pageSize paged_connections = enhanced_connections_dict[startIdx:endIdx] # Convert back to UserConnection objects (convert enum strings back to enums) items = [] for conn_dict in paged_connections: # Convert enum strings back to enum objects conn_dict_copy = dict(conn_dict) if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str): try: conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"]) except ValueError: pass # Keep as string if invalid if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str): try: conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"]) except ValueError: pass # Keep as string if invalid items.append(UserConnection(**conn_dict_copy)) return PaginatedResponse( items=items, pagination=PaginationMetadata( currentPage=paginationParams.page, pageSize=paginationParams.pageSize, totalItems=totalItems, totalPages=totalPages, sort=paginationParams.sort, filters=paginationParams.filters ) ) except HTTPException: raise except Exception as e: logger.error(f"Error getting connections: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get connections: {str(e)}" ) @router.post("/", response_model=UserConnection) @limiter.limit("10/minute") async def create_connection( request: Request, connection_data: Dict[str, Any] = Body(...), currentUser: User = Depends(getCurrentUser) ) -> UserConnection: """Create a new connection for the current user SECURITY: This endpoint is secure - it always creates connections for the current user and cannot be used to create connections for other users. """ try: interface = getInterface(currentUser) # Map type to authority authority_map = { 'msft': AuthAuthority.MSFT, 'google': AuthAuthority.GOOGLE } authority = authority_map.get(connection_data.get('type')) if not authority: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported connection type: {connection_data.get('type')}" ) # Get fresh copy of user from database user = interface.getUser(currentUser.id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) # Always create a new connection with PENDING status connection = interface.addUserConnection( userId=currentUser.id, authority=authority, externalId="", # Will be set after OAuth externalUsername="", # Will be set after OAuth status=ConnectionStatus.PENDING # Start with PENDING status ) # Save connection record - models now handle timestamp serialization automatically interface.db.recordModify(UserConnection, connection.id, connection.model_dump()) return connection except HTTPException: raise except Exception as e: logger.error(f"Error creating connection: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create connection: {str(e)}" ) @router.put("/{connectionId}", response_model=UserConnection) @limiter.limit("10/minute") async def update_connection( request: Request, connectionId: str = Path(..., description="The ID of the connection to update"), connection_data: Dict[str, Any] = Body(...), currentUser: User = Depends(getCurrentUser) ) -> UserConnection: """Update an existing connection for the current user SECURITY: This endpoint is secure - users can only update their own connections. """ try: interface = getInterface(currentUser) # Find the connection connection = None # SECURITY FIX: All users (including admins) can only update their own connections # This prevents admin from updating other users' connections and causing confusion connections = interface.getUserConnections(currentUser.id) for conn in connections: if conn.id == connectionId: connection = conn break if not connection: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Connection not found" ) # Update connection fields for field, value in connection_data.items(): if hasattr(connection, field): setattr(connection, field, value) # Update lastChecked timestamp using UTC timestamp connection.lastChecked = getUtcTimestamp() # Update connection - models now handle timestamp serialization automatically interface.db.recordModify(UserConnection, connectionId, connection.model_dump()) # Get token status for the updated connection tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connectionId) # Create enhanced connection with token status enhanced_connection = UserConnection( id=connection.id, userId=connection.userId, authority=connection.authority, externalId=connection.externalId, externalUsername=connection.externalUsername, externalEmail=connection.externalEmail, status=connection.status, connectedAt=connection.connectedAt, lastChecked=connection.lastChecked, expiresAt=connection.expiresAt, tokenStatus=tokenStatus, tokenExpiresAt=tokenExpiresAt ) return enhanced_connection except HTTPException: raise except Exception as e: logger.error(f"Error updating connection: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update connection: {str(e)}" ) @router.post("/{connectionId}/connect") @limiter.limit("10/minute") async def connect_service( request: Request, connectionId: str = Path(..., description="The ID of the connection to connect"), currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: """Connect a service for the current user SECURITY: This endpoint is secure - users can only connect their own connections. """ try: interface = getInterface(currentUser) # Find the connection connection = None # SECURITY FIX: All users (including admins) can only connect their own connections # This prevents admin from connecting other users' connections and causing confusion connections = interface.getUserConnections(currentUser.id) for conn in connections: if conn.id == connectionId: connection = conn break if not connection: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Connection not found" ) # Initiate OAuth flow with state=connect auth_url = None if connection.authority == AuthAuthority.MSFT: # Use the same login endpoint with state=connect to ensure account selection # Include current user ID in state state_data = { "type": "connect", "connectionId": connectionId, "userId": currentUser.id # Add current user ID } auth_url = f"/api/msft/login?state={json.dumps(state_data)}" elif connection.authority == AuthAuthority.GOOGLE: state_data = { "type": "connect", "connectionId": connectionId, "userId": currentUser.id # Add current user ID } auth_url = f"/api/google/login?state={json.dumps(state_data)}" else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported authority: {connection.authority}" ) return {"authUrl": auth_url} except HTTPException: raise except Exception as e: logger.error(f"Error connecting service: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to connect service: {str(e)}" ) @router.post("/{connectionId}/disconnect") @limiter.limit("10/minute") async def disconnect_service( request: Request, connectionId: str = Path(..., description="The ID of the connection to disconnect"), currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: """Disconnect a service for the current user SECURITY: This endpoint is secure - users can only disconnect their own connections. """ try: interface = getInterface(currentUser) # Find the connection connection = None # SECURITY FIX: All users (including admins) can only disconnect their own connections # This prevents admin from disconnecting other users' connections and causing confusion connections = interface.getUserConnections(currentUser.id) for conn in connections: if conn.id == connectionId: connection = conn break if not connection: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Connection not found" ) # Update connection status connection.status = ConnectionStatus.INACTIVE connection.lastChecked = getUtcTimestamp() # Update connection record - models now handle timestamp serialization automatically interface.db.recordModify(UserConnection, connectionId, connection.model_dump()) return {"message": "Service disconnected successfully"} except HTTPException: raise except Exception as e: logger.error(f"Error disconnecting service: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to disconnect service: {str(e)}" ) @router.delete("/{connectionId}") @limiter.limit("10/minute") async def delete_connection( request: Request, connectionId: str = Path(..., description="The ID of the connection to delete"), currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: """Delete a connection for the current user SECURITY: This endpoint is secure - users can only delete their own connections. """ try: interface = getInterface(currentUser) # Find the connection connection = None # SECURITY FIX: All users (including admins) can only delete their own connections # This prevents admin from deleting other users' connections and causing confusion connections = interface.getUserConnections(currentUser.id) for conn in connections: if conn.id == connectionId: connection = conn break if not connection: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Connection not found" ) # Remove the connection - only need connectionId since permissions are verified interface.removeUserConnection(connectionId) return {"message": "Connection deleted successfully"} except HTTPException: raise except Exception as e: logger.error(f"Error deleting connection: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete connection: {str(e)}" )