Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions apps/worker/services/processing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@

from dataclasses import dataclass

from sqlalchemy import case, func
from sqlalchemy.orm import Session

from database.enums import ReportType
from database.models.core import Commit
from database.models.reports import CommitReport, Upload
from shared.helpers.redis import get_redis_connection
from shared.metrics import Counter
from shared.reports.enums import UploadState

MERGE_BATCH_SIZE = 10

Expand Down Expand Up @@ -75,18 +82,53 @@ def should_trigger_postprocessing(uploads: UploadNumbers) -> bool:


class ProcessingState:
def __init__(self, repoid: int, commitsha: str) -> None:
def __init__(
self, repoid: int, commitsha: str, db_session: Session | None = None
) -> None:
self._redis = get_redis_connection()
self.repoid = repoid
self.commitsha = commitsha
self._db_session = db_session

def get_upload_numbers(self):
if self._db_session:
row = (
self._db_session.query(
func.count(
case(
(
Upload.state_id == UploadState.UPLOADED.db_id,
Upload.id_,
),
)
),
func.count(
case(
(
Upload.state_id == UploadState.PROCESSED.db_id,
Upload.id_,
),
)
),
)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
)
.one()
)
return UploadNumbers(processing=row[0], processed=row[1])

processing = self._redis.scard(self._redis_key("processing"))
processed = self._redis.scard(self._redis_key("processed"))
return UploadNumbers(processing, processed)

def mark_uploads_as_processing(self, upload_ids: list[int]):
if not upload_ids:
if not upload_ids or self._db_session:
return
key = self._redis_key("processing")
self._redis.sadd(key, *upload_ids)
Expand All @@ -97,6 +139,27 @@ def mark_uploads_as_processing(self, upload_ids: list[int]):
def clear_in_progress_uploads(self, upload_ids: list[int]):
if not upload_ids:
return
if self._db_session:
# Mark still-UPLOADED uploads as ERROR so they stop being counted
# as "processing" in get_upload_numbers(). Only matches UPLOADED --
# already-PROCESSED uploads (success path) are unaffected.
updated = (
self._db_session.query(Upload)
.filter(
Upload.id_.in_(upload_ids),
Upload.state_id == UploadState.UPLOADED.db_id,
)
.update(
{
Upload.state_id: UploadState.ERROR.db_id,
Upload.state: "error",
},
synchronize_session="fetch",
)
)
if updated > 0:
CLEARED_UPLOADS.inc(updated)
return
removed_uploads = self._redis.srem(self._redis_key("processing"), *upload_ids)
if removed_uploads > 0:
# the normal flow would move the uploads from the "processing" set
Expand All @@ -106,6 +169,15 @@ def clear_in_progress_uploads(self, upload_ids: list[int]):
CLEARED_UPLOADS.inc(removed_uploads)

def mark_upload_as_processed(self, upload_id: int):
if self._db_session:
upload = self._db_session.query(Upload).get(upload_id)
if upload:
upload.state_id = UploadState.PROCESSED.db_id
# Don't set upload.state here -- the finisher's idempotency check
# uses state="processed" to detect already-merged uploads.
# The state string is set by update_uploads() after merging.
return

processing_key = self._redis_key("processing")
processed_key = self._redis_key("processed")

Expand All @@ -124,9 +196,36 @@ def mark_upload_as_processed(self, upload_id: int):
def mark_uploads_as_merged(self, upload_ids: list[int]):
if not upload_ids:
return
if self._db_session:
self._db_session.query(Upload).filter(Upload.id_.in_(upload_ids)).update(
{
Upload.state_id: UploadState.MERGED.db_id,
Upload.state: "merged",
},
synchronize_session="fetch",
)
return

self._redis.srem(self._redis_key("processed"), *upload_ids)

def get_uploads_for_merging(self) -> set[int]:
if self._db_session:
rows = (
self._db_session.query(Upload.id_)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
Upload.state_id == UploadState.PROCESSED.db_id,
)
.limit(MERGE_BATCH_SIZE)
.all()
)
return {row[0] for row in rows}

return {
int(id)
for id in self._redis.srandmember(
Expand Down
182 changes: 182 additions & 0 deletions apps/worker/services/tests/test_processing_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@

import pytest

from database.tests.factories.core import (
CommitFactory,
ReportFactory,
RepositoryFactory,
UploadFactory,
)
from services.processing.state import (
ProcessingState,
should_perform_merge,
should_trigger_postprocessing,
)
from shared.reports.enums import UploadState


def test_single_upload():
Expand Down Expand Up @@ -189,3 +196,178 @@ def test_empty_list_guards_parametrized(method_name, upload_ids, should_call_red
assert mock_redis.srem.call_count == 0, (
f"{method_name} should not call srem"
)


@pytest.mark.django_db(databases={"default"})
class TestProcessingStateDBPath:
"""Tests for the DB-backed path of ProcessingState (when db_session is provided)."""

@pytest.fixture
def setup_commit(self, dbsession):
repository = RepositoryFactory.create()
dbsession.add(repository)
dbsession.flush()
commit = CommitFactory.create(repository=repository)
dbsession.add(commit)
dbsession.flush()
report = ReportFactory.create(commit=commit)
dbsession.add(report)
dbsession.flush()
return repository, commit, report

def _create_upload(self, dbsession, report, state_id):
upload = UploadFactory.create(
report=report,
state="uploaded",
state_id=state_id,
)
dbsession.add(upload)
dbsession.flush()
return upload

def test_get_upload_numbers_empty(self, dbsession, setup_commit):
_, commit, _ = setup_commit
state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
numbers = state.get_upload_numbers()
assert numbers.processing == 0
assert numbers.processed == 0

def test_get_upload_numbers_with_uploads(self, dbsession, setup_commit):
_, commit, report = setup_commit
self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)
self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)
self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
numbers = state.get_upload_numbers()
assert numbers.processing == 2
assert numbers.processed == 1

def test_mark_upload_as_processed(self, dbsession, setup_commit):
_, commit, report = setup_commit
upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.mark_upload_as_processed(upload.id_)
dbsession.flush()

dbsession.refresh(upload)
assert upload.state_id == UploadState.PROCESSED.db_id
# state string is intentionally NOT set here -- the finisher's
# idempotency check uses state="processed" to detect already-merged uploads
assert upload.state == "uploaded"

def test_mark_uploads_as_merged(self, dbsession, setup_commit):
_, commit, report = setup_commit
u1 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)
u2 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.mark_uploads_as_merged([u1.id_, u2.id_])
dbsession.flush()

dbsession.refresh(u1)
dbsession.refresh(u2)
assert u1.state_id == UploadState.MERGED.db_id
assert u1.state == "merged"
assert u2.state_id == UploadState.MERGED.db_id
assert u2.state == "merged"

def test_mark_uploads_as_merged_empty_list(self, dbsession, setup_commit):
_, commit, _ = setup_commit
state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.mark_uploads_as_merged([])

def test_get_uploads_for_merging(self, dbsession, setup_commit):
_, commit, report = setup_commit
u1 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)
u2 = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)
self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)
self._create_upload(dbsession, report, UploadState.MERGED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
merging = state.get_uploads_for_merging()
assert merging == {u1.id_, u2.id_}

def test_get_uploads_for_merging_respects_batch_size(self, dbsession, setup_commit):
_, commit, report = setup_commit
for _ in range(15):
self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
merging = state.get_uploads_for_merging()
assert len(merging) == 10

def test_mark_uploads_as_processing_is_noop(self, dbsession, setup_commit):
_, commit, report = setup_commit
upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.mark_uploads_as_processing([upload.id_])

dbsession.refresh(upload)
assert upload.state_id == UploadState.UPLOADED.db_id

def test_clear_in_progress_uploads_sets_error_on_uploaded(
self, dbsession, setup_commit
):
"""Uploaded uploads are set to ERROR so they stop counting as 'processing'."""
_, commit, report = setup_commit
upload = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.clear_in_progress_uploads([upload.id_])

dbsession.refresh(upload)
assert upload.state_id == UploadState.ERROR.db_id
assert upload.state == "error"

def test_clear_in_progress_uploads_skips_processed(self, dbsession, setup_commit):
"""Already-processed uploads are not affected (success path in finally block)."""
_, commit, report = setup_commit
upload = self._create_upload(dbsession, report, UploadState.PROCESSED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)
state.clear_in_progress_uploads([upload.id_])

dbsession.refresh(upload)
assert upload.state_id == UploadState.PROCESSED.db_id

def test_full_lifecycle(self, dbsession, setup_commit):
"""End-to-end: UPLOADED -> PROCESSED -> MERGED with DB state."""
_, commit, report = setup_commit
u1 = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)
u2 = self._create_upload(dbsession, report, UploadState.UPLOADED.db_id)

state = ProcessingState(commit.repoid, commit.commitid, db_session=dbsession)

numbers = state.get_upload_numbers()
assert numbers.processing == 2
assert numbers.processed == 0

state.mark_upload_as_processed(u1.id_)
dbsession.flush()

numbers = state.get_upload_numbers()
assert numbers.processing == 1
assert numbers.processed == 1
assert not should_perform_merge(numbers)

state.mark_upload_as_processed(u2.id_)
dbsession.flush()

numbers = state.get_upload_numbers()
assert numbers.processing == 0
assert numbers.processed == 2
assert should_perform_merge(numbers)

merging = state.get_uploads_for_merging()
assert merging == {u1.id_, u2.id_}

state.mark_uploads_as_merged(list(merging))
dbsession.flush()

numbers = state.get_upload_numbers()
assert numbers.processing == 0
assert numbers.processed == 0
assert should_trigger_postprocessing(numbers)
Loading