Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions agent_memory_server/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
_TASK_TTL_SECONDS = 7 * 24 * 60 * 60 # 7 days


class _UnsetType:
"""Sentinel type for distinguishing 'not provided' from ``None``."""

_instance: "_UnsetType | None" = None

def __new__(cls) -> "_UnsetType":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __repr__(self) -> str:
return "<UNSET>"


_UNSET = _UnsetType()


def _task_key(task_id: str) -> str:
"""Return the Redis key for a task JSON payload."""

Expand Down Expand Up @@ -60,11 +77,14 @@ async def update_task_status(
status: TaskStatusEnum | None = None,
started_at: datetime | None = None,
completed_at: datetime | None = None,
error_message: str | None = None,
error_message: str | None | _UnsetType = _UNSET,
) -> None:
"""Update status and timestamps for an existing Task.

If the task does not exist, this is a no-op.

Pass ``error_message=""`` to clear a previously set error message.
Omit ``error_message`` (or pass the default) to leave it unchanged.
"""

redis = await get_redis_conn()
Expand All @@ -89,8 +109,11 @@ async def update_task_status(
task.started_at = started_at
if completed_at is not None:
task.completed_at = completed_at
if error_message is not None:
task.error_message = error_message
if error_message is not _UNSET:
if error_message is None or error_message == "":
task.error_message = None
else:
task.error_message = error_message

# Ensure created_at is always set
if task.created_at is None:
Expand Down
72 changes: 72 additions & 0 deletions tests/integration/test_task_error_message_clearable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Test that task error_message can be cleared.

Regression test for https://github.com/redis/agent-memory-server/issues/206
"""

import pytest
from ulid import ULID

from agent_memory_server.models import Task, TaskStatusEnum, TaskTypeEnum
from agent_memory_server.tasks import create_task, get_task, update_task_status


def _make_task(**overrides) -> Task:
defaults = {
"id": str(ULID()),
"type": TaskTypeEnum.SUMMARY_VIEW_FULL_RUN,
"view_id": "test-view",
}
defaults.update(overrides)
return Task(**defaults)


class TestErrorMessageClearable:
"""error_message should be clearable by passing empty string."""

@pytest.mark.asyncio
async def test_clear_error_message_with_empty_string(self, async_redis_client):
"""Passing error_message='' should clear a previously set error."""
task = _make_task(status=TaskStatusEnum.FAILED)
await create_task(task)

# Set an error
await update_task_status(
task.id,
error_message="Something broke",
)
t1 = await get_task(task.id)
assert t1.error_message == "Something broke"

# Clear it
await update_task_status(
task.id,
error_message="",
)
t2 = await get_task(task.id)
assert (
t2.error_message is None
), "Empty string should clear error_message to None"

@pytest.mark.asyncio
async def test_none_does_not_change_error_message(self, async_redis_client):
"""Omitting error_message (defaults to _UNSET) should leave the field unchanged."""
task = _make_task(status=TaskStatusEnum.FAILED)
await create_task(task)

await update_task_status(task.id, error_message="Original error")
await update_task_status(task.id) # error_message defaults to None

t = await get_task(task.id)
assert t.error_message == "Original error"

@pytest.mark.asyncio
async def test_set_new_error_replaces_old(self, async_redis_client):
"""Passing a non-empty error_message should replace the existing one."""
task = _make_task(status=TaskStatusEnum.FAILED)
await create_task(task)

await update_task_status(task.id, error_message="First")
await update_task_status(task.id, error_message="Second")

t = await get_task(task.id)
assert t.error_message == "Second"
Loading