Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions apps/worker/services/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def process_upload(
upload = db_session.query(Upload).filter_by(id_=upload_id).first()
assert upload

state = ProcessingState(repo_id, commit_sha)
# this in a noop in normal cases, but relevant for task retries:
state = ProcessingState(repo_id, commit_sha, db_session=db_session)
# this is a noop in normal cases, but relevant for task retries:
state.mark_uploads_as_processing([upload_id])

report_service = ReportService(commit_yaml)
Expand Down Expand Up @@ -108,7 +108,6 @@ def process_upload(
celery_app.tasks[upload_finisher_task_name].apply_async(
kwargs=finisher_kwargs
)

rewrite_or_delete_upload(archive_service, commit_yaml, report_info)

except CeleryError:
Expand Down
151 changes: 107 additions & 44 deletions apps/worker/services/processing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@
"intermediate report".
"""

import logging
from dataclasses import dataclass

from shared.helpers.redis import get_redis_connection
from sqlalchemy import case, func
from sqlalchemy.orm import Session

from database.enums import ReportType
from database.models.core import Commit
from database.models.reports import CommitReport, Upload
from shared.metrics import Counter
from shared.reports.enums import UploadState

MERGE_BATCH_SIZE = 10
log = logging.getLogger(__name__)

# TTL for processing state keys in Redis (24 hours, matches intermediate report TTL)
# This prevents state keys from accumulating indefinitely and ensures consistency
# with intermediate report expiration
PROCESSING_STATE_TTL = 24 * 60 * 60
MERGE_BATCH_SIZE = 10

CLEARED_UPLOADS = Counter(
"worker_processing_cleared_uploads",
Expand Down Expand Up @@ -75,59 +79,118 @@ def should_trigger_postprocessing(uploads: UploadNumbers) -> bool:


class ProcessingState:
def __init__(self, repoid: int, commitsha: str) -> None:
self._redis = get_redis_connection()
def __init__(self, repoid: int, commitsha: str, db_session: Session) -> None:
self.repoid = repoid
self.commitsha = commitsha
self._db_session = db_session

def get_upload_numbers(self):
processing = self._redis.scard(self._redis_key("processing"))
processed = self._redis.scard(self._redis_key("processed"))
return UploadNumbers(processing, processed)
row = (
self._db_session.query(
func.count(
case(
(
Upload.state_id == UploadState.UPLOADED.db_id,
Upload.id_,
),
)
),
func.count(
case(
(
Upload.state_id == UploadState.PROCESSED.db_id,
Upload.id_,
),
)
),
)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
)
.one()
)
return UploadNumbers(processing=row[0], processed=row[1])

def mark_uploads_as_processing(self, upload_ids: list[int]):
if not upload_ids:
return
key = self._redis_key("processing")
self._redis.sadd(key, *upload_ids)
# Set TTL to match intermediate report expiration (24 hours)
# This ensures state keys don't accumulate indefinitely
self._redis.expire(key, PROCESSING_STATE_TTL)
# No-op: uploads are created with state_id=UPLOADED, which
# get_upload_numbers() already counts as "processing".
pass

def clear_in_progress_uploads(self, upload_ids: list[int]):
if not upload_ids:
return
removed_uploads = self._redis.srem(self._redis_key("processing"), *upload_ids)
if removed_uploads > 0:
# the normal flow would move the uploads from the "processing" set
# to the "processed" set via `mark_upload_as_processed`.
# this function here is only called in the error case and we don't expect
# this to be triggered often, if at all.
CLEARED_UPLOADS.inc(removed_uploads)
# Mark still-UPLOADED uploads as ERROR so they stop being counted
# as "processing" in get_upload_numbers(). Only matches UPLOADED --
# already-PROCESSED uploads (success path) are unaffected.
#
# This runs in a finally block, so the transaction may already be
# in a failed state. Best-effort: log and move on if the DB is
# unreachable — the upload stays UPLOADED, which is safe.
try:
updated = (
self._db_session.query(Upload)
.filter(
Upload.id_.in_(upload_ids),
Upload.state_id == UploadState.UPLOADED.db_id,
)
.update(
{
Upload.state_id: UploadState.ERROR.db_id,
Upload.state: "error",
},
synchronize_session="fetch",
)
)
if updated > 0:
CLEARED_UPLOADS.inc(updated)
except Exception:
log.warning(
"Failed to clear in-progress uploads (transaction may be aborted)",
extra={"upload_ids": upload_ids},
exc_info=True,
)

def mark_upload_as_processed(self, upload_id: int):
processing_key = self._redis_key("processing")
processed_key = self._redis_key("processed")

res = self._redis.smove(processing_key, processed_key, upload_id)
if not res:
# this can happen when `upload_id` was never in the source set,
# which probably is the case during initial deployment as
# the code adding this to the initial set was not deployed yet
# TODO: make sure to remove this code after a grace period
self._redis.sadd(processed_key, upload_id)

# Set TTL on processed key to match intermediate report expiration
# This ensures uploads marked as processed have a bounded lifetime
self._redis.expire(processed_key, PROCESSING_STATE_TTL)
upload = self._db_session.query(Upload).get(upload_id)
if upload:
upload.state_id = UploadState.PROCESSED.db_id
# Don't set upload.state here -- the finisher's idempotency check
# uses state="processed" to detect already-merged uploads.
# The state string is set by update_uploads() after merging.

def mark_uploads_as_merged(self, upload_ids: list[int]):
if not upload_ids:
return
self._redis.srem(self._redis_key("processed"), *upload_ids)
self._db_session.query(Upload).filter(
Upload.id_.in_(upload_ids),
Upload.state_id == UploadState.PROCESSED.db_id,
).update(
{
Upload.state_id: UploadState.MERGED.db_id,
Upload.state: "merged",
},
synchronize_session="fetch",
)
self._db_session.commit()

def get_uploads_for_merging(self) -> set[int]:
return {int(id) for id in self._redis.smembers(self._redis_key("processed"))}

def _redis_key(self, state: str) -> str:
return f"upload-processing-state/{self.repoid}/{self.commitsha}/{state}"
rows = (
self._db_session.query(Upload.id_)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
Upload.state_id == UploadState.PROCESSED.db_id,
)
.limit(MERGE_BATCH_SIZE)
.all()
)
return {row[0] for row in rows}
68 changes: 68 additions & 0 deletions apps/worker/services/tests/test_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest

from database.tests.factories.core import (
CommitFactory,
ReportFactory,
RepositoryFactory,
UploadFactory,
)
from services.processing.merging import update_uploads
from services.processing.types import MergeResult, ProcessingResult
from shared.reports.enums import UploadState
from shared.yaml import UserYaml


@pytest.mark.django_db(databases={"default"})
class TestUpdateUploadsState:
def test_successful_uploads_set_to_merged(self, dbsession):
repository = RepositoryFactory.create()
commit = CommitFactory.create(repository=repository)
report = ReportFactory.create(commit=commit)
upload = UploadFactory.create(
report=report,
state="started",
state_id=UploadState.UPLOADED.db_id,
)
dbsession.add_all([repository, commit, report, upload])
dbsession.flush()

processing_results: list[ProcessingResult] = [
{"upload_id": upload.id_, "successful": True, "arguments": {}},
]
merge_result = MergeResult(
session_mapping={upload.id_: 0}, deleted_sessions=set()
)

update_uploads(dbsession, UserYaml({}), processing_results, [], merge_result)

dbsession.refresh(upload)
assert upload.state_id == UploadState.MERGED.db_id
assert upload.state == "merged"

def test_failed_uploads_set_to_error(self, dbsession):
repository = RepositoryFactory.create()
commit = CommitFactory.create(repository=repository)
report = ReportFactory.create(commit=commit)
upload = UploadFactory.create(
report=report,
state="started",
state_id=UploadState.UPLOADED.db_id,
)
dbsession.add_all([repository, commit, report, upload])
dbsession.flush()

processing_results: list[ProcessingResult] = [
{
"upload_id": upload.id_,
"successful": False,
"arguments": {},
"error": {"code": "report_empty", "params": {}},
},
]
merge_result = MergeResult(session_mapping={}, deleted_sessions=set())

update_uploads(dbsession, UserYaml({}), processing_results, [], merge_result)

dbsession.refresh(upload)
assert upload.state_id == UploadState.ERROR.db_id
assert upload.state == "error"
13 changes: 7 additions & 6 deletions apps/worker/services/tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from services.processing.processing import process_upload
from services.processing.types import UploadArguments
from shared.reports.enums import UploadState
from shared.yaml import UserYaml


Expand All @@ -23,6 +24,7 @@ def test_triggers_finisher_when_gate_is_acquired(
upload = UploadFactory.create(
report__commit=commit,
state="started",
state_id=UploadState.UPLOADED.db_id,
)
dbsession.add_all([repository, commit, upload])
dbsession.flush()
Expand All @@ -34,7 +36,6 @@ def test_triggers_finisher_when_gate_is_acquired(
"reportid": str(upload.report.external_id),
}

# Mock dependencies
mock_report_service = mocker.patch(
"services.processing.processing.ReportService"
)
Expand Down Expand Up @@ -63,13 +64,10 @@ def test_triggers_finisher_when_gate_is_acquired(
mock_redis = mocker.patch("services.processing.processing.get_redis_connection")
mock_redis.return_value.set.return_value = True

# Mock other dependencies
mocker.patch("services.processing.processing.save_intermediate_report")
mocker.patch("services.processing.processing.rewrite_or_delete_upload")

commit_yaml = UserYaml({})

# Execute
result = process_upload(
on_processing_error=lambda error: None,
db_session=dbsession,
Expand All @@ -79,10 +77,8 @@ def test_triggers_finisher_when_gate_is_acquired(
arguments=arguments,
)

# Verify
assert result["successful"] is True
assert result["upload_id"] == upload.id_

# Verify finisher was triggered
mock_finisher_task.apply_async.assert_called_once_with(
kwargs={
Expand All @@ -92,6 +88,11 @@ def test_triggers_finisher_when_gate_is_acquired(
}
)
mock_redis.return_value.set.assert_called_once()
dbsession.refresh(upload)
assert upload.state_id == UploadState.PROCESSED.db_id
# state string is not updated by the processor -- the finisher sets it
# after merging (to avoid triggering the finisher's idempotency check early)
assert upload.state == "started"

def test_does_not_trigger_finisher_when_gate_exists(
self, dbsession, mocker, mock_storage
Expand Down
Loading
Loading