Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions codeflash/benchmarking/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,27 @@ def _render_comparison(result: CompareResult) -> None:
# Find all benchmark keys across both refs
all_benchmark_keys = set(result.base_total_ns.keys()) | set(result.head_total_ns.keys())

base_total_ns = result.base_total_ns
head_total_ns = result.head_total_ns
base_function_ns = result.base_function_ns
head_function_ns = result.head_function_ns

for bm_key in sorted(all_benchmark_keys, key=str):
# Show only the test function name, not the full module path
bm_name = str(bm_key).rsplit("::", 1)[-1] if "::" in str(bm_key) else str(bm_key)
console.print()
console.rule(f"[bold]{bm_name}[/bold]")
console.print()
bm_key_str = str(bm_key)
bm_name = bm_key_str.rsplit("::", 1)[-1] if "::" in bm_key_str else bm_key_str

# Batch console operations
console.print(f"\n[blue]{'─' * 80}[/blue]\n[bold]{bm_name}[/bold]\n[blue]{'─' * 80}[/blue]\n")

base_ns = result.base_total_ns.get(bm_key)
head_ns = result.head_total_ns.get(bm_key)
base_ns = base_total_ns.get(bm_key)
head_ns = head_total_ns.get(bm_key)

# Pre-compute formatted values
base_ms_str = _fmt_ms(base_ns)
head_ms_str = _fmt_ms(head_ns)
delta_str = _fmt_delta(base_ns, head_ns)
speedup_str = _fmt_speedup(base_ns, head_ns)

# Table 1: Total benchmark time
t1 = Table(title="End-to-End", border_style="blue", show_lines=True, expand=False)
Expand All @@ -487,15 +499,13 @@ def _render_comparison(result: CompareResult) -> None:
t1.add_column("Delta", justify="right")
t1.add_column("Speedup", justify="right")

t1.add_row(f"{base_short} (base)", _fmt_ms(base_ns), "-", "-")
t1.add_row(
f"{head_short} (head)", _fmt_ms(head_ns), _fmt_delta(base_ns, head_ns), _fmt_speedup(base_ns, head_ns)
)
t1.add_row(f"{base_short} (base)", base_ms_str, "-", "-")
t1.add_row(f"{head_short} (head)", head_ms_str, delta_str, speedup_str)
console.print(t1, justify="center")

# Table 2: Per-function breakdown
all_funcs = set()
for d in [result.base_function_ns, result.head_function_ns]:
for d in [base_function_ns, head_function_ns]:
for func_name, bm_dict in d.items():
if bm_key in bm_dict:
all_funcs.add(func_name)
Expand All @@ -511,11 +521,11 @@ def _render_comparison(result: CompareResult) -> None:
t2.add_column("Speedup", justify="right")

def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> int:
return result.base_function_ns.get(fn, {}).get(_bm_key, 0)
return base_function_ns.get(fn, {}).get(_bm_key, 0)

for func_name in sorted(all_funcs, key=sort_key, reverse=True):
b_ns = result.base_function_ns.get(func_name, {}).get(bm_key)
h_ns = result.head_function_ns.get(func_name, {}).get(bm_key)
b_ns = base_function_ns.get(func_name, {}).get(bm_key)
h_ns = head_function_ns.get(func_name, {}).get(bm_key)

# Shorten function name for display
short_name = func_name.rsplit(".", 1)[-1] if "." in func_name else func_name
Expand All @@ -526,10 +536,10 @@ def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> int:
t2.add_section()
t2.add_row(
"[bold]TOTAL[/bold]",
f"[bold]{_fmt_ms(base_ns)}[/bold]",
f"[bold]{_fmt_ms(head_ns)}[/bold]",
_fmt_delta(base_ns, head_ns),
_fmt_speedup(base_ns, head_ns),
f"[bold]{base_ms_str}[/bold]",
f"[bold]{head_ms_str}[/bold]",
delta_str,
speedup_str,
)
console.print(t2, justify="center")

Expand Down
Loading