diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py index ec4383f4..d703cd56 100755 --- a/.github/workflows/set_cibw_build.py +++ b/.github/workflows/set_cibw_build.py @@ -5,7 +5,6 @@ import os import sys - # pylint: disable-next=consider-using-f-string CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2] diff --git a/.github/workflows/set_release.py b/.github/workflows/set_release.py index 6c437f19..0ac3e152 100755 --- a/.github/workflows/set_release.py +++ b/.github/workflows/set_release.py @@ -5,7 +5,6 @@ import pathlib import re - ROOT = pathlib.Path(__file__).absolute().parent.parent.parent VERSION_FILE = ROOT / 'torchopt' / 'version.py' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ab860a5..69794d2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,10 +6,10 @@ ci: autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]" autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate" autoupdate_schedule: monthly -default_stages: [commit, push, manual] +default_stages: [pre-commit, pre-push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,24 +26,24 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.8 + rev: v22.1.0 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.15.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 8.0.1 hooks: - id: isort - - repo: https://github.com/psf/black - rev: 24.4.2 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 26.1.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 + rev: v3.21.2 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 7.1.0 + rev: 7.3.0 hooks: - id: flake8 additional_dependencies: @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/docs/source/conf.py b/docs/source/conf.py index a4f23533..7668dc67 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Configuration file for the Sphinx documentation builder.""" + # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html @@ -33,7 +34,6 @@ import sphinx import sphinxcontrib.katex as katex - HERE = pathlib.Path(__file__).absolute().parent PROJECT_ROOT = HERE.parent.parent diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index 2f42e050..707db05d 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -57,13 +57,11 @@ import torchopt - CWD = pathlib(__file__).absolute().parent sys.path.append(str(CWD.parent / 'few-shot')) from helpers.omniglot_loaders import OmniglotNShot - mpl.use('Agg') plt.style.use('bmh') diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index 475c1b12..eb5b02ff 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -25,7 +25,6 @@ import torchopt from helpers.policy import CategoricalMLPPolicy - TASK_NUM = 40 TRAJ_NUM = 20 TRAJ_LEN = 10 diff --git a/examples/MAML-RL/helpers/__init__.py b/examples/MAML-RL/helpers/__init__.py index 31d45c37..c52ec7fc 100644 --- a/examples/MAML-RL/helpers/__init__.py +++ b/examples/MAML-RL/helpers/__init__.py @@ -18,7 +18,6 @@ from gym.envs.registration import register - register( 'TabularMDP-v0', entry_point='helpers.tabular_mdp:TabularMDPEnv', diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index 0cb57a92..238cdbd0 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -24,7 +24,6 @@ import torchopt from helpers.policy import CategoricalMLPPolicy - TASK_NUM = 40 TRAJ_NUM = 20 TRAJ_LEN = 10 diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 56db91ef..2c541da4 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -25,7 +25,6 @@ import torchopt from helpers.policy_torchrl import ActorCritic - TASK_NUM = 40 TRAJ_NUM = 20 TRAJ_LEN = 10 diff --git a/examples/distributed/few-shot/maml_omniglot.py b/examples/distributed/few-shot/maml_omniglot.py index f840e65e..23a4dc8b 100644 --- a/examples/distributed/few-shot/maml_omniglot.py +++ b/examples/distributed/few-shot/maml_omniglot.py @@ -58,7 +58,6 @@ import torchopt.distributed as todist from helpers.omniglot_loaders import OmniglotNShot - mpl.use('Agg') plt.style.use('bmh') diff --git a/examples/distributed/few-shot/maml_omniglot_local_loader.py b/examples/distributed/few-shot/maml_omniglot_local_loader.py index fb737d4f..a8ee543c 100644 --- a/examples/distributed/few-shot/maml_omniglot_local_loader.py +++ b/examples/distributed/few-shot/maml_omniglot_local_loader.py @@ -60,7 +60,6 @@ import torchopt.distributed as todist from helpers.omniglot_loaders import OmniglotNShot - mpl.use('Agg') plt.style.use('bmh') diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index 7f7f67fe..b8e1242a 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -54,7 +54,6 @@ import torchopt from helpers.omniglot_loaders import OmniglotNShot - mpl.use('Agg') plt.style.use('bmh') diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 1db08427..b3af6939 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -36,7 +36,6 @@ from helpers.omniglot_loaders import OmniglotNShot from torchopt.diff.implicit import ImplicitMetaGradientModule - mpl.use('Agg') plt.style.use('bmh') diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 7bc1e9da..b5a6ca84 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -37,7 +37,6 @@ from helpers.omniglot_loaders import OmniglotNShot from torchopt import pytree - mpl.use('Agg') plt.style.use('bmh') diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index 2d0abcd3..e18f9edb 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -27,51 +27,51 @@ namespace py = pybind11; namespace adam_op { -TensorArray<3> adamForwardInplace(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplace(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMu(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMu(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNu(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNu(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, const pyuint_t count); -void buildSubmodule(py::module &mod); // NOLINT[runtime/references] +void buildSubmodule(py::module& mod); // NOLINT[runtime/references] } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 4d54377e..c2125ff1 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -23,45 +23,45 @@ namespace torchopt { namespace adam_op { -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index 17002b36..f38b3e7f 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -23,45 +23,45 @@ namespace torchopt { namespace adam_op { -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/include/utils.h b/include/utils.h index cefabfac..3b029949 100644 --- a/include/utils.h +++ b/include/utils.h @@ -24,7 +24,7 @@ #endif namespace torchopt { -__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) { +__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) { const auto dim = tensor.dim(); size_t n = 1; for (std::decay_t i = 0; i < dim; ++i) { diff --git a/setup.py b/setup.py index c50ba5ed..33448924 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,6 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext - HERE = pathlib.Path(__file__).absolute().parent diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 47f5d7f1..a0f61cc9 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -29,9 +29,9 @@ namespace py = pybind11; namespace adam_op { -TensorArray<3> adamForwardInplace(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplace(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -49,8 +49,8 @@ TensorArray<3> adamForwardInplace(const torch::Tensor &updates, } } -torch::Tensor adamForwardMu(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMu(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (updates.device().is_cuda()) { @@ -64,8 +64,8 @@ torch::Tensor adamForwardMu(const torch::Tensor &updates, } } -torch::Tensor adamForwardNu(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNu(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (updates.device().is_cuda()) { @@ -79,8 +79,8 @@ torch::Tensor adamForwardNu(const torch::Tensor &updates, } } -torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -98,9 +98,9 @@ torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, } } -TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (dmu.device().is_cuda()) { @@ -114,9 +114,9 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, } } -TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (dnu.device().is_cuda()) { @@ -130,10 +130,10 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, } } -TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, @@ -152,7 +152,7 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, } } -void buildSubmodule(py::module &mod) { // NOLINT[runtime/references] +void buildSubmodule(py::module& mod) { // NOLINT[runtime/references] py::module m = mod.def_submodule("adam_op", "Adam Ops"); m.def("forward_", &adamForwardInplace, diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 9c460685..38aa2bc0 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -37,9 +37,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_ptr, - scalar_t *__restrict__ mu_ptr, - scalar_t *__restrict__ nu_ptr) { + scalar_t* __restrict__ updates_ptr, + scalar_t* __restrict__ mu_ptr, + scalar_t* __restrict__ nu_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -61,9 +61,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, } } -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -91,11 +91,11 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, } template -void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ mu_ptr, +void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ mu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ mu_out_ptr) { + scalar_t* __restrict__ mu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -107,8 +107,8 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, } } -torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); @@ -125,11 +125,11 @@ torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, } template -void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ nu_ptr, +void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ nu_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ nu_out_ptr) { + scalar_t* __restrict__ nu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -142,8 +142,8 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, } } -torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); @@ -160,14 +160,14 @@ torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, } template -void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, - const scalar_t *__restrict__ new_nu_ptr, +void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr, + const scalar_t* __restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_out_ptr) { + scalar_t* __restrict__ updates_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -180,8 +180,8 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, } } -torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -209,11 +209,11 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, } template -void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, +void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dmu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dmu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -225,9 +225,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, } } -TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); @@ -245,12 +245,12 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, } template -void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, - const scalar_t *__restrict__ updates_ptr, +void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr, + const scalar_t* __restrict__ updates_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dnu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dnu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -263,9 +263,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, } } -TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); @@ -284,14 +284,14 @@ TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, } template -void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, - const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ new_mu_ptr, +void adamBackwardUpdatesCPUKernel(const scalar_t* __restrict__ dupdates_ptr, + const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ new_mu_ptr, const other_t one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const size_t n, - scalar_t *__restrict__ dnew_mu_out_ptr, - scalar_t *__restrict__ dnew_nu_out_ptr) { + scalar_t* __restrict__ dnew_mu_out_ptr, + scalar_t* __restrict__ dnew_nu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -316,10 +316,10 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, } } -TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index a12eca4f..538ad7e5 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -35,9 +35,9 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_ptr, - scalar_t *__restrict__ mu_ptr, - scalar_t *__restrict__ nu_ptr) { + scalar_t* __restrict__ updates_ptr, + scalar_t* __restrict__ mu_ptr, + scalar_t* __restrict__ nu_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -62,9 +62,9 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, } } -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -112,11 +112,11 @@ TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ mu_ptr, +__global__ void adamForwardMuCUDAKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ mu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ mu_out_ptr) { + scalar_t* __restrict__ mu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -132,8 +132,8 @@ __global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr } } -torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); @@ -165,11 +165,11 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ nu_ptr, +__global__ void adamForwardNuCUDAKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ nu_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ nu_out_ptr) { + scalar_t* __restrict__ nu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -186,8 +186,8 @@ __global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr } } -torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); @@ -219,14 +219,14 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu_ptr, - const scalar_t *__restrict__ new_nu_ptr, +__global__ void adamForwardUpdatesCUDAKernel(const scalar_t* __restrict__ new_mu_ptr, + const scalar_t* __restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_out_ptr) { + scalar_t* __restrict__ updates_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -243,8 +243,8 @@ __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu } } -torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -291,11 +291,11 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, } template -__global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, +__global__ void adamBackwardMuCUDAKernel(const scalar_t* __restrict__ dmu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dmu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dmu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -311,9 +311,9 @@ __global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, } } -TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); @@ -346,12 +346,12 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, } template -__global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, - const scalar_t *__restrict__ updates_ptr, +__global__ void adamBackwardNuCUDAKernel(const scalar_t* __restrict__ dnu_ptr, + const scalar_t* __restrict__ updates_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dnu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dnu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -368,9 +368,9 @@ __global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, } } -TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); @@ -405,14 +405,14 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, } template -__global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupdates_ptr, - const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ new_mu_ptr, +__global__ void adamBackwardUpdatesCUDAKernel(const scalar_t* __restrict__ dupdates_ptr, + const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ new_mu_ptr, const other_t one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const size_t n, - scalar_t *__restrict__ dnew_mu_out_ptr, - scalar_t *__restrict__ dnew_nu_out_ptr) { + scalar_t* __restrict__ dnew_mu_out_ptr, + scalar_t* __restrict__ dnew_nu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -441,10 +441,10 @@ __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupda } } -TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/tests/conftest.py b/tests/conftest.py index bb2b1cf2..77a6cbab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,5 @@ import os - os.environ['PYTHONHASHSEED'] = '0' os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8' diff --git a/tests/helpers.py b/tests/helpers.py index ca5aa443..b5280fa1 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -31,7 +31,6 @@ from torchopt import pytree - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/tests/test_accelerated_op.py b/tests/test_accelerated_op.py index 668c9b9a..bed3a71d 100644 --- a/tests/test_accelerated_op.py +++ b/tests/test_accelerated_op.py @@ -20,7 +20,6 @@ import helpers import torchopt - try: import torchopt._C.adam_op except ImportError: diff --git a/tests/test_alias.py b/tests/test_alias.py index 3c42d7c8..8b218f3b 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -27,7 +27,6 @@ from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 6cccb716..3f5752fd 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -34,7 +34,6 @@ from torchopt import pytree from torchopt.diff.implicit import ImplicitMetaGradientModule - try: import jax import jax.numpy as jnp diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 6ee2939b..7a63355f 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -18,7 +18,6 @@ import helpers from torchopt import pytree - tree_a = (torch.randn(20, 10), torch.randn(20)) tree_b = (torch.randn(20, 10), torch.randn(20)) diff --git a/tests/test_zero_order.py b/tests/test_zero_order.py index 65642559..357f697c 100644 --- a/tests/test_zero_order.py +++ b/tests/test_zero_order.py @@ -23,7 +23,6 @@ import helpers import torchopt - BATCH_SIZE = 8 NUM_UPDATES = 5 diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 830072e3..8b8a23cc 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -79,7 +79,6 @@ ) from torchopt.version import __version__ - __all__ = [ 'SGD', 'AdaDelta', diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 90452046..cebff564 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -22,7 +22,6 @@ from torchopt.accelerated_op.adam_op import AdamOp - if TYPE_CHECKING: from torchopt.typing import Device diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index d7f9796d..34d3f4d7 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING - if TYPE_CHECKING: import torch diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index 43ac26cd..bccdd7d0 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -23,7 +23,6 @@ import torch - try: from torchopt._C import adam_op # pylint: disable=no-name-in-module except ImportError: diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 5767c5d7..a7585a7f 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -40,7 +40,6 @@ from torchopt.alias.rmsprop import rmsprop from torchopt.alias.sgd import sgd - __all__ = [ 'adadelta', 'adagrad', diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py index 910cb13e..b292cd51 100644 --- a/torchopt/alias/adadelta.py +++ b/torchopt/alias/adadelta.py @@ -26,7 +26,6 @@ from torchopt.combine import chain from torchopt.transform import scale_by_adadelta - if TYPE_CHECKING: from torchopt.typing import GradientTransformation, ScalarOrSchedule diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 6fdb4aa3..47351bbb 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -42,7 +42,6 @@ from torchopt.transform import scale_by_rss, scale_by_schedule from torchopt.typing import GradientTransformation, Numeric, Scalar, ScalarOrSchedule, Schedule - __all__ = ['adagrad'] diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index 0ae0eb8e..3f2ae4aa 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -43,7 +43,6 @@ from torchopt.combine import chain from torchopt.transform import scale_by_accelerated_adam, scale_by_adam - if TYPE_CHECKING: from torchopt.typing import GradientTransformation, ScalarOrSchedule diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py index 3da16713..588fdf56 100644 --- a/torchopt/alias/adamax.py +++ b/torchopt/alias/adamax.py @@ -26,7 +26,6 @@ from torchopt.combine import chain from torchopt.transform import scale_by_adamax - if TYPE_CHECKING: from torchopt.typing import GradientTransformation, ScalarOrSchedule diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 2dc72ef1..0547a2af 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -43,7 +43,6 @@ from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam - if TYPE_CHECKING: from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py index 9e2880ee..8a23303c 100644 --- a/torchopt/alias/radam.py +++ b/torchopt/alias/radam.py @@ -26,7 +26,6 @@ from torchopt.combine import chain from torchopt.transform import scale_by_radam - if TYPE_CHECKING: from torchopt.typing import GradientTransformation, ScalarOrSchedule diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 612e4f45..b49ccbe6 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -40,7 +40,6 @@ from torchopt.transform import scale_by_rms, scale_by_stddev, trace from torchopt.typing import GradientTransformation, ScalarOrSchedule - __all__ = ['rmsprop'] diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6d5935bc..60f9f393 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -40,7 +40,6 @@ from torchopt.transform import trace from torchopt.typing import GradientTransformation, ScalarOrSchedule - __all__ = ['sgd'] diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 0f41e822..eadabb52 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -23,7 +23,6 @@ from torchopt.transform import scale, scale_by_schedule from torchopt.transform.utils import tree_map_flat, tree_map_flat_ - if TYPE_CHECKING: import torch diff --git a/torchopt/base.py b/torchopt/base.py index 81892e17..c167cce8 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -38,7 +38,6 @@ from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol from typing_extensions import Self # Python 3.11+ - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/clip.py b/torchopt/clip.py index d64afc58..671c3914 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -26,7 +26,6 @@ from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/combine.py b/torchopt/combine.py index 15345286..6c72446a 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -38,7 +38,6 @@ from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 4cff14c6..4220ba05 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -18,5 +18,4 @@ from torchopt.diff.implicit.decorator import custom_root from torchopt.diff.implicit.nn import ImplicitMetaGradientModule - __all__ = ['ImplicitMetaGradientModule', 'custom_root'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 11ba0153..5dcfaed1 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -45,7 +45,6 @@ from torchopt import linear_solve, pytree - if TYPE_CHECKING: from torchopt.typing import ( ListOfOptionalTensors, diff --git a/torchopt/diff/implicit/nn/__init__.py b/torchopt/diff/implicit/nn/__init__.py index e91ef8ed..b62abcc2 100644 --- a/torchopt/diff/implicit/nn/__init__.py +++ b/torchopt/diff/implicit/nn/__init__.py @@ -17,7 +17,6 @@ import torchopt.nn.module # preload to resolve circular references from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule - __all__ = ['ImplicitMetaGradientModule'] del torchopt diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 6b214cb8..7c42dc21 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -30,7 +30,6 @@ from torchopt.nn.module import MetaGradientModule from torchopt.nn.stateless import reparametrize, swap_state - if TYPE_CHECKING: import torch diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index 4369f4e5..205e83ff 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -24,7 +24,6 @@ from torchopt.diff.zero_order.decorator import zero_order from torchopt.diff.zero_order.nn import ZeroOrderGradientModule - __all__ = ['ZeroOrderGradientModule', 'zero_order'] diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index e498b43c..ea10702d 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -17,7 +17,6 @@ from __future__ import annotations import functools -import itertools from typing import Any, Callable, Literal, Sequence from typing_extensions import TypeAlias # Python 3.10+ @@ -124,7 +123,7 @@ def add_perturbation( for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = list( - itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + map(add_perturbation, flat_diff_params, noises), ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, @@ -228,7 +227,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = list( - itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + map(add_perturbation, flat_diff_params, noises), ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, diff --git a/torchopt/diff/zero_order/nn/__init__.py b/torchopt/diff/zero_order/nn/__init__.py index f2753b27..06967fc7 100644 --- a/torchopt/diff/zero_order/nn/__init__.py +++ b/torchopt/diff/zero_order/nn/__init__.py @@ -17,7 +17,6 @@ import torchopt.nn.module # preload to resolve circular references from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule - __all__ = ['ZeroOrderGradientModule'] del torchopt diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index eeddabeb..fe95f8e3 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -28,7 +28,6 @@ from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order from torchopt.nn.stateless import reparametrize - if TYPE_CHECKING: from torchopt.typing import Numeric, TupleOfTensors diff --git a/torchopt/distributed/__init__.py b/torchopt/distributed/__init__.py index 31f1283b..cf0e6dae 100644 --- a/torchopt/distributed/__init__.py +++ b/torchopt/distributed/__init__.py @@ -21,7 +21,6 @@ from torchopt.distributed.api import * # noqa: F403 from torchopt.distributed.world import * # noqa: F403 - __all__ = ['is_available', *api.__all__, *world.__all__] diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 97be682f..28bd34e8 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -39,7 +39,6 @@ from torchopt.distributed.world import get_worker_id, get_world_rank, get_world_size from torchopt.typing import Future - __all__ = [ 'TensorDimensionPartitioner', 'batch_partitioner', @@ -318,12 +317,12 @@ def remote_async_call( futures.append(fut) future = cast( - Future[List[T]], + 'Future[List[T]]', torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]), ) if reducer is not None: return cast( - Future[U], + 'Future[U]', future.then(lambda fut: reducer(fut.wait())), ) return future diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 71afdb86..87c73a1b 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -23,7 +23,6 @@ import torch.distributed.autograd as autograd from torch.distributed.autograd import context - if TYPE_CHECKING: from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index 610e52a0..597f700e 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -24,7 +24,6 @@ import torch.distributed.rpc as rpc from torch.distributed.elastic.multiprocessing.errors import record - __all__ = [ 'auto_init_rpc', 'barrier', diff --git a/torchopt/hook.py b/torchopt/hook.py index c11b92f6..767206b2 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -21,7 +21,6 @@ from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation - if TYPE_CHECKING: import torch diff --git a/torchopt/linalg/__init__.py b/torchopt/linalg/__init__.py index fc499d67..00f69b22 100644 --- a/torchopt/linalg/__init__.py +++ b/torchopt/linalg/__init__.py @@ -34,5 +34,4 @@ from torchopt.linalg.cg import cg from torchopt.linalg.ns import ns, ns_inv - __all__ = ['cg', 'ns', 'ns_inv'] diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 1096a5af..146ec5d9 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -44,7 +44,6 @@ from torchopt.linalg.utils import cat_shapes, normalize_matvec from torchopt.pytree import tree_vdot_real - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 5fc8d478..3fdac97c 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -26,7 +26,6 @@ from torchopt import pytree from torchopt.linalg.utils import normalize_matvec - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index bbcc80aa..2f21d2b6 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -23,7 +23,6 @@ from torchopt import pytree - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index 43ca1da0..6abde479 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -35,5 +35,4 @@ from torchopt.linear_solve.inv import solve_inv from torchopt.linear_solve.normal_cg import solve_normal_cg - __all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index 23814cc2..83e18ed6 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -41,7 +41,6 @@ from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec - if TYPE_CHECKING: from torchopt.typing import LinearSolver, TensorTree diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index 4dbe1542..aaa62c75 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -43,7 +43,6 @@ from torchopt import linalg, pytree from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec - if TYPE_CHECKING: from torchopt.typing import LinearSolver, TensorTree diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index a5af49b2..457e278c 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -41,7 +41,6 @@ from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec - if TYPE_CHECKING: from torchopt.typing import LinearSolver, TensorTree diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 9d1b8779..c075ee7d 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -39,7 +39,6 @@ from torchopt import pytree - if TYPE_CHECKING: from torchopt.typing import TensorTree diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index b55e49d7..b76ad5ba 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -19,7 +19,6 @@ from torchopt.nn.module import MetaGradientModule from torchopt.nn.stateless import reparameterize, reparametrize, swap_state - __all__ = [ 'ImplicitMetaGradientModule', 'MetaGradientModule', diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 8c40f58a..3af2f045 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -25,7 +25,6 @@ from torchopt import pytree - if TYPE_CHECKING: from torchopt.typing import TensorContainer diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index c7f92b86..1b06cf0d 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -19,7 +19,6 @@ import contextlib from typing import TYPE_CHECKING, Generator, Iterable - if TYPE_CHECKING: import torch import torch.nn as nn @@ -84,7 +83,7 @@ def reparametrize( module: nn.Module, named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, -) -> Generator[nn.Module, None, None]: +) -> Generator[nn.Module]: """Reparameterize the module parameters and/or buffers.""" if not isinstance(named_tensors, dict): named_tensors = dict(named_tensors) diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py index 600b69c5..85bb24d8 100644 --- a/torchopt/optim/adadelta.py +++ b/torchopt/optim/adadelta.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 06091281..bd70aa20 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 555af22e..916edf1f 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py index e4996e85..dff41451 100644 --- a/torchopt/optim/adamax.py +++ b/torchopt/optim/adamax.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index a60061ea..ae95d158 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index bdaa0d67..5163f1ff 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -25,7 +25,6 @@ from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors from torchopt.update import apply_updates - __all__ = ['Optimizer'] diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index fa287f04..3795f250 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -23,7 +23,6 @@ from torchopt.base import GradientTransformation, UninitializedState from torchopt.update import apply_updates - if TYPE_CHECKING: from torchopt.typing import OptState, Params diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py index eb386ae3..c5f4bf62 100644 --- a/torchopt/optim/meta/adadelta.py +++ b/torchopt/optim/meta/adadelta.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 129c1338..97846c44 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index 7a78ea7f..cb799a5f 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py index d6b40427..4ff38271 100644 --- a/torchopt/optim/meta/adamax.py +++ b/torchopt/optim/meta/adamax.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 62864582..f04b3fa4 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 73ecdde7..337ed0cf 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -27,7 +27,6 @@ from torchopt.update import apply_updates from torchopt.utils import extract_module_containers - __all__ = ['MetaOptimizer'] diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py index bb07b5ba..02622e14 100644 --- a/torchopt/optim/meta/radam.py +++ b/torchopt/optim/meta/radam.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer - if TYPE_CHECKING: import torch.nn as nn diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index a8b4abfa..1ba01cc1 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -20,7 +20,6 @@ from torchopt.optim.meta.base import MetaOptimizer from torchopt.typing import ScalarOrSchedule - __all__ = ['MetaRMSProp', 'MetaRMSprop'] diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 81e04413..0d722607 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -20,7 +20,6 @@ from torchopt.optim.meta.base import MetaOptimizer from torchopt.typing import ScalarOrSchedule - __all__ = ['MetaSGD'] diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py index 20e9dd22..d4f7f26e 100644 --- a/torchopt/optim/radam.py +++ b/torchopt/optim/radam.py @@ -21,7 +21,6 @@ from torchopt import alias from torchopt.optim.base import Optimizer - if TYPE_CHECKING: import torch diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index 032e5864..6753c348 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -22,7 +22,6 @@ from torchopt.optim.base import Optimizer from torchopt.typing import ScalarOrSchedule - __all__ = ['RMSProp', 'RMSprop'] diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index 27cd53c1..ab7c3b4f 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -22,7 +22,6 @@ from torchopt.optim.base import Optimizer from torchopt.typing import ScalarOrSchedule - __all__ = ['SGD'] diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 53abc2d2..c67d8083 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -26,7 +26,6 @@ import torch.distributed.rpc as rpc from optree import * # pylint: disable=wildcard-import,unused-wildcard-import - if TYPE_CHECKING: from torchopt.typing import Future, RRef, Scalar, T, TensorTree diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index d3d3eff5..3adc61c8 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -34,5 +34,4 @@ from torchopt.schedule.exponential_decay import exponential_decay from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule - __all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index c19c54b9..c9469b8e 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -37,7 +37,6 @@ import math from typing import TYPE_CHECKING - if TYPE_CHECKING: from torchopt.typing import Numeric, Scalar, Schedule diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index d2a5160c..dc4104da 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -39,7 +39,6 @@ import numpy as np import torch - if TYPE_CHECKING: from torchopt.typing import Numeric, Scalar, Schedule diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index fa59a43b..c3dbf095 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -44,7 +44,6 @@ from torchopt.transform.scale_by_stddev import scale_by_stddev from torchopt.transform.trace import trace - __all__ = [ 'add_decayed_weights', 'masked', diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 0cb67837..c7984d57 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -40,7 +40,6 @@ from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ - if TYPE_CHECKING: import torch diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 740df1b0..481719e8 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -21,7 +21,6 @@ from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation - if TYPE_CHECKING: import torch diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 2b492bdf..86325ad1 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -39,7 +39,6 @@ from torchopt.base import EmptyState, GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_ - if TYPE_CHECKING: import torch diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index 6d05e5dd..28518bab 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -27,7 +27,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index d45d1eb2..664915bc 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -44,7 +44,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index cfacbf35..bf855e43 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -27,7 +27,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py index 95f26149..9132c69b 100644 --- a/torchopt/transform/scale_by_radam.py +++ b/torchopt/transform/scale_by_radam.py @@ -28,7 +28,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index f2141388..95ed7935 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -41,7 +41,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 642b2e5c..0c5ef3e2 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -41,7 +41,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 499e2adb..95a8b6b2 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -41,7 +41,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ - if TYPE_CHECKING: from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 5a3e6655..48afeeff 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -43,7 +43,6 @@ from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 219cbbec..9c87880b 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -43,7 +43,6 @@ from torchopt.base import GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ - if TYPE_CHECKING: from torchopt.typing import OptState, Params, Updates diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 9b38d561..cb9e99fc 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -40,7 +40,6 @@ from torchopt import pytree - if TYPE_CHECKING: from torchopt.typing import TensorTree, Updates diff --git a/torchopt/typing.py b/torchopt/typing.py index fcd888fb..425941ac 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -45,7 +45,6 @@ UninitializedState, ) - __all__ = [ 'ChainedGradientTransformation', 'Device', diff --git a/torchopt/update.py b/torchopt/update.py index 3f2d71fe..3ebe7199 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -37,7 +37,6 @@ from torchopt import pytree - if TYPE_CHECKING: import torch diff --git a/torchopt/utils.py b/torchopt/utils.py index 5f9202a3..8d098892 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -27,7 +27,6 @@ from torchopt import pytree from torchopt.typing import Device, ModuleTensorContainers, OptState, TensorContainer, TensorTree - if TYPE_CHECKING: from torchopt.optim.meta.base import MetaOptimizer @@ -79,13 +78,13 @@ def fn_(obj: Any) -> None: obj.detach_().requires_grad_(requires_grad) if isinstance(target, ModuleState): - true_target = cast(TensorTree, (target.params, target.buffers)) + true_target = cast('TensorTree', (target.params, target.buffers)) elif isinstance(target, nn.Module): - true_target = cast(TensorTree, tuple(target.parameters())) + true_target = cast('TensorTree', tuple(target.parameters())) elif isinstance(target, MetaOptimizer): - true_target = cast(TensorTree, target.state_dict()) + true_target = cast('TensorTree', target.state_dict()) else: - true_target = cast(TensorTree, target) # tree of tensors + true_target = cast('TensorTree', target) # tree of tensors pytree.tree_map_(fn_, true_target) @@ -325,7 +324,7 @@ def recover_state_dict( from torchopt.optim.meta.base import MetaOptimizer if isinstance(target, nn.Module): - params, buffers, *_ = state = cast(ModuleState, state) + params, buffers, *_ = state = cast('ModuleState', state) params_containers, buffers_containers = extract_module_containers(target, with_buffers=True) if state.detach_buffers: @@ -343,7 +342,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: ): tgt.update(src) elif isinstance(target, MetaOptimizer): - state = cast(Sequence[OptState], state) + state = cast('Sequence[OptState]', state) target.load_state_dict(state) else: raise TypeError(f'Unexpected class of {target}') @@ -422,9 +421,9 @@ def module_clone( # noqa: C901 if isinstance(target, (nn.Module, MetaOptimizer)): if isinstance(target, nn.Module): - containers = cast(TensorTree, extract_module_containers(target, with_buffers=True)) + containers = cast('TensorTree', extract_module_containers(target, with_buffers=True)) else: - containers = cast(TensorTree, target.state_dict()) + containers = cast('TensorTree', target.state_dict()) tensors = pytree.tree_leaves(containers) memo = {id(t): t for t in tensors} cloned = copy.deepcopy(target, memo=memo) @@ -476,7 +475,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: replicate = clone_detach_ - return pytree.tree_map(replicate, cast(TensorTree, target)) + return pytree.tree_map(replicate, cast('TensorTree', target)) @overload diff --git a/torchopt/version.py b/torchopt/version.py index 9fdcac9b..69aff7da 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -25,7 +25,7 @@ try: prefix, sep, suffix = ( - subprocess.check_output( # noqa: S603 + subprocess.check_output( ['git', 'describe', '--abbrev=7'], # noqa: S607 cwd=os.path.dirname(os.path.abspath(__file__)), stderr=subprocess.DEVNULL, diff --git a/torchopt/visual.py b/torchopt/visual.py index 7638d7ec..c3a462f1 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -27,7 +27,6 @@ from torchopt import pytree from torchopt.utils import ModuleState - if TYPE_CHECKING: from torchopt.typing import TensorTree @@ -129,7 +128,7 @@ def make_dot( # noqa: C901 elif isinstance(param, Generator): param_map.update({v: k for k, v in param}) else: - param_map.update({v: k for k, v in cast(Mapping, param).items()}) + param_map.update({v: k for k, v in cast('Mapping', param).items()}) node_attr = { 'style': 'filled', diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index afc55f38..231bceff 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -1,588 +1,588 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# TorchOpt as Functional Optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Basic API\n", - "\n", - "In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "\n", - "import functorch\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax\n", - "import torch\n", - "import torch.autograd\n", - "import torch.nn as nn\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "def mse(inputs, targets):\n", - " return ((inputs - targets) ** 2).mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.1 Original JAX implementation\n", - "\n", - "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_jax():\n", - " batch_size = 1\n", - " dim = 1\n", - " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", - "\n", - " def model(params, x):\n", - " return jnp.matmul(x, params['weight']) + params['bias']\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = optax.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " def compute_loss(params, x, y):\n", - " pred = model(params, x)\n", - " return mse(pred, y)\n", - "\n", - " xs = 2 * jnp.ones((batch_size, dim))\n", - " ys = jnp.ones((batch_size, 1))\n", - "\n", - " grads = jax.grad(compute_loss)(params, xs, ys)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optax.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[1.]], dtype=float32)),\n", - " ('bias', DeviceArray([0.], dtype=float32))\n", - "])\n", - "Parameters after update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", - " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", - "])\n" - ] - } - ], - "source": [ - "origin_jax()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.2 `functorch` with TorchOpt\n", - "\n", - "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = torchopt.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " Parameter containing: tensor([-1.0000], requires_grad=True)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch_with_wrapper():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optimizer.step(loss, params)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " tensor([[6.6757e-06]], grad_fn=),\n", - " tensor([-1.0000], grad_fn=)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch_with_wrapper()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.3 Full TorchOpt\n", - "\n", - "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def full_torchopt():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " # High-level API\n", - " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", - " # Low-level API\n", - " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "full_torchopt()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.4 Original PyTorch\n", - "\n", - "The final example is to original PyTorch example with `torch.optim`." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_torch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "origin_torch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Differentiable Optimization with Functional Optimizer\n", - "\n", - "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", - "\n", - "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "def differentiable():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Meta-parameter\n", - " meta_param = nn.Parameter(torch.ones(1))\n", - "\n", - " # SGD example\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.sgd(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " # Where meta_param is used\n", - " pred = pred + meta_param\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params, create_graph=True)\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", - " # Update parameters with single step SGD update\n", - " params = torchopt.apply_updates(params, updates, inplace=False)\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - " loss.backward()\n", - "\n", - " print('Gradient for the meta-parameter:', meta_param.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Gradient for the meta-parameter: tensor([32.])\n" - ] - } - ], - "source": [ - "differentiable()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 Track the Gradient of Momentum\n", - "\n", - "Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Accelerated Optimizer\n", - "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Check whether the `accelerated_op` is available:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TorchOpt as Functional Optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Basic API\n", + "\n", + "In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import functorch\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import torch\n", + "import torch.autograd\n", + "import torch.nn as nn\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "def mse(inputs, targets):\n", + " return ((inputs - targets) ** 2).mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1 Original JAX implementation\n", + "\n", + "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def origin_jax():\n", + " batch_size = 1\n", + " dim = 1\n", + " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", + "\n", + " def model(params, x):\n", + " return jnp.matmul(x, params['weight']) + params['bias']\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = optax.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " def compute_loss(params, x, y):\n", + " pred = model(params, x)\n", + " return mse(pred, y)\n", + "\n", + " xs = 2 * jnp.ones((batch_size, dim))\n", + " ys = jnp.ones((batch_size, 1))\n", + "\n", + " grads = jax.grad(compute_loss)(params, xs, ys)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optax.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[1.]], dtype=float32)),\n", + " ('bias', DeviceArray([0.], dtype=float32))\n", + "])\n", + "Parameters after update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", + " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", + "])\n" + ] + } + ], + "source": [ + "origin_jax()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2 `functorch` with TorchOpt\n", + "\n", + "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = torchopt.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " Parameter containing: tensor([-1.0000], requires_grad=True)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch_with_wrapper():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optimizer.step(loss, params)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " tensor([[6.6757e-06]], grad_fn=),\n", + " tensor([-1.0000], grad_fn=)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch_with_wrapper()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3 Full TorchOpt\n", + "\n", + "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def full_torchopt():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " # High-level API\n", + " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", + " # Low-level API\n", + " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } + ], + "source": [ + "full_torchopt()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4 Original PyTorch\n", + "\n", + "The final example is to original PyTorch example with `torch.optim`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def origin_torch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } + ], + "source": [ + "origin_torch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Differentiable Optimization with Functional Optimizer\n", + "\n", + "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", + "\n", + "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def differentiable():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " # Meta-parameter\n", + " meta_param = nn.Parameter(torch.ones(1))\n", + "\n", + " # SGD example\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.sgd(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " # Where meta_param is used\n", + " pred = pred + meta_param\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params, create_graph=True)\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", + " # Update parameters with single step SGD update\n", + " params = torchopt.apply_updates(params, updates, inplace=False)\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + " loss.backward()\n", + "\n", + " print('Gradient for the meta-parameter:', meta_param.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gradient for the meta-parameter: tensor([32.])\n" + ] + } + ], + "source": [ + "differentiable()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Track the Gradient of Momentum\n", + "\n", + "Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Accelerated Optimizer\n", + "\n", + "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check whether the `accelerated_op` is available:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "torchopt.accelerated_op_available(torch.device('cpu'))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "torchopt.accelerated_op_available(torch.device('cuda'))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "net = Net(1).cuda()\n", + "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" + ] } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cpu'))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cuda'))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "net = Net(1).cuda()\n", - "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index dd58c48d..07ae9c00 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -57,7 +57,6 @@ "\n", "import torchopt\n", "\n", - "\n", "x = torch.tensor(1.0, requires_grad=True)\n", "y = 2 * x\n", "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" @@ -181,8 +180,9 @@ "# Draw computation graph\n", "display(\n", " torchopt.visual.make_dot(\n", - " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", - " )\n", + " loss,\n", + " [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}],\n", + " ),\n", ")" ] } diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index 69be77ed..6c254f33 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -200,8 +200,9 @@ "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, @@ -247,8 +248,9 @@ "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, @@ -513,21 +515,30 @@ "source": [ "functional_adam = torchopt.adam(\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", - " )\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", + " ),\n", ")\n", "\n", "adam = torchopt.Adam(\n", " net.parameters(),\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", " ),\n", ")\n", "\n", "meta_adam = torchopt.MetaAdam(\n", " net,\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", " ),\n", ")" ] @@ -610,19 +621,26 @@ "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", + " net,\n", + " by='reference',\n", + " enable_visual=True,\n", + " visual_prefix='step0.',\n", ")\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", "net_state_1 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", + " net,\n", + " by='reference',\n", + " enable_visual=True,\n", + " visual_prefix='step1.',\n", ")\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index d8c24bc6..d6f03aa9 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -192,7 +192,7 @@ " one_step_net_state,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")" ] }, @@ -393,7 +393,7 @@ " one_step_net_state,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")\n", "\n", "# Outer update\n", @@ -457,7 +457,9 @@ "torchopt.stop_gradient(net)\n", "torchopt.stop_gradient(optim)\n", "one_step_net_state_detached = torchopt.extract_state_dict(\n", - " net, enable_visual=True, visual_prefix='step1.detached.'\n", + " net,\n", + " enable_visual=True,\n", + " visual_prefix='step1.detached.',\n", ")\n", "\n", "# Inner update\n", @@ -480,7 +482,7 @@ " one_step_net_state_detached,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")" ] }, diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index 23407801..5f4d3357 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -1,576 +1,578 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", - "metadata": {}, - "source": [ - "# TorchOpt for Implicit Differentiation" - ] - }, - { - "cell_type": "markdown", - "id": "2b547376", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", - "metadata": {}, - "source": [ - "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." - ] - }, - { - "cell_type": "markdown", - "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", - "metadata": {}, - "outputs": [], - "source": [ - "import functorch\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", - "metadata": {}, - "source": [ - "## 1. Functional API\n", - "\n", - "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "# Functional API for implicit gradient\n", - "def stationary(params, meta_params, data):\n", - " # stationary condition construction\n", - " return stationary condition\n", - "\n", - "# Decorator that wraps the function\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", - "def solve(params, meta_params, data):\n", - " # Forward optimization process for params\n", - " return optimal_params\n", - "\n", - "# Define params, meta_params and get data\n", - "params, meta_prams, data = ..., ..., ...\n", - "optimal_params = solve(params, meta_params, data)\n", - "loss = outer_loss(optimal_params)\n", - "\n", - "meta_grads = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", - "metadata": {}, - "source": [ - "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", - "\n", - "$$\n", - "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", - "$$\n", - "\n", - "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", - "\n", - "$$\n", - "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", - "$$\n", - "\n", - "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", - "metadata": {}, - "outputs": [], - "source": [ - "# Inner-loop objective function\n", - "# The optimality function: grad(imaml_objective)\n", - "def imaml_objective(params, meta_params, data):\n", - " x, y, fmodel = data\n", - " y_pred = fmodel(params, x)\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " loss = F.mse_loss(y_pred, y) + regularization_loss\n", - " return loss\n", - "\n", - "\n", - "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", - "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", - "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", - "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", - "\n", - "\n", - "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", - "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", - "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - ")\n", - "def inner_solver(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params\n", - "\n", - "\n", - "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", - ")\n", - "def inner_solver_inv_ns(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params" - ] - }, - { - "cell_type": "markdown", - "id": "32a75c81-d479-4120-a73d-5b2b488358d0", - "metadata": {}, - "source": [ - "In the next step, we consider a specific case for one layer neural network to fit the linear data." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(0)\n", - "x = torch.randn(20, 4)\n", - "w = torch.randn(4, 1)\n", - "b = torch.randn(1)\n", - "y = x @ w + b + 0.5 * torch.randn(20, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "eeb1823a-2231-4471-bb68-cce7724f2578", - "metadata": {}, - "source": [ - "We instantiate an one layer neural network, where the weights and bias are initialized with constant." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "model = Net(4)\n", - "fmodel, meta_params = functorch.make_functional(model)\n", - "data = (x, y, fmodel)\n", - "\n", - "\n", - "# Clone function for parameters\n", - "def clone(params):\n", - " cloned = []\n", - " for item in params:\n", - " if isinstance(item, torch.Tensor):\n", - " cloned.append(item.clone().detach_().requires_grad_(True))\n", - " else:\n", - " cloned.append(item)\n", - " return tuple(cloned)" - ] - }, - { - "cell_type": "markdown", - "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", - "metadata": {}, - "source": [ - "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", - "\n", - "outer_loss = fmodel(optimal_params, x).mean()" - ] - }, - { - "cell_type": "markdown", - "id": "e2812351-f635-496e-9732-c80831ac04a6", - "metadata": {}, - "source": [ - "Finally, we can get the meta-gradient as shown below." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", - "metadata": {}, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] + "cells": + [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Implicit Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for implicit gradient\n", + "def stationary(params, meta_params, data):\n", + " # stationary condition construction\n", + " return stationary condition\n", + "\n", + "# Decorator that wraps the function\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", + "def solve(params, meta_params, data):\n", + " # Forward optimization process for params\n", + " return optimal_params\n", + "\n", + "# Define params, meta_params and get data\n", + "params, meta_prams, data = ..., ..., ...\n", + "optimal_params = solve(params, meta_params, data)\n", + "loss = outer_loss(optimal_params)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", + "\n", + "$$\n", + "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", + "$$\n", + "\n", + "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", + "\n", + "$$\n", + "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", + "$$\n", + "\n", + "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [], + "source": [ + "# Inner-loop objective function\n", + "# The optimality function: grad(imaml_objective)\n", + "def imaml_objective(params, meta_params, data):\n", + " x, y, fmodel = data\n", + " y_pred = fmodel(params, x)\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " loss = F.mse_loss(y_pred, y) + regularization_loss\n", + " return loss\n", + "\n", + "\n", + "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", + "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", + "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", + "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", + "\n", + "\n", + "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", + "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + ")\n", + "def inner_solver(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params\n", + "\n", + "\n", + "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", + ")\n", + "def inner_solver_inv_ns(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params" + ] + }, + { + "cell_type": "markdown", + "id": "32a75c81-d479-4120-a73d-5b2b488358d0", + "metadata": {}, + "source": [ + "In the next step, we consider a specific case for one layer neural network to fit the linear data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "x = torch.randn(20, 4)\n", + "w = torch.randn(4, 1)\n", + "b = torch.randn(1)\n", + "y = x @ w + b + 0.5 * torch.randn(20, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "eeb1823a-2231-4471-bb68-cce7724f2578", + "metadata": {}, + "source": [ + "We instantiate an one layer neural network, where the weights and bias are initialized with constant." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "model = Net(4)\n", + "fmodel, meta_params = functorch.make_functional(model)\n", + "data = (x, y, fmodel)\n", + "\n", + "\n", + "# Clone function for parameters\n", + "def clone(params):\n", + " cloned = []\n", + " for item in params:\n", + " if isinstance(item, torch.Tensor):\n", + " cloned.append(item.clone().detach_().requires_grad_(True))\n", + " else:\n", + " cloned.append(item)\n", + " return tuple(cloned)" + ] + }, + { + "cell_type": "markdown", + "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", + "metadata": {}, + "source": [ + "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", + "\n", + "outer_loss = fmodel(optimal_params, x).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "e2812351-f635-496e-9732-c80831ac04a6", + "metadata": {}, + "source": [ + "Finally, we can get the meta-gradient as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "cell_type": "markdown", + "id": "926ae8bb", + "metadata": {}, + "source": [ + "Also we can switch to the Neumann Series inversion linear solver." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43df0374", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", + "outer_loss = fmodel(optimal_params, x).mean()\n", + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ImplicitMetaGradientModule\n", + "\n", + "# Inherited from the class ImplicitMetaGradientModule\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", + " def __init__(self, meta_module):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + "\n", + " def optimality(self, batch, labels):\n", + " # Stationary condition construction for calculating implicit gradient\n", + " # NOTE: If this method is not implemented, it will be automatically derived from the\n", + " # gradient of the `objective` function.\n", + " ...\n", + "\n", + " def objective(self, batch, labels):\n", + " # Define the inner-loop optimization objective\n", + " # NOTE: This method is optional if method `optimality` is implemented.\n", + " ...\n", + "\n", + " def solve(self, batch, labels):\n", + " # Conduct the inner-loop optimization\n", + " ...\n", + " return self # optimized module\n", + "\n", + "# Get meta_params and data\n", + "meta_params, data = ..., ...\n", + "inner_net = InnerNet()\n", + "\n", + "# Solve for inner-loop process related to the meta-parameters\n", + "optimal_inner_net = inner_net.solve(meta_params, *data)\n", + "\n", + "# Get outer-loss and solve for meta-gradient\n", + "loss = outer_loss(optimal_inner_net)\n", + "meta_grad = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", + "metadata": {}, + "source": + [ + "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, n_inner_iter, reg_param):\n", + " super().__init__()\n", + " # Declaration of the meta-parameter\n", + " self.meta_net = meta_net\n", + " # Get a deepcopy, register inner-parameter\n", + " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", + " self.n_inner_iter = n_inner_iter\n", + " self.reg_param = reg_param\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + " def objective(self, x, y):\n", + " # We do not implement the optimality conditions, so it will be automatically derived from\n", + " # the gradient of the `objective` function.\n", + " y_pred = self(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " regularization_loss = 0\n", + " for p1, p2 in zip(\n", + " self.parameters(), # parameters of `self.net`\n", + " self.meta_parameters(), # parameters of `self.meta_net`\n", + " ):\n", + " regularization_loss += (\n", + " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " )\n", + " return loss + regularization_loss\n", + "\n", + " def solve(self, x, y):\n", + " params = tuple(self.parameters())\n", + " inner_optim = torchopt.SGD(params, lr=2e-2)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for _ in range(self.n_inner_iter):\n", + " loss = self.objective(x, y)\n", + " inner_optim.zero_grad()\n", + " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", + " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", + " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", + " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", + " loss.backward(inputs=params) # backward pass in inner-loop\n", + " inner_optim.step() # update inner parameters\n", + " return self\n", + "\n", + "\n", + "# Initialize the meta-network\n", + "meta_net = Net(4)\n", + "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve(x, y)\n", + "outer_loss = optimal_inner_net(x).mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + }, + { + "cell_type": "markdown", + "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", + "metadata": {}, + "source": [ + "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(\n", + "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", + "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", + "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", + "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", + "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", + ")\n" + ] + } + ], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, x0):\n", + " super().__init__()\n", + " # Register meta-parameter\n", + " self.meta_net = meta_net\n", + " # Declaration of the inner-parameter, register inner-parameter\n", + " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.meta_net(x)\n", + "\n", + " def optimality(self):\n", + " # Fixed-point condition\n", + " return (self.x - self(self.x),)\n", + "\n", + " def solve(self):\n", + " # Solving inner-loop fixed-point iteration\n", + " # This is just an illustrating example for solving fixed-point iteration\n", + " # one can use more advanced method to solve fixed-point iteration\n", + " # such as anderson acceleration.\n", + " for _ in range(10):\n", + " self.x.copy_(self(self.x))\n", + " return self\n", + "\n", + "\n", + "# Initialize meta-network\n", + "torch.manual_seed(0)\n", + "meta_net = Net(4)\n", + "x0 = torch.randn(1, 4)\n", + "inner_net = InnerNet(meta_net, x0)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve()\n", + "outer_loss = optimal_inner_net.x.mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 } - ], - "source": [ - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "cell_type": "markdown", - "id": "926ae8bb", - "metadata": {}, - "source": [ - "Also we can switch to the Neumann Series inversion linear solver." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "43df0374", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", - "outer_loss = fmodel(optimal_params, x).mean()\n", - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", - "metadata": {}, - "source": [ - "## 2. OOP API\n", - "\n", - "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "from torchopt.nn import ImplicitMetaGradientModule\n", - "\n", - "# Inherited from the class ImplicitMetaGradientModule\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", - " def __init__(self, meta_module):\n", - " ...\n", - "\n", - " def forward(self, batch):\n", - " # Forward process\n", - " ...\n", - "\n", - " def optimality(self, batch, labels):\n", - " # Stationary condition construction for calculating implicit gradient\n", - " # NOTE: If this method is not implemented, it will be automatically derived from the\n", - " # gradient of the `objective` function.\n", - " ...\n", - "\n", - " def objective(self, batch, labels):\n", - " # Define the inner-loop optimization objective\n", - " # NOTE: This method is optional if method `optimality` is implemented.\n", - " ...\n", - "\n", - " def solve(self, batch, labels):\n", - " # Conduct the inner-loop optimization\n", - " ...\n", - " return self # optimized module\n", - "\n", - "# Get meta_params and data\n", - "meta_params, data = ..., ...\n", - "inner_net = InnerNet()\n", - "\n", - "# Solve for inner-loop process related to the meta-parameters\n", - "optimal_inner_net = inner_net.solve(meta_params, *data)\n", - "\n", - "# Get outer-loss and solve for meta-gradient\n", - "loss = outer_loss(optimal_inner_net)\n", - "meta_grad = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", - "metadata": {}, - "source": [ - "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, n_inner_iter, reg_param):\n", - " super().__init__()\n", - " # Declaration of the meta-parameter\n", - " self.meta_net = meta_net\n", - " # Get a deepcopy, register inner-parameter\n", - " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", - " self.n_inner_iter = n_inner_iter\n", - " self.reg_param = reg_param\n", - "\n", - " def forward(self, x):\n", - " return self.net(x)\n", - "\n", - " def objective(self, x, y):\n", - " # We do not implement the optimality conditions, so it will be automatically derived from\n", - " # the gradient of the `objective` function.\n", - " y_pred = self(x)\n", - " loss = F.mse_loss(y_pred, y)\n", - " regularization_loss = 0\n", - " for p1, p2 in zip(\n", - " self.parameters(), # parameters of `self.net`\n", - " self.meta_parameters(), # parameters of `self.meta_net`\n", - " ):\n", - " regularization_loss += (\n", - " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " )\n", - " return loss + regularization_loss\n", - "\n", - " def solve(self, x, y):\n", - " params = tuple(self.parameters())\n", - " inner_optim = torchopt.SGD(params, lr=2e-2)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for _ in range(self.n_inner_iter):\n", - " loss = self.objective(x, y)\n", - " inner_optim.zero_grad()\n", - " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", - " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", - " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", - " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", - " loss.backward(inputs=params) # backward pass in inner-loop\n", - " inner_optim.step() # update inner parameters\n", - " return self\n", - "\n", - "\n", - "# Initialize the meta-network\n", - "meta_net = Net(4)\n", - "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve(x, y)\n", - "outer_loss = optimal_inner_net(x).mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - }, - { - "cell_type": "markdown", - "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", - "metadata": {}, - "source": [ - "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "(\n", - "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", - "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", - "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", - "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", - "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", - ")\n" - ] - } - ], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, dim)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, x0):\n", - " super().__init__()\n", - " # Register meta-parameter\n", - " self.meta_net = meta_net\n", - " # Declaration of the inner-parameter, register inner-parameter\n", - " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", - "\n", - " def forward(self, x):\n", - " return self.meta_net(x)\n", - "\n", - " def optimality(self):\n", - " # Fixed-point condition\n", - " return (self.x - self(self.x),)\n", - "\n", - " def solve(self):\n", - " # Solving inner-loop fixed-point iteration\n", - " # This is just an illustrating example for solving fixed-point iteration\n", - " # one can use more advanced method to solve fixed-point iteration\n", - " # such as anderson acceleration.\n", - " for _ in range(10):\n", - " self.x.copy_(self(self.x))\n", - " return self\n", - "\n", - "\n", - "# Initialize meta-network\n", - "torch.manual_seed(0)\n", - "meta_net = Net(4)\n", - "x0 = torch.randn(1, 4)\n", - "inner_net = InnerNet(meta_net, x0)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve()\n", - "outer_loss = optimal_inner_net.x.mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb index d6cb028c..683eb34d 100644 --- a/tutorials/6_Zero_Order_Differentiation.ipynb +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", @@ -175,7 +175,11 @@ "\n", "\n", "@torchopt.diff.zero_order(\n", - " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", + " distribution=distribution,\n", + " method='forward',\n", + " argnums=0,\n", + " num_samples=100,\n", + " sigma=0.01,\n", ")\n", "def forward_process(params, fn, x, y):\n", " y_pred = fn(params, x)\n",