diff --git a/apps/worker/services/processing/state.py b/apps/worker/services/processing/state.py index 4a1c7d973b..4f713a4e19 100644 --- a/apps/worker/services/processing/state.py +++ b/apps/worker/services/processing/state.py @@ -23,8 +23,15 @@ from dataclasses import dataclass +from sqlalchemy import case, func +from sqlalchemy.orm import Session + +from database.enums import ReportType +from database.models.core import Commit +from database.models.reports import CommitReport, Upload from shared.helpers.redis import get_redis_connection from shared.metrics import Counter +from shared.reports.enums import UploadState MERGE_BATCH_SIZE = 10 @@ -75,18 +82,53 @@ def should_trigger_postprocessing(uploads: UploadNumbers) -> bool: class ProcessingState: - def __init__(self, repoid: int, commitsha: str) -> None: + def __init__( + self, repoid: int, commitsha: str, db_session: Session | None = None + ) -> None: self._redis = get_redis_connection() self.repoid = repoid self.commitsha = commitsha + self._db_session = db_session def get_upload_numbers(self): + if self._db_session: + row = ( + self._db_session.query( + func.count( + case( + ( + Upload.state_id == UploadState.UPLOADED.db_id, + Upload.id_, + ), + ) + ), + func.count( + case( + ( + Upload.state_id == UploadState.PROCESSED.db_id, + Upload.id_, + ), + ) + ), + ) + .join(CommitReport, Upload.report_id == CommitReport.id_) + .join(Commit, CommitReport.commit_id == Commit.id_) + .filter( + Commit.repoid == self.repoid, + Commit.commitid == self.commitsha, + (CommitReport.report_type == None) # noqa: E711 + | (CommitReport.report_type == ReportType.COVERAGE.value), + ) + .one() + ) + return UploadNumbers(processing=row[0], processed=row[1]) + processing = self._redis.scard(self._redis_key("processing")) processed = self._redis.scard(self._redis_key("processed")) return UploadNumbers(processing, processed) def mark_uploads_as_processing(self, upload_ids: list[int]): - if not upload_ids: + if not upload_ids or self._db_session: return key = self._redis_key("processing") self._redis.sadd(key, *upload_ids) @@ -97,6 +139,27 @@ def mark_uploads_as_processing(self, upload_ids: list[int]): def clear_in_progress_uploads(self, upload_ids: list[int]): if not upload_ids: return + if self._db_session: + # Mark still-UPLOADED uploads as ERROR so they stop being counted + # as "processing" in get_upload_numbers(). Only matches UPLOADED -- + # already-PROCESSED uploads (success path) are unaffected. + updated = ( + self._db_session.query(Upload) + .filter( + Upload.id_.in_(upload_ids), + Upload.state_id == UploadState.UPLOADED.db_id, + ) + .update( + { + Upload.state_id: UploadState.ERROR.db_id, + Upload.state: "error", + }, + synchronize_session="fetch", + ) + ) + if updated > 0: + CLEARED_UPLOADS.inc(updated) + return removed_uploads = self._redis.srem(self._redis_key("processing"), *upload_ids) if removed_uploads > 0: # the normal flow would move the uploads from the "processing" set @@ -106,6 +169,15 @@ def clear_in_progress_uploads(self, upload_ids: list[int]): CLEARED_UPLOADS.inc(removed_uploads) def mark_upload_as_processed(self, upload_id: int): + if self._db_session: + upload = self._db_session.query(Upload).get(upload_id) + if upload: + upload.state_id = UploadState.PROCESSED.db_id + # Don't set upload.state here -- the finisher's idempotency check + # uses state="processed" to detect already-merged uploads. + # The state string is set by update_uploads() after merging. + return + processing_key = self._redis_key("processing") processed_key = self._redis_key("processed") @@ -124,9 +196,36 @@ def mark_upload_as_processed(self, upload_id: int): def mark_uploads_as_merged(self, upload_ids: list[int]): if not upload_ids: return + if self._db_session: + self._db_session.query(Upload).filter(Upload.id_.in_(upload_ids)).update( + { + Upload.state_id: UploadState.MERGED.db_id, + Upload.state: "merged", + }, + synchronize_session="fetch", + ) + return + self._redis.srem(self._redis_key("processed"), *upload_ids) def get_uploads_for_merging(self) -> set[int]: + if self._db_session: + rows = ( + self._db_session.query(Upload.id_) + .join(CommitReport, Upload.report_id == CommitReport.id_) + .join(Commit, CommitReport.commit_id == Commit.id_) + .filter( + Commit.repoid == self.repoid, + Commit.commitid == self.commitsha, + (CommitReport.report_type == None) # noqa: E711 + | (CommitReport.report_type == ReportType.COVERAGE.value), + Upload.state_id == UploadState.PROCESSED.db_id, + ) + .limit(MERGE_BATCH_SIZE) + .all() + ) + return {row[0] for row in rows} + return { int(id) for id in self._redis.srandmember( diff --git a/apps/worker/services/tests/test_processing_state.py b/apps/worker/services/tests/test_processing_state.py index 0416e49a63..3433c534da 100644 --- a/apps/worker/services/tests/test_processing_state.py +++ b/apps/worker/services/tests/test_processing_state.py @@ -3,11 +3,18 @@ import pytest +from database.tests.factories.core import ( + CommitFactory, + ReportFactory, + RepositoryFactory, + UploadFactory, +) from services.processing.state import ( ProcessingState, should_perform_merge, should_trigger_postprocessing, ) +from shared.reports.enums import UploadState def test_single_upload(): @@ -189,3 +196,178 @@ def test_empty_list_guards_parametrized(method_name, upload_ids, should_call_red assert mock_redis.srem.call_count == 0, ( f"{method_name} should not call srem" ) + + +@pytest.mark.django_db(databases={"default"}) +class TestProcessingStateDBPath: + """Tests for the DB-backed path of ProcessingState (when db_session is provided).""" + + @pytest.fixture + def setup_commit(self, dbsession): + repository = RepositoryFactory.create() + dbsession.add(repository) + dbsession.flush() + commit = CommitFactory.create(repository=repository) + dbsession.add(commit) + dbsession.flush() + report = ReportFactory.create(commit=commit) + dbsession.add(report) + dbsession.flush() + return repository, commit, report + + def _create_upload(self, dbsession, report, state_id): + upload = UploadFactory.create( + report=report, + state="uploaded", + state_id=state_id, + ) + dbsession.add(upload) + dbsession.flush() + return upload + + def test_get_upload_numbers_empty(self, dbsession, setup_commit): + _, commit, _ = setup_commit + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + numbers = state.get_upload_numbers() + assert numbers.processing == 0 + assert numbers.processed == 0 + + def test_get_upload_numbers_with_uploads(self, dbsession, setup_commit): + _, commit, report = setup_commit + self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + numbers = state.get_upload_numbers() + assert numbers.processing == 2 + assert numbers.processed == 1 + + def test_mark_upload_as_processed(self, dbsession, setup_commit): + _, commit, report = setup_commit + upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.mark_upload_as_processed(upload.id_) + dbsession.flush() + + dbsession.refresh(upload) + assert upload.state_id == UploadState.PROCESSED.db_id + # state string is intentionally NOT set here -- the finisher's + # idempotency check uses state="processed" to detect already-merged uploads + assert upload.state == "uploaded" + + def test_mark_uploads_as_merged(self, dbsession, setup_commit): + _, commit, report = setup_commit + u1 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + u2 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.mark_uploads_as_merged([u1.id_, u2.id_]) + dbsession.flush() + + dbsession.refresh(u1) + dbsession.refresh(u2) + assert u1.state_id == UploadState.MERGED.db_id + assert u1.state == "merged" + assert u2.state_id == UploadState.MERGED.db_id + assert u2.state == "merged" + + def test_mark_uploads_as_merged_empty_list(self, dbsession, setup_commit): + _, commit, _ = setup_commit + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.mark_uploads_as_merged([]) + + def test_get_uploads_for_merging(self, dbsession, setup_commit): + _, commit, report = setup_commit + u1 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + u2 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + self._create_upload(dbsession, report, UploadState.MERGED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + merging = state.get_uploads_for_merging() + assert merging == {u1.id_, u2.id_} + + def test_get_uploads_for_merging_respects_batch_size(self, dbsession, setup_commit): + _, commit, report = setup_commit + for _ in range(15): + self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + merging = state.get_uploads_for_merging() + assert len(merging) == 10 + + def test_mark_uploads_as_processing_is_noop(self, dbsession, setup_commit): + _, commit, report = setup_commit + upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.mark_uploads_as_processing([upload.id_]) + + dbsession.refresh(upload) + assert upload.state_id == UploadState.UPLOADED.db_id + + def test_clear_in_progress_uploads_sets_error_on_uploaded( + self, dbsession, setup_commit + ): + """Uploaded uploads are set to ERROR so they stop counting as 'processing'.""" + _, commit, report = setup_commit + upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.clear_in_progress_uploads([upload.id_]) + + dbsession.refresh(upload) + assert upload.state_id == UploadState.ERROR.db_id + assert upload.state == "error" + + def test_clear_in_progress_uploads_skips_processed(self, dbsession, setup_commit): + """Already-processed uploads are not affected (success path in finally block).""" + _, commit, report = setup_commit + upload = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + state.clear_in_progress_uploads([upload.id_]) + + dbsession.refresh(upload) + assert upload.state_id == UploadState.PROCESSED.db_id + + def test_full_lifecycle(self, dbsession, setup_commit): + """End-to-end: UPLOADED -> PROCESSED -> MERGED with DB state.""" + _, commit, report = setup_commit + u1 = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + u2 = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id) + + state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession) + + numbers = state.get_upload_numbers() + assert numbers.processing == 2 + assert numbers.processed == 0 + + state.mark_upload_as_processed(u1.id_) + dbsession.flush() + + numbers = state.get_upload_numbers() + assert numbers.processing == 1 + assert numbers.processed == 1 + assert not should_perform_merge(numbers) + + state.mark_upload_as_processed(u2.id_) + dbsession.flush() + + numbers = state.get_upload_numbers() + assert numbers.processing == 0 + assert numbers.processed == 2 + assert should_perform_merge(numbers) + + merging = state.get_uploads_for_merging() + assert merging == {u1.id_, u2.id_} + + state.mark_uploads_as_merged(list(merging)) + dbsession.flush() + + numbers = state.get_upload_numbers() + assert numbers.processing == 0 + assert numbers.processed == 0 + assert should_trigger_postprocessing(numbers)