gateway/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py
ValueOn AG 48c0f900af rag
2026-05-12 15:19:01 +02:00

337 lines
12 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Background job service.
Generic infrastructure for fire-and-forget async tasks. Any caller (HTTP route,
AI tool, scheduler) can submit work and get a `jobId` back immediately; status
is polled via `GET /api/jobs/{jobId}`.
Usage (registration, once at module load):
from modules.serviceCenter.services.serviceBackgroundJobs import registerJobHandler
async def _myHandler(job, progressCb):
progressCb(10, "starting...")
result = await doExpensiveWork(job["payload"])
return result # stored as job.result
registerJobHandler("myJobType", _myHandler)
Usage (submission):
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
jobId = await startJob("myJobType", {"foo": "bar"}, mandateId=mid, triggeredBy=userId)
return {"jobId": jobId}
Restart semantics: jobs are tracked in DB. If the worker process dies mid-job,
`_recoverInterruptedJobs()` (called at boot) flips RUNNING jobs to ERROR with a
clear message. No silent zombies.
"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable, Dict, List, Optional
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.shared.dbRegistry import registerDatabase
from modules.datamodels.datamodelBackgroundJob import (
BackgroundJob,
BackgroundJobStatusEnum,
TERMINAL_JOB_STATUSES,
)
logger = logging.getLogger(__name__)
JOBS_DATABASE = APP_CONFIG.get("DB_DATABASE", "poweron_app")
registerDatabase(JOBS_DATABASE)
_CANCEL_CHECK_INTERVAL_S = 3.0
class JobProgressCallback:
"""Callable progress reporter with cooperative cancel-check for long-running walkers."""
def __init__(self, jobId: str):
self._jobId = jobId
self._cancelledCache: Optional[bool] = None
self._lastCheckedAt: float = 0.0
def __call__(self, progress: int, message: Optional[str] = None) -> None:
try:
clamped = max(0, min(100, int(progress)))
fields: Dict[str, Any] = {"progress": clamped}
if message is not None:
fields["progressMessage"] = message[:500]
_updateJob(self._jobId, fields)
except Exception as ex:
logger.warning("Progress update failed for job %s: %s", self._jobId, ex)
def isCancelled(self) -> bool:
"""Check if this job was cancelled. Reads DB at most every 3s to limit load."""
now = time.time()
if self._cancelledCache is True:
return True
if now - self._lastCheckedAt < _CANCEL_CHECK_INTERVAL_S:
return self._cancelledCache or False
self._lastCheckedAt = now
try:
job = _loadJob(self._jobId)
if job and job.get("status") == BackgroundJobStatusEnum.CANCELLED.value:
self._cancelledCache = True
return True
except Exception:
pass
self._cancelledCache = False
return False
JobHandler = Callable[[Dict[str, Any], JobProgressCallback], Awaitable[Optional[Dict[str, Any]]]]
_JOB_HANDLERS: Dict[str, JobHandler] = {}
def registerJobHandler(jobType: str, handler: JobHandler) -> None:
"""Register a handler for a job type. Idempotent — last registration wins."""
if jobType in _JOB_HANDLERS and _JOB_HANDLERS[jobType] is not handler:
logger.info("Re-registering background job handler for type %s", jobType)
_JOB_HANDLERS[jobType] = handler
def _getDb() -> DatabaseConnector:
return DatabaseConnector(
dbDatabase=JOBS_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def _serialiseDatetimes(data: Dict[str, Any]) -> Dict[str, Any]:
"""Return copy of dict with datetime values converted to ISO 8601 strings."""
out: Dict[str, Any] = {}
for k, v in data.items():
if isinstance(v, datetime):
out[k] = v.isoformat()
else:
out[k] = v
return out
async def startJob(
jobType: str,
payload: Optional[Dict[str, Any]] = None,
*,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
triggeredBy: Optional[str] = None,
) -> str:
"""Insert a new BackgroundJob, kick off its handler in the background, return jobId.
Returns immediately; the handler runs via `asyncio.create_task`.
"""
if jobType not in _JOB_HANDLERS:
raise ValueError(f"No handler registered for jobType '{jobType}'")
job = BackgroundJob(
jobType=jobType,
mandateId=mandateId,
featureInstanceId=featureInstanceId,
triggeredBy=triggeredBy,
payload=payload or {},
)
db = _getDb()
record = db.recordCreate(BackgroundJob, _serialiseDatetimes(job.model_dump()))
jobId = record["id"]
asyncio.create_task(_runJob(jobId))
logger.info(
"BackgroundJob %s submitted: type=%s mandate=%s instance=%s by=%s",
jobId, jobType, mandateId, featureInstanceId, triggeredBy,
)
return jobId
def _loadJob(jobId: str) -> Optional[Dict[str, Any]]:
db = _getDb()
rows = db.getRecordset(BackgroundJob, recordFilter={"id": jobId})
return dict(rows[0]) if rows else None
def _updateJob(jobId: str, fields: Dict[str, Any]) -> None:
db = _getDb()
db.recordModify(BackgroundJob, jobId, _serialiseDatetimes(fields))
def _markStarted(jobId: str) -> None:
_updateJob(jobId, {
"status": BackgroundJobStatusEnum.RUNNING.value,
"startedAt": datetime.now(timezone.utc).timestamp(),
})
def _markSuccess(jobId: str, result: Optional[Dict[str, Any]]) -> None:
_updateJob(jobId, {
"status": BackgroundJobStatusEnum.SUCCESS.value,
"result": result or {},
"progress": 100,
"finishedAt": datetime.now(timezone.utc).timestamp(),
})
def _markError(jobId: str, errorMessage: str) -> None:
truncated = (errorMessage or "")[:1000]
_updateJob(jobId, {
"status": BackgroundJobStatusEnum.ERROR.value,
"errorMessage": truncated,
"finishedAt": datetime.now(timezone.utc).timestamp(),
})
def _makeProgressCallback(jobId: str) -> JobProgressCallback:
return JobProgressCallback(jobId)
async def _runJob(jobId: str) -> None:
job = _loadJob(jobId)
if not job:
logger.error("BackgroundJob %s vanished before runner started", jobId)
return
handler = _JOB_HANDLERS.get(job["jobType"])
if not handler:
msg = f"No handler registered for jobType '{job['jobType']}'"
logger.error("BackgroundJob %s: %s", jobId, msg)
_markError(jobId, msg)
return
_markStarted(jobId)
try:
result = await handler(job, _makeProgressCallback(jobId))
_markSuccess(jobId, result if isinstance(result, dict) else None)
logger.info("BackgroundJob %s (%s) completed successfully", jobId, job["jobType"])
except Exception as e:
logger.exception("BackgroundJob %s (%s) failed", jobId, job["jobType"])
_markError(jobId, str(e))
def getJobStatus(jobId: str) -> Optional[Dict[str, Any]]:
"""Load current job state. Returns None if not found."""
return _loadJob(jobId)
def listJobs(
*,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
jobType: Optional[str] = None,
limit: int = 20,
) -> List[Dict[str, Any]]:
"""List recent jobs filtered by scope. Newest first."""
db = _getDb()
rows = db.getRecordset(BackgroundJob)
out = [dict(r) for r in rows]
if mandateId is not None:
out = [r for r in out if r.get("mandateId") == mandateId]
if featureInstanceId is not None:
out = [r for r in out if r.get("featureInstanceId") == featureInstanceId]
if jobType is not None:
out = [r for r in out if r.get("jobType") == jobType]
out.sort(key=lambda r: r.get("createdAt") or 0, reverse=True)
return out[:limit]
def isTerminalStatus(status: str) -> bool:
"""True if the given status is one of {SUCCESS, ERROR, CANCELLED}."""
return status in {s.value for s in TERMINAL_JOB_STATUSES}
def cancelJob(jobId: str, *, reason: str = "user_requested") -> bool:
"""Mark a job as CANCELLED. Walkers detect this via JobProgressCallback.isCancelled().
Returns False if the job is already in a terminal state or does not exist.
"""
job = _loadJob(jobId)
if not job:
return False
if isTerminalStatus(job.get("status", "")):
return False
_updateJob(jobId, {
"status": BackgroundJobStatusEnum.CANCELLED.value,
"errorMessage": f"cancelled: {reason}"[:1000],
"finishedAt": datetime.now(timezone.utc).timestamp(),
})
logger.info("BackgroundJob %s cancelled (reason=%s)", jobId, reason)
return True
def cancelJobsByConnection(connectionId: str, *, jobType: str = "connection.bootstrap") -> int:
"""Cancel all RUNNING/PENDING jobs whose payload.connectionId matches.
Returns count of jobs marked as cancelled.
"""
db = _getDb()
rows = db.getRecordset(BackgroundJob, recordFilter={"jobType": jobType})
count = 0
for row in rows:
status = row.get("status", "")
if status not in (BackgroundJobStatusEnum.PENDING.value, BackgroundJobStatusEnum.RUNNING.value):
continue
payload = row.get("payload") or {}
if payload.get("connectionId") == connectionId:
if cancelJob(row["id"], reason=f"connection_stop:{connectionId[:8]}"):
count += 1
return count
def recoverInterruptedJobs() -> int:
"""Flip any RUNNING jobs to ERROR and re-queue bootstrap jobs (called at worker boot).
A RUNNING job in the DB after process restart means the previous worker
died mid-execution; the asyncio task is gone and the job will never
finish on its own. For connection.bootstrap jobs, a fresh job is
automatically re-queued so the user doesn't have to manually retry.
"""
db = _getDb()
try:
rows = db.getRecordset(BackgroundJob, recordFilter={"status": BackgroundJobStatusEnum.RUNNING.value})
except Exception as ex:
logger.warning("recoverInterruptedJobs: failed to scan RUNNING jobs: %s", ex)
return 0
count = 0
requeued = 0
for row in rows:
try:
_markError(row["id"], "Interrupted by worker restart")
count += 1
except Exception as ex:
logger.warning("recoverInterruptedJobs: could not mark %s as ERROR: %s", row.get("id"), ex)
continue
if row.get("jobType") == "connection.bootstrap":
payload = row.get("payload") or {}
if payload.get("connectionId"):
try:
newJob = BackgroundJob(
jobType="connection.bootstrap",
payload=payload,
triggeredBy="recovery.requeue",
)
record = db.recordCreate(BackgroundJob, _serialiseDatetimes(newJob.model_dump()))
asyncio.create_task(_runJob(record["id"]))
requeued += 1
logger.info(
"recoverInterruptedJobs: re-queued bootstrap for connectionId=%s (new jobId=%s)",
payload["connectionId"], record["id"],
)
except Exception as reqEx:
logger.warning("recoverInterruptedJobs: re-queue failed for %s: %s", row.get("id"), reqEx)
if count:
logger.warning("Recovered %d interrupted background job(s) after restart (re-queued %d)", count, requeued)
return count