gateway/tests/unit/workflows/test_state_management.py
2025-12-15 21:55:26 +01:00

172 lines
5.3 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Unit tests for workflow state management in ChatWorkflow and TaskContext
Tests state increment methods, helper methods, and updateFromSelection.
"""
import pytest
import uuid
from modules.datamodels.datamodelChat import ChatWorkflow, TaskContext, TaskStep
from modules.datamodels.datamodelWorkflow import ActionDefinition
class TestChatWorkflowStateManagement:
"""Test ChatWorkflow state management methods"""
def test_chatWorkflow_initial_state(self):
"""Test initial state of ChatWorkflow"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate"
)
assert workflow.currentRound == 0
assert workflow.currentTask == 0
assert workflow.currentAction == 0
def test_chatWorkflow_getRoundIndex(self):
"""Test getRoundIndex() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentRound=2
)
assert workflow.getRoundIndex() == 2
def test_chatWorkflow_getTaskIndex(self):
"""Test getTaskIndex() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentTask=3
)
assert workflow.getTaskIndex() == 3
def test_chatWorkflow_getActionIndex(self):
"""Test getActionIndex() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentAction=5
)
assert workflow.getActionIndex() == 5
def test_chatWorkflow_incrementRound(self):
"""Test incrementRound() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentRound=1
)
workflow.incrementRound()
assert workflow.currentRound == 2
def test_chatWorkflow_incrementTask(self):
"""Test incrementTask() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentTask=1
)
workflow.incrementTask()
assert workflow.currentTask == 2
def test_chatWorkflow_incrementAction(self):
"""Test incrementAction() method"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate",
currentAction=1
)
workflow.incrementAction()
assert workflow.currentAction == 2
def test_chatWorkflow_state_sequence(self):
"""Test state increment sequence"""
workflow = ChatWorkflow(
id=str(uuid.uuid4()),
name="Test Workflow",
mandateId="test_mandate"
)
# Start at round 0, task 0, action 0
assert workflow.currentRound == 0
assert workflow.currentTask == 0
assert workflow.currentAction == 0
# Increment action
workflow.incrementAction()
assert workflow.currentAction == 1
# Increment task (should reset action)
workflow.incrementTask()
assert workflow.currentTask == 1
assert workflow.currentAction == 0
# Increment round (should reset task and action)
workflow.incrementRound()
assert workflow.currentRound == 1
assert workflow.currentTask == 0
assert workflow.currentAction == 0
class TestTaskContextUpdateFromSelection:
"""Test TaskContext.updateFromSelection() method"""
def test_taskContext_updateFromSelection(self):
"""Test updateFromSelection() with ActionDefinition"""
taskStep = TaskStep(
id="step1",
objective="Test objective"
)
context = TaskContext(
taskStep=taskStep
)
actionDef = ActionDefinition(
action="ai.process",
actionObjective="Process documents",
parametersContext="Some context",
learnings=["Learning 1", "Learning 2"]
)
context.updateFromSelection(actionDef)
assert context.actionObjective == "Process documents"
assert context.parametersContext == "Some context"
assert len(context.learnings) == 2
assert "Learning 1" in context.learnings
def test_taskContext_updateFromSelection_partial(self):
"""Test updateFromSelection() with partial ActionDefinition"""
taskStep = TaskStep(
id="step1",
objective="Test objective"
)
context = TaskContext(
taskStep=taskStep
)
actionDef = ActionDefinition(
action="ai.process",
actionObjective="Process documents"
)
context.updateFromSelection(actionDef)
assert context.actionObjective == "Process documents"
assert context.parametersContext is None
assert len(context.learnings) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])