diff --git a/codeflash/benchmarking/compare.py b/codeflash/benchmarking/compare.py index fb98ef301..62ed33a20 100644 --- a/codeflash/benchmarking/compare.py +++ b/codeflash/benchmarking/compare.py @@ -470,15 +470,27 @@ def _render_comparison(result: CompareResult) -> None: # Find all benchmark keys across both refs all_benchmark_keys = set(result.base_total_ns.keys()) | set(result.head_total_ns.keys()) + base_total_ns = result.base_total_ns + head_total_ns = result.head_total_ns + base_function_ns = result.base_function_ns + head_function_ns = result.head_function_ns + for bm_key in sorted(all_benchmark_keys, key=str): # Show only the test function name, not the full module path - bm_name = str(bm_key).rsplit("::", 1)[-1] if "::" in str(bm_key) else str(bm_key) - console.print() - console.rule(f"[bold]{bm_name}[/bold]") - console.print() + bm_key_str = str(bm_key) + bm_name = bm_key_str.rsplit("::", 1)[-1] if "::" in bm_key_str else bm_key_str + + # Batch console operations + console.print(f"\n[blue]{'─' * 80}[/blue]\n[bold]{bm_name}[/bold]\n[blue]{'─' * 80}[/blue]\n") - base_ns = result.base_total_ns.get(bm_key) - head_ns = result.head_total_ns.get(bm_key) + base_ns = base_total_ns.get(bm_key) + head_ns = head_total_ns.get(bm_key) + + # Pre-compute formatted values + base_ms_str = _fmt_ms(base_ns) + head_ms_str = _fmt_ms(head_ns) + delta_str = _fmt_delta(base_ns, head_ns) + speedup_str = _fmt_speedup(base_ns, head_ns) # Table 1: Total benchmark time t1 = Table(title="End-to-End", border_style="blue", show_lines=True, expand=False) @@ -487,15 +499,13 @@ def _render_comparison(result: CompareResult) -> None: t1.add_column("Delta", justify="right") t1.add_column("Speedup", justify="right") - t1.add_row(f"{base_short} (base)", _fmt_ms(base_ns), "-", "-") - t1.add_row( - f"{head_short} (head)", _fmt_ms(head_ns), _fmt_delta(base_ns, head_ns), _fmt_speedup(base_ns, head_ns) - ) + t1.add_row(f"{base_short} (base)", base_ms_str, "-", "-") + t1.add_row(f"{head_short} (head)", head_ms_str, delta_str, speedup_str) console.print(t1, justify="center") # Table 2: Per-function breakdown all_funcs = set() - for d in [result.base_function_ns, result.head_function_ns]: + for d in [base_function_ns, head_function_ns]: for func_name, bm_dict in d.items(): if bm_key in bm_dict: all_funcs.add(func_name) @@ -511,11 +521,11 @@ def _render_comparison(result: CompareResult) -> None: t2.add_column("Speedup", justify="right") def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> int: - return result.base_function_ns.get(fn, {}).get(_bm_key, 0) + return base_function_ns.get(fn, {}).get(_bm_key, 0) for func_name in sorted(all_funcs, key=sort_key, reverse=True): - b_ns = result.base_function_ns.get(func_name, {}).get(bm_key) - h_ns = result.head_function_ns.get(func_name, {}).get(bm_key) + b_ns = base_function_ns.get(func_name, {}).get(bm_key) + h_ns = head_function_ns.get(func_name, {}).get(bm_key) # Shorten function name for display short_name = func_name.rsplit(".", 1)[-1] if "." in func_name else func_name @@ -526,10 +536,10 @@ def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> int: t2.add_section() t2.add_row( "[bold]TOTAL[/bold]", - f"[bold]{_fmt_ms(base_ns)}[/bold]", - f"[bold]{_fmt_ms(head_ns)}[/bold]", - _fmt_delta(base_ns, head_ns), - _fmt_speedup(base_ns, head_ns), + f"[bold]{base_ms_str}[/bold]", + f"[bold]{head_ms_str}[/bold]", + delta_str, + speedup_str, ) console.print(t2, justify="center")