Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions apps/worker/tasks/compute_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,53 @@ def create_or_update_flag_comparisons(
comparison_proxy: ComparisonProxy,
):
repository_id = comparison.compare_commit.repository.repoid
for flag_name in head_report_flags.keys():
totals = self.get_flag_comparison_totals(flag_name, comparison_proxy)
repositoryflag = (
db_session.query(RepositoryFlag)
.filter_by(
flag_name=flag_name,
repository_id=repository_id,
)
.first()
flag_names = list(head_report_flags.keys())

# Batch-fetch all RepositoryFlag records for this repository matching
# the flag names in the head report, to avoid one query per flag.
existing_repository_flags: dict[str, RepositoryFlag] = {
rf.flag_name: rf
for rf in db_session.query(RepositoryFlag)
.filter(
RepositoryFlag.repository_id == repository_id,
RepositoryFlag.flag_name.in_(flag_names),
)
if not repositoryflag:
.all()
}

# Create any missing RepositoryFlag records up front so we can
# build a complete id→flag map before querying CompareFlag.
for flag_name in flag_names:
if flag_name not in existing_repository_flags:
log.warning(
"Repository flag not found for flag. Created repository flag.",
extra={"repoid": repository_id, "flag_name": flag_name},
)
repositoryflag = RepositoryFlag(
new_flag = RepositoryFlag(
repository_id=repository_id,
flag_name=flag_name,
)
db_session.add(repositoryflag)
db_session.add(new_flag)
db_session.flush()

flag_comparison_entry = (
db_session.query(CompareFlag)
.filter_by(
commit_comparison_id=comparison.id,
repositoryflag_id=repositoryflag.id,
)
.first()
existing_repository_flags[flag_name] = new_flag

# Batch-fetch all CompareFlag records for this comparison whose
# repositoryflag_id is among the flags we care about.
repositoryflag_ids = [rf.id for rf in existing_repository_flags.values()]
existing_flag_comparisons: dict[int, CompareFlag] = {
cf.repositoryflag_id: cf
for cf in db_session.query(CompareFlag)
.filter(
CompareFlag.commit_comparison_id == comparison.id,
CompareFlag.repositoryflag_id.in_(repositoryflag_ids),
)
.all()
}

for flag_name in flag_names:
totals = self.get_flag_comparison_totals(flag_name, comparison_proxy)
repositoryflag = existing_repository_flags[flag_name]
flag_comparison_entry = existing_flag_comparisons.get(repositoryflag.id)

if not flag_comparison_entry:
log.debug(
Expand Down
Loading