diff --git a/apps/worker/tasks/compute_comparison.py b/apps/worker/tasks/compute_comparison.py index 8265073d89..ab61915aa6 100644 --- a/apps/worker/tasks/compute_comparison.py +++ b/apps/worker/tasks/compute_comparison.py @@ -135,36 +135,53 @@ def create_or_update_flag_comparisons( comparison_proxy: ComparisonProxy, ): repository_id = comparison.compare_commit.repository.repoid - for flag_name in head_report_flags.keys(): - totals = self.get_flag_comparison_totals(flag_name, comparison_proxy) - repositoryflag = ( - db_session.query(RepositoryFlag) - .filter_by( - flag_name=flag_name, - repository_id=repository_id, - ) - .first() + flag_names = list(head_report_flags.keys()) + + # Batch-fetch all RepositoryFlag records for this repository matching + # the flag names in the head report, to avoid one query per flag. + existing_repository_flags: dict[str, RepositoryFlag] = { + rf.flag_name: rf + for rf in db_session.query(RepositoryFlag) + .filter( + RepositoryFlag.repository_id == repository_id, + RepositoryFlag.flag_name.in_(flag_names), ) - if not repositoryflag: + .all() + } + + # Create any missing RepositoryFlag records up front so we can + # build a complete id→flag map before querying CompareFlag. + for flag_name in flag_names: + if flag_name not in existing_repository_flags: log.warning( "Repository flag not found for flag. Created repository flag.", extra={"repoid": repository_id, "flag_name": flag_name}, ) - repositoryflag = RepositoryFlag( + new_flag = RepositoryFlag( repository_id=repository_id, flag_name=flag_name, ) - db_session.add(repositoryflag) + db_session.add(new_flag) db_session.flush() - - flag_comparison_entry = ( - db_session.query(CompareFlag) - .filter_by( - commit_comparison_id=comparison.id, - repositoryflag_id=repositoryflag.id, - ) - .first() + existing_repository_flags[flag_name] = new_flag + + # Batch-fetch all CompareFlag records for this comparison whose + # repositoryflag_id is among the flags we care about. + repositoryflag_ids = [rf.id for rf in existing_repository_flags.values()] + existing_flag_comparisons: dict[int, CompareFlag] = { + cf.repositoryflag_id: cf + for cf in db_session.query(CompareFlag) + .filter( + CompareFlag.commit_comparison_id == comparison.id, + CompareFlag.repositoryflag_id.in_(repositoryflag_ids), ) + .all() + } + + for flag_name in flag_names: + totals = self.get_flag_comparison_totals(flag_name, comparison_proxy) + repositoryflag = existing_repository_flags[flag_name] + flag_comparison_entry = existing_flag_comparisons.get(repositoryflag.id) if not flag_comparison_entry: log.debug(