Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions apps/worker/services/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,29 @@ def process_upload(
# Safety net: trigger finisher if all uploads are done.
# Handles retries outside a chord (visibility timeout, task_reject_on_worker_lost).
# Will be replaced by gate key mechanism in a future PR.
# Use Redis-only state for counting (dual-write keeps Redis in sync).
redis_state = ProcessingState(repo_id, commit_sha)
upload_numbers = redis_state.get_upload_numbers()
if should_trigger_postprocessing(upload_numbers):
log.info(
"All uploads processed, triggering finisher",
extra={
"repo_id": repo_id,
"commit_sha": commit_sha,
"upload_id": upload_id,
},
)
celery_app.tasks[upload_finisher_task_name].apply_async(
kwargs={
"repoid": repo_id,
"commitid": commit_sha,
"commit_yaml": commit_yaml.to_dict(),
}
try:
upload_numbers = state.get_upload_numbers()
if should_trigger_postprocessing(upload_numbers):
log.info(
"All uploads processed, triggering finisher",
extra={
"repo_id": repo_id,
"commit_sha": commit_sha,
"upload_id": upload_id,
},
)
celery_app.tasks[upload_finisher_task_name].apply_async(
kwargs={
"repoid": repo_id,
"commitid": commit_sha,
"commit_yaml": commit_yaml.to_dict(),
}
)
except Exception:
log.warning(
"Safety-net finisher trigger failed (non-fatal)",
extra={"repo_id": repo_id, "commit_sha": commit_sha},
exc_info=True,
)

rewrite_or_delete_upload(archive_service, commit_yaml, report_info)
Expand Down
193 changes: 80 additions & 113 deletions apps/worker/services/processing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"intermediate report".
"""

import logging
from dataclasses import dataclass

from sqlalchemy import case, func
Expand All @@ -29,16 +30,12 @@
from database.enums import ReportType
from database.models.core import Commit
from database.models.reports import CommitReport, Upload
from shared.helpers.redis import get_redis_connection
from shared.metrics import Counter
from shared.reports.enums import UploadState

MERGE_BATCH_SIZE = 10
log = logging.getLogger(__name__)

# TTL for processing state keys in Redis (24 hours, matches intermediate report TTL)
# This prevents state keys from accumulating indefinitely and ensures consistency
# with intermediate report expiration
PROCESSING_STATE_TTL = 24 * 60 * 60
MERGE_BATCH_SIZE = 10

CLEARED_UPLOADS = Counter(
"worker_processing_cleared_uploads",
Expand Down Expand Up @@ -82,65 +79,59 @@ def should_trigger_postprocessing(uploads: UploadNumbers) -> bool:


class ProcessingState:
def __init__(
self, repoid: int, commitsha: str, db_session: Session | None = None
) -> None:
self._redis = get_redis_connection()
def __init__(self, repoid: int, commitsha: str, db_session: Session) -> None:
self.repoid = repoid
self.commitsha = commitsha
self._db_session = db_session

def get_upload_numbers(self):
if self._db_session:
row = (
self._db_session.query(
func.count(
case(
(
Upload.state_id == UploadState.UPLOADED.db_id,
Upload.id_,
),
)
),
func.count(
case(
(
Upload.state_id == UploadState.PROCESSED.db_id,
Upload.id_,
),
)
),
)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
)
.one()
row = (
self._db_session.query(
func.count(
case(
(
Upload.state_id == UploadState.UPLOADED.db_id,
Upload.id_,
),
)
),
func.count(
case(
(
Upload.state_id == UploadState.PROCESSED.db_id,
Upload.id_,
),
)
),
)
return UploadNumbers(processing=row[0], processed=row[1])

processing = self._redis.scard(self._redis_key("processing"))
processed = self._redis.scard(self._redis_key("processed"))
return UploadNumbers(processing, processed)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
)
.one()
)
return UploadNumbers(processing=row[0], processed=row[1])

def mark_uploads_as_processing(self, upload_ids: list[int]):
if not upload_ids:
return
key = self._redis_key("processing")
self._redis.sadd(key, *upload_ids)
self._redis.expire(key, PROCESSING_STATE_TTL)
# No-op: uploads are created with state_id=UPLOADED, which
# get_upload_numbers() already counts as "processing".
pass

def clear_in_progress_uploads(self, upload_ids: list[int]):
if not upload_ids:
return
if self._db_session:
# Mark still-UPLOADED uploads as ERROR so they stop being counted
# as "processing" in get_upload_numbers(). Only matches UPLOADED --
# already-PROCESSED uploads (success path) are unaffected.
# Mark still-UPLOADED uploads as ERROR so they stop being counted
# as "processing" in get_upload_numbers(). Only matches UPLOADED --
# already-PROCESSED uploads (success path) are unaffected.
#
# This runs in a finally block, so the transaction may already be
# in a failed state. Best-effort: log and move on if the DB is
# unreachable — the upload stays UPLOADED, which is safe.
try:
updated = (
self._db_session.query(Upload)
.filter(
Expand All @@ -157,72 +148,48 @@ def clear_in_progress_uploads(self, upload_ids: list[int]):
)
if updated > 0:
CLEARED_UPLOADS.inc(updated)
self._redis.srem(self._redis_key("processing"), *upload_ids)
return
removed_uploads = self._redis.srem(self._redis_key("processing"), *upload_ids)
if removed_uploads > 0:
CLEARED_UPLOADS.inc(removed_uploads)
except Exception:
log.warning(
"Failed to clear in-progress uploads (transaction may be aborted)",
extra={"upload_ids": upload_ids},
exc_info=True,
)

def mark_upload_as_processed(self, upload_id: int):
if self._db_session:
upload = self._db_session.query(Upload).get(upload_id)
if upload:
upload.state_id = UploadState.PROCESSED.db_id
# Don't set upload.state here -- the finisher's idempotency check
# uses state="processed" to detect already-merged uploads.
# The state string is set by update_uploads() after merging.

processing_key = self._redis_key("processing")
processed_key = self._redis_key("processed")

res = self._redis.smove(processing_key, processed_key, upload_id)
if not res:
self._redis.sadd(processed_key, upload_id)

self._redis.expire(processed_key, PROCESSING_STATE_TTL)
upload = self._db_session.query(Upload).get(upload_id)
if upload:
upload.state_id = UploadState.PROCESSED.db_id
# Don't set upload.state here -- the finisher's idempotency check
# uses state="processed" to detect already-merged uploads.
# The state string is set by update_uploads() after merging.

def mark_uploads_as_merged(self, upload_ids: list[int]):
if not upload_ids:
return
if self._db_session:
self._db_session.query(Upload).filter(
Upload.id_.in_(upload_ids),
Upload.state_id == UploadState.PROCESSED.db_id,
).update(
{
Upload.state_id: UploadState.MERGED.db_id,
Upload.state: "merged",
},
synchronize_session="fetch",
)
self._redis.srem(self._redis_key("processed"), *upload_ids)
return
self._redis.srem(self._redis_key("processed"), *upload_ids)
self._db_session.query(Upload).filter(
Upload.id_.in_(upload_ids),
Upload.state_id == UploadState.PROCESSED.db_id,
).update(
{
Upload.state_id: UploadState.MERGED.db_id,
Upload.state: "merged",
},
synchronize_session="fetch",
)

def get_uploads_for_merging(self) -> set[int]:
if self._db_session:
rows = (
self._db_session.query(Upload.id_)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
Upload.state_id == UploadState.PROCESSED.db_id,
)
.limit(MERGE_BATCH_SIZE)
.all()
)
return {row[0] for row in rows}

return {
int(id)
for id in self._redis.srandmember(
self._redis_key("processed"), MERGE_BATCH_SIZE
rows = (
self._db_session.query(Upload.id_)
.join(CommitReport, Upload.report_id == CommitReport.id_)
.join(Commit, CommitReport.commit_id == Commit.id_)
.filter(
Commit.repoid == self.repoid,
Commit.commitid == self.commitsha,
(CommitReport.report_type == None) # noqa: E711
| (CommitReport.report_type == ReportType.COVERAGE.value),
Upload.state_id == UploadState.PROCESSED.db_id,
)
}

def _redis_key(self, state: str) -> str:
return f"upload-processing-state/{self.repoid}/{self.commitsha}/{state}"
.limit(MERGE_BATCH_SIZE)
.all()
)
return {row[0] for row in rows}
Loading
Loading