Skip to content

Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time) (#5087)#5087

Open
sdaulton wants to merge 1 commit intofacebook:mainfrom
sdaulton:export-D97159025
Open

Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time) (#5087)#5087
sdaulton wants to merge 1 commit intofacebook:mainfrom
sdaulton:export-D97159025

Conversation

@sdaulton
Copy link
Copy Markdown
Contributor

@sdaulton sdaulton commented Mar 21, 2026

Summary:
X-link: meta-pytorch/botorch#3247

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:

  • Replace pyro.sample/pyro.distributions with numpyro.sample/numpyro.distributions
  • Replace pyro.deterministic with numpyro.deterministic
  • Rewrite kernel functions (matern52_kernel, linear_kernel, compute_dists) in JAX
  • Replace Pyro's NUTS/MCMC with NumPyro equivalents (dense_mass instead of full_mass, num_warmup instead of warmup_steps, explicit jax.random.PRNGKey)
  • Remove register_exception_handler (Pyro-specific); NumPyro handles divergences via built-in NaN detection
  • Add botorch/utils/jax_utils.py with torch_to_jax/jax_to_torch conversion helpers
  • Convert training data to JAX arrays in PyroModel.set_inputs(); all sample() methods now operate in JAX-land
  • Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
  • Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
  • Update BUCK to replace pyro-ppl dep with jax + numpyro

Differential Revision: D97159025

@meta-cla meta-cla bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Mar 21, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Mar 21, 2026

@sdaulton has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97159025.

sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 21, 2026
…ction in fit time)

Summary:
X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
@meta-codesync meta-codesync bot changed the title Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time) Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time) (#5087) Mar 22, 2026
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 22, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 22, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 23, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 23, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 24, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 24, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 24, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
@sdaulton sdaulton force-pushed the export-D97159025 branch 2 times, most recently from b76911c to 675444d Compare March 24, 2026 19:08
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 24, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 24, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 24, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 24, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247

Pull Request resolved: facebook#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 24, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:
Pull Request resolved: meta-pytorch#3247

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/Ax-1 that referenced this pull request Mar 24, 2026
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 24, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
…ction in fit time) (facebook#5087)

Summary:
X-link: meta-pytorch/botorch#3247


This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 25, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
sdaulton added a commit to sdaulton/botorch that referenced this pull request Mar 25, 2026
…ction in fit time) (meta-pytorch#3247)

Summary:

X-link: facebook/Ax#5087

This diff replaces all Pyro usage in BoTorch's fully Bayesian models with
NumPyro (backed by JAX). NumPyro's JAX-based NUTS implementation is
significantly faster than Pyro's PyTorch-based one, delivering 25x
speedups on CPU with equivalent model quality.

Key changes:
- Replace `pyro.sample`/`pyro.distributions` with `numpyro.sample`/`numpyro.distributions`
- Replace `pyro.deterministic` with `numpyro.deterministic`
- Rewrite kernel functions (`matern52_kernel`, `linear_kernel`, `compute_dists`) in JAX
- Replace Pyro's `NUTS`/`MCMC` with NumPyro equivalents (`dense_mass` instead of `full_mass`, `num_warmup` instead of `warmup_steps`, explicit `jax.random.PRNGKey`)
- Remove `register_exception_handler` (Pyro-specific); NumPyro handles divergences via built-in NaN detection
- Add `botorch/utils/jax_utils.py` with `torch_to_jax`/`jax_to_torch` conversion helpers
- Convert training data to JAX arrays in `PyroModel.set_inputs()`; all `sample()` methods now operate in JAX-land
- Post-process MCMC samples by converting JAX arrays back to torch tensors at the boundary
- Pin numpyro to 0.18.0 in PACKAGE for vectorized chain support
- Update BUCK to replace `pyro-ppl` dep with `jax` + `numpyro`

Differential Revision: D97159025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant