Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/set_cibw_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import sys


# pylint: disable-next=consider-using-f-string
CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2]

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/set_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pathlib
import re


ROOT = pathlib.Path(__file__).absolute().parent.parent.parent

VERSION_FILE = ROOT / 'torchopt' / 'version.py'
Expand Down
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ ci:
autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
autoupdate_schedule: monthly
default_stages: [commit, push, manual]
default_stages: [pre-commit, pre-push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v6.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,24 +26,24 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v22.1.0
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.15.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 8.0.1
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 24.4.2
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.21.2
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -52,7 +52,7 @@ repos:
^examples/
)
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
rev: 7.3.0
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -68,7 +68,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Configuration file for the Sphinx documentation builder."""

# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
Expand All @@ -33,7 +34,6 @@
import sphinx
import sphinxcontrib.katex as katex


HERE = pathlib.Path(__file__).absolute().parent
PROJECT_ROOT = HERE.parent.parent

Expand Down
2 changes: 0 additions & 2 deletions examples/FuncTorch/maml_omniglot_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@

import torchopt


CWD = pathlib(__file__).absolute().parent
sys.path.append(str(CWD.parent / 'few-shot'))

from helpers.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
1 change: 0 additions & 1 deletion examples/MAML-RL/func_maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torchopt
from helpers.policy import CategoricalMLPPolicy


TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10
Expand Down
1 change: 0 additions & 1 deletion examples/MAML-RL/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from gym.envs.registration import register


register(
'TabularMDP-v0',
entry_point='helpers.tabular_mdp:TabularMDPEnv',
Expand Down
1 change: 0 additions & 1 deletion examples/MAML-RL/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torchopt
from helpers.policy import CategoricalMLPPolicy


TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10
Expand Down
1 change: 0 additions & 1 deletion examples/MAML-RL/maml_torchrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torchopt
from helpers.policy_torchrl import ActorCritic


TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10
Expand Down
1 change: 0 additions & 1 deletion examples/distributed/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import torchopt.distributed as todist
from helpers.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import torchopt.distributed as todist
from helpers.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
1 change: 0 additions & 1 deletion examples/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import torchopt
from helpers.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
1 change: 0 additions & 1 deletion examples/iMAML/imaml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from helpers.omniglot_loaders import OmniglotNShot
from torchopt.diff.implicit import ImplicitMetaGradientModule


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
1 change: 0 additions & 1 deletion examples/iMAML/imaml_omniglot_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from helpers.omniglot_loaders import OmniglotNShot
from torchopt import pytree


mpl.use('Agg')
plt.style.use('bmh')

Expand Down
40 changes: 20 additions & 20 deletions include/adam_op/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,51 @@ namespace py = pybind11;

namespace adam_op {

TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMu(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNu(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
const pyuint_t count);

void buildSubmodule(py::module &mod); // NOLINT[runtime/references]
void buildSubmodule(py::module& mod); // NOLINT[runtime/references]

} // namespace adam_op
} // namespace torchopt
38 changes: 19 additions & 19 deletions include/adam_op/adam_op_impl_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,45 @@

namespace torchopt {
namespace adam_op {
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
Expand Down
38 changes: 19 additions & 19 deletions include/adam_op/adam_op_impl_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,45 @@

namespace torchopt {
namespace adam_op {
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
Expand Down
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#endif

namespace torchopt {
__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) {
__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
const auto dim = tensor.dim();
size_t n = 1;
for (std::decay_t<decltype(dim)> i = 0; i < dim; ++i) {
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext


HERE = pathlib.Path(__file__).absolute().parent


Expand Down
Loading
Loading