diff --git a/agent_memory_server/tasks.py b/agent_memory_server/tasks.py index b05195b..4ea4320 100644 --- a/agent_memory_server/tasks.py +++ b/agent_memory_server/tasks.py @@ -13,6 +13,23 @@ _TASK_TTL_SECONDS = 7 * 24 * 60 * 60 # 7 days +class _UnsetType: + """Sentinel type for distinguishing 'not provided' from ``None``.""" + + _instance: "_UnsetType | None" = None + + def __new__(cls) -> "_UnsetType": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "" + + +_UNSET = _UnsetType() + + class InvalidTaskTransitionError(Exception): """Raised when a task status transition is not allowed.""" @@ -74,12 +91,15 @@ async def update_task_status( status: TaskStatusEnum | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, - error_message: str | None = None, + error_message: str | None | _UnsetType = _UNSET, ) -> None: """Update status and timestamps for an existing Task. If the task does not exist, this is a no-op. + Pass ``error_message=""`` to clear a previously set error message. + Omit ``error_message`` (or pass the default) to leave it unchanged. + Raises: InvalidTaskTransitionError: If the requested status transition violates the task state machine. @@ -113,8 +133,11 @@ async def update_task_status( task.started_at = started_at if completed_at is not None: task.completed_at = completed_at - if error_message is not None: - task.error_message = error_message + if error_message is not _UNSET: + if error_message is None or error_message == "": + task.error_message = None + else: + task.error_message = error_message # Validate timestamp ordering only when timestamps are being changed. # This avoids rejecting status-only updates on tasks that already have diff --git a/tests/integration/test_task_error_message_clearable.py b/tests/integration/test_task_error_message_clearable.py new file mode 100644 index 0000000..4a2232e --- /dev/null +++ b/tests/integration/test_task_error_message_clearable.py @@ -0,0 +1,72 @@ +"""Test that task error_message can be cleared. + +Regression test for https://github.com/redis/agent-memory-server/issues/206 +""" + +import pytest +from ulid import ULID + +from agent_memory_server.models import Task, TaskStatusEnum, TaskTypeEnum +from agent_memory_server.tasks import create_task, get_task, update_task_status + + +def _make_task(**overrides) -> Task: + defaults = { + "id": str(ULID()), + "type": TaskTypeEnum.SUMMARY_VIEW_FULL_RUN, + "view_id": "test-view", + } + defaults.update(overrides) + return Task(**defaults) + + +class TestErrorMessageClearable: + """error_message should be clearable by passing empty string.""" + + @pytest.mark.asyncio + async def test_clear_error_message_with_empty_string(self, async_redis_client): + """Passing error_message='' should clear a previously set error.""" + task = _make_task(status=TaskStatusEnum.FAILED) + await create_task(task) + + # Set an error + await update_task_status( + task.id, + error_message="Something broke", + ) + t1 = await get_task(task.id) + assert t1.error_message == "Something broke" + + # Clear it + await update_task_status( + task.id, + error_message="", + ) + t2 = await get_task(task.id) + assert ( + t2.error_message is None + ), "Empty string should clear error_message to None" + + @pytest.mark.asyncio + async def test_none_does_not_change_error_message(self, async_redis_client): + """Omitting error_message (defaults to _UNSET) should leave the field unchanged.""" + task = _make_task(status=TaskStatusEnum.FAILED) + await create_task(task) + + await update_task_status(task.id, error_message="Original error") + await update_task_status(task.id) # error_message defaults to None + + t = await get_task(task.id) + assert t.error_message == "Original error" + + @pytest.mark.asyncio + async def test_set_new_error_replaces_old(self, async_redis_client): + """Passing a non-empty error_message should replace the existing one.""" + task = _make_task(status=TaskStatusEnum.FAILED) + await create_task(task) + + await update_task_status(task.id, error_message="First") + await update_task_status(task.id, error_message="Second") + + t = await get_task(task.id) + assert t.error_message == "Second"