diff --git a/.translate/state/autodiff.md.yml b/.translate/state/autodiff.md.yml new file mode 100644 index 0000000..8662fff --- /dev/null +++ b/.translate/state/autodiff.md.yml @@ -0,0 +1,6 @@ +source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d +synced-at: "2026-04-09" +model: claude-sonnet-4-6 +mode: NEW +section-count: 5 +tool-version: 0.14.1 diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index baa2850..bbee313 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,6 +1,6 @@ -source-sha: c4c03c80c1eb4318f627d869707d242d19c8cf09 -synced-at: "2026-03-20" +source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d +synced-at: "2026-04-09" model: claude-sonnet-4-6 -mode: NEW -section-count: 6 -tool-version: 0.13.0 +mode: UPDATE +section-count: 7 +tool-version: 0.14.1 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 61c6ef5..0798448 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,6 +1,6 @@ -source-sha: c4c03c80c1eb4318f627d869707d242d19c8cf09 -synced-at: "2026-03-20" +source-sha: 05ce95691fd97e48da39dd6d58fe032c03e8813d +synced-at: "2026-04-09" model: claude-sonnet-4-6 -mode: NEW -section-count: 2 -tool-version: 0.13.0 +mode: UPDATE +section-count: 3 +tool-version: 0.14.1 diff --git a/lectures/_toc.yml b/lectures/_toc.yml index 97c429c..0ea886b 100644 --- a/lectures/_toc.yml +++ b/lectures/_toc.yml @@ -25,6 +25,7 @@ parts: - file: numba - file: jax_intro - file: numpy_vs_numba_vs_jax + - file: autodiff - caption: Working with Data numbered: true chapters: diff --git a/lectures/autodiff.md b/lectures/autodiff.md new file mode 100644 index 0000000..c7c896e --- /dev/null +++ b/lectures/autodiff.md @@ -0,0 +1,524 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.17.2 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +translation: + title: 自动微分探险 + headings: + Overview: 概述 + What is automatic differentiation?: 什么是自动微分? + What is automatic differentiation?::Autodiff is not finite differences: 自动微分不是有限差分 + What is automatic differentiation?::Autodiff is not symbolic calculus: 自动微分不是符号微积分 + What is automatic differentiation?::Autodiff: 自动微分 + Some experiments: 一些实验 + Some experiments::A differentiable function: 一个可微函数 + Some experiments::Absolute value function: 绝对值函数 + Some experiments::Differentiating through control flow: 对控制流进行微分 + Some experiments::Differentiating through a linear interpolation: 对线性插值进行微分 + Gradient Descent: 梯度下降 + Gradient Descent::A function for gradient descent: 梯度下降函数 + Gradient Descent::Simulated data: 模拟数据 + Gradient Descent::Minimizing squared loss by gradient descent: 通过梯度下降最小化平方损失 + Gradient Descent::Adding a squared term: 添加二次项 + Exercises: 练习 +--- + +# 自动微分探险 + + +```{include} _admonition/gpu.md +``` + +## 概述 + +本讲座以 {doc}`我们的简要预览 ` 为基础,使用 Google JAX 对自动微分进行更深入的介绍。 + +自动微分是现代机器学习和人工智能的关键要素之一。 + +正因如此,它吸引了大量的投资,目前已有几个强大的实现可供使用。 + +其中最优秀的之一是 JAX 中包含的自动微分例程。 + +虽然其他软件包也提供此功能,但 JAX 版本特别强大,因为它与 JAX 的其他核心组件(例如 JIT 编译和并行化)集成得非常好。 + +自动微分不仅可以用于人工智能,还可以用于数学建模中面临的许多问题,例如多维非线性优化和求根问题。 + +除了 Anaconda 中已有的内容外,本讲座还需要以下库: + +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install jax +``` + +我们需要以下导入: + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from sympy import symbols +``` + +## 什么是自动微分? + +自动微分(Autodiff)是一种在计算机上计算导数的技术。 + +### 自动微分不是有限差分 + +$f(x) = \exp(2x)$ 的导数为 + +$$ + f'(x) = 2 \exp(2x) +$$ + +不知道如何求导的计算机可能会用有限差分比率来近似: + +$$ + (Df)(x) := \frac{f(x+h) - f(x)}{h} +$$ + +其中 $h$ 是一个小正数。 + +```{code-cell} ipython3 +def f(x): + "Original function." + return np.exp(2 * x) + +def f_prime(x): + "True derivative." + return 2 * np.exp(2 * x) + +def Df(x, h=0.1): + "Approximate derivative (finite difference)." + return (f(x + h) - f(x))/h + +x_grid = np.linspace(-2, 1, 200) +fig, ax = plt.subplots() +ax.plot(x_grid, f_prime(x_grid), label="$f'$") +ax.plot(x_grid, Df(x_grid), label="$Df$") +ax.legend() +plt.show() +``` + +这种数值导数通常不准确且不稳定。 + +原因之一是: + +$$ + \frac{f(x+h) - f(x)}{h} \approx \frac{0}{0} +$$ + +分子和分母中的小数值会导致舍入误差。 + +在高维情况下或对高阶导数而言,情况会呈指数级恶化。 + ++++ + +### 自动微分不是符号微积分 + ++++ + +符号微积分尝试使用微分规则来生成表示导数的单一封闭形式表达式。 + +```{code-cell} ipython3 +m, a, b, x = symbols('m a b x') +f_x = (a*x + b)**m +f_x.diff((x, 6)) # 6-th order derivative +``` + +符号微积分不适合高性能计算。 + +一个缺点是符号微积分无法对控制流进行微分。 + +此外,使用符号微积分可能涉及冗余计算。 + +例如,考虑: + +$$ + (f g h)' + = (f' g + g' f) h + (f g) h' +$$ + +如果我们在 $x$ 处求值,那么 $f(x)$ 和 $g(x)$ 各会被计算两次。 + +另外,计算 $f'(x)$ 和 $f(x)$ 可能涉及类似的项(例如,$f(x) = \exp(2x) \implies f'(x) = 2f(x)$),但符号代数并不利用这一点。 + ++++ + +### 自动微分 + +自动微分生成的函数在调用代码传入数值时对导数进行求值,而不是生成表示整个导数的单一符号表达式。 + +导数通过链式法则将计算分解为各个组成部分来构建。 + +链式法则被反复应用,直到各项化简为程序知道如何精确微分的原始函数(加法、减法、指数、正弦和余弦等)。 + ++++ + +## 一些实验 + ++++ + +让我们从 $\mathbb R$ 上的一些实值函数开始。 + ++++ + +### 一个可微函数 + ++++ + +让我们用一个相对简单的函数来测试 JAX 的自动微分。 + +```{code-cell} ipython3 +def f(x): + return jnp.sin(x) - 2 * jnp.cos(3 * x) * jnp.exp(- x**2) +``` + +我们使用 `grad` 来计算实值函数的梯度: + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +让我们绘制结果: + +```{code-cell} ipython3 +x_grid = jnp.linspace(-5, 5, 100) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +### 绝对值函数 + ++++ + +如果函数不可微会发生什么? + +```{code-cell} ipython3 +def f(x): + return jnp.abs(x) +``` + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +在不可微点 $0$ 处,`jax.grad` 返回右导数: + +```{code-cell} ipython3 +f_prime(0.0) +``` + +### 对控制流进行微分 + ++++ + +让我们尝试对一些循环和条件进行微分。 + +```{code-cell} ipython3 +def f(x): + def f1(x): + for i in range(2): + x *= 0.2 * x + return x + def f2(x): + x = sum((x**i + i) for i in range(3)) + return x + y = f1(x) if x < 0 else f2(x) + return y +``` + +```{code-cell} ipython3 +f_prime = jax.grad(f) +``` + +```{code-cell} ipython3 +x_grid = jnp.linspace(-5, 5, 100) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend() +plt.show() +``` + +### 对线性插值进行微分 + ++++ + +我们可以对线性插值进行微分,即使函数不光滑: + +```{code-cell} ipython3 +n = 20 +xp = jnp.linspace(-5, 5, n) +yp = jnp.cos(2 * xp) + +fig, ax = plt.subplots() +ax.plot(x_grid, jnp.interp(x_grid, xp, yp)) +plt.show() +``` + +```{code-cell} ipython3 +f_prime = jax.grad(jnp.interp) +``` + +```{code-cell} ipython3 +f_prime_vec = jax.vmap(f_prime, in_axes=(0, None, None)) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(x_grid, f_prime_vec(x_grid, xp, yp)) +plt.show() +``` + +## 梯度下降 + ++++ + +让我们尝试实现梯度下降。 + +作为一个简单的应用,我们将使用梯度下降来求解简单线性回归中的普通最小二乘法参数估计值。 + ++++ + +### 梯度下降函数 + ++++ + +以下是梯度下降的实现。 + +```{code-cell} ipython3 +def grad_descent(f, # Function to be minimized + args, # Extra arguments to the function + x0, # Initial condition + λ=0.1, # Initial learning rate + tol=1e-5, + max_iter=1_000): + """ + Minimize the function f via gradient descent, starting from guess x0. + + The learning rate is computed according to the Barzilai-Borwein method. + + """ + + f_grad = jax.grad(f) + x = jnp.array(x0) + df = f_grad(x, args) + ϵ = tol + 1 + i = 0 + while ϵ > tol and i < max_iter: + new_x = x - λ * df + new_df = f_grad(new_x, args) + Δx = new_x - x + Δdf = new_df - df + λ = jnp.abs(Δx @ Δdf) / (Δdf @ Δdf) + ϵ = jnp.max(jnp.abs(Δx)) + x, df = new_x, new_df + i += 1 + + return x + +``` + +### 模拟数据 + +我们将通过最小化回归问题中的最小二乘和来测试我们的梯度下降函数。 + +让我们生成一些模拟数据: + +```{code-cell} ipython3 +n = 100 +key = jax.random.key(1234) +x = jax.random.uniform(key, (n,)) + +α, β, σ = 0.5, 1.0, 0.1 # Set the true intercept and slope. +key, subkey = jax.random.split(key) +ϵ = jax.random.normal(subkey, (n,)) + +y = α * x + β + σ * ϵ +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.scatter(x, y) +plt.show() +``` + +让我们首先使用封闭形式解来计算估计的斜率和截距。 + +```{code-cell} ipython3 +mx = x.mean() +my = y.mean() +α_hat = jnp.sum((x - mx) * (y - my)) / jnp.sum((x - mx)**2) +β_hat = my - α_hat * mx +``` + +```{code-cell} ipython3 +α_hat, β_hat +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x, α_hat * x + β_hat, 'k-') +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +### 通过梯度下降最小化平方损失 + ++++ + +让我们看看是否可以用我们的梯度下降函数得到相同的值。 + +首先我们建立最小二乘损失函数。 + +```{code-cell} ipython3 +@jax.jit +def loss(params, data): + a, b = params + x, y = data + return jnp.sum((y - a * x - b)**2) +``` + +现在我们对其进行最小化: + +```{code-cell} ipython3 +p0 = jnp.zeros(2) # Initial guess for α, β +data = x, y +α_hat, β_hat = grad_descent(loss, data, p0) +``` + +让我们绘制结果。 + +```{code-cell} ipython3 +fig, ax = plt.subplots() +x_grid = jnp.linspace(0, 1, 100) +ax.scatter(x, y) +ax.plot(x_grid, α_hat * x_grid + β_hat, 'k-', alpha=0.6) +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +注意,我们得到了与封闭形式解相同的估计值。 + ++++ + +### 添加二次项 + +现在让我们尝试拟合一个二次多项式。 + +以下是我们新的损失函数。 + +```{code-cell} ipython3 +@jax.jit +def loss(params, data): + a, b, c = params + x, y = data + return jnp.sum((y - a * x**2 - b * x - c)**2) +``` + +现在我们在三维空间中进行最小化。 + +让我们试试看。 + +```{code-cell} ipython3 +p0 = jnp.zeros(3) +α_hat, β_hat, γ_hat = grad_descent(loss, data, p0) + +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x_grid, α_hat * x_grid**2 + β_hat * x_grid + γ_hat, 'k-', alpha=0.6) +ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') +ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') +plt.show() +``` + +## 练习 + +```{exercise-start} +:label: auto_ex1 +``` + +函数 `jnp.polyval` 用于求多项式的值。 + +例如,如果 `len(p)` 为 3,那么 `jnp.polyval(p, x)` 返回: + +$$ + f(p, x) := p_0 x^2 + p_1 x + p_2 +$$ + +使用该函数进行多项式回归。 + +(经验)损失函数为: + +$$ + \ell(p, x, y) + = \sum_{i=1}^n (y_i - f(p, x_i))^2 +$$ + +设 $k=4$,将 `params` 的初始猜测值设为 `jnp.zeros(k)`。 + +使用梯度下降找到使损失函数最小化的数组 `params`,并绘制结果(参照上面的示例)。 + + +```{exercise-end} +``` + +```{solution-start} auto_ex1 +:class: dropdown +``` + +以下是一种解法。 + +```{code-cell} ipython3 +def loss(params, data): + x, y = data + return jnp.sum((y - jnp.polyval(params, x))**2) +``` + +```{code-cell} ipython3 +k = 4 +p0 = jnp.zeros(k) +p_hat = grad_descent(loss, data, p0) +print('Estimated parameter vector:') +print(p_hat) +print('\n\n') + +fig, ax = plt.subplots() +ax.scatter(x, y) +ax.plot(x_grid, jnp.polyval(p_hat, x_grid), 'k-', alpha=0.6) +plt.show() +``` + + +```{solution-end} +``` \ No newline at end of file diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index ce8bbea..e1b89e0 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -34,10 +34,14 @@ translation: JIT compilation::Evaluating a more complicated function: 评估更复杂的函数 JIT compilation::Evaluating a more complicated function::With NumPy: 使用 NumPy JIT compilation::Evaluating a more complicated function::With JAX: 使用 JAX - JIT compilation::Compiling the Whole Function: 编译整个函数 + JIT compilation::How JIT compilation works: JIT 编译的工作原理 + JIT compilation::Compiling the whole function: 编译整个函数 JIT compilation::Compiling non-pure functions: 编译非纯函数 JIT compilation::Summary: 总结 - Gradients: 梯度 + Vectorization with `vmap`: 使用 `vmap` 进行向量化 + Vectorization with `vmap`::A simple example: 一个简单的示例 + Vectorization with `vmap`::Combining transformations: 组合变换 + 'Automatic differentiation: a preview': 自动微分:预览 Exercises: 练习 --- @@ -78,20 +82,20 @@ JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情 ```{code-cell} ipython3 import jax -import quantecon as qe -``` - -此外,我们用以下代码替换 `import numpy as np`: - -```{code-cell} ipython3 import jax.numpy as jnp +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np +import quantecon as qe import matplotlib as mpl # i18n FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf" # i18n mpl.font_manager.fontManager.addfont(FONTPATH) # i18n mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ``` -现在我们可以用 `jnp` 代替 `np` 来进行常规数组操作: +注意我们导入了 `jax.numpy as jnp`,它提供了类似 NumPy 的接口。 + +以下是使用 `jnp` 进行的一些标准数组操作: ```{code-cell} ipython3 a = jnp.asarray((1.0, 3.2, -1.5)) @@ -147,7 +151,6 @@ jnp.linalg.inv(B) # Inverse of identity is identity jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors ``` - ### 差异 现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。 @@ -181,7 +184,6 @@ jnp.ones(3) 例如,在 NumPy 中我们可以这样写: ```{code-cell} ipython3 -import numpy as np a = np.linspace(0, 1, 3) a ``` @@ -220,13 +222,13 @@ a_new = a.sort() # Instead, the sort method returns a new sorted array a, a_new ``` -JAX 的设计者选择将数组设为不可变的,因为 JAX 使用[函数式编程](https://en.wikipedia.org/wiki/Functional_programming)风格。 +JAX 的设计者选择将数组设为不可变的,因为 JAX 使用 [函数式编程](https://en.wikipedia.org/wiki/Functional_programming) 风格。 这个设计选择有重要的含义,我们接下来将对此进行探讨! #### 变通方法 -我们注意到 JAX 确实提供了一种使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html)进行原地数组修改的版本。 +我们注意到 JAX 确实提供了一种使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) 进行原地数组修改的版本。 ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -248,7 +250,6 @@ a (尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。) - ## 函数式编程 来自 JAX 的文档: @@ -278,8 +279,6 @@ a * 不会改变全局状态 * 不会修改传递给函数的数据(不可变数据) - - ### 示例 以下是一个*非纯*函数的示例: @@ -315,7 +314,6 @@ def add_tax_pure(prices, tax_rate): 现在我们理解了什么是纯函数,让我们探索 JAX 处理随机数的方法如何维护这种纯粹性。 - ## 随机数 与 NumPy 或 Matlab 中的随机数相比,JAX 中的随机数有很大不同。 @@ -326,7 +324,6 @@ def add_tax_pure(prices, tax_rate): 此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。 - ### 随机数生成 在 JAX 中,随机数生成器的状态被显式控制。 @@ -335,7 +332,7 @@ def add_tax_pure(prices, tax_rate): ```{code-cell} ipython3 seed = 1234 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) ``` 现在我们可以使用密钥生成一些随机数: @@ -365,6 +362,78 @@ jax.random.normal(key, (3, 3)) jax.random.normal(subkey, (3, 3)) ``` +下图说明了 `split` 如何从单个根密钥生成密钥树,每个密钥生成独立的随机抽取。 + +```{code-cell} ipython3 +:tags: [hide-input] + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.set_xlim(-0.5, 6.5) +ax.set_ylim(-0.5, 3.5) +ax.set_aspect('equal') +ax.axis('off') + +box_style = dict(boxstyle="round,pad=0.3", facecolor="white", + edgecolor="black", linewidth=1.5) +box_used = dict(boxstyle="round,pad=0.3", facecolor="#d4edda", + edgecolor="black", linewidth=1.5) + +# Root key +ax.text(3, 3, "key₀", ha='center', va='center', fontsize=11, + bbox=box_style) + +# Level 1 +ax.annotate("", xy=(1.5, 2), xytext=(3, 2.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(4.5, 2), xytext=(3, 2.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(1.5, 2, "key₁", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(4.5, 2, "subkey₁", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(5.7, 2, "→ draw", ha='left', va='center', fontsize=10, + color='green') + +# Label the split +ax.text(2, 2.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +# Level 2 +ax.annotate("", xy=(0.5, 1), xytext=(1.5, 1.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(2.5, 1), xytext=(1.5, 1.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(0.5, 1, "key₂", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(2.5, 1, "subkey₂", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(3.7, 1, "→ draw", ha='left', va='center', fontsize=10, + color='green') + +ax.text(0.7, 1.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +# Level 3 +ax.annotate("", xy=(0, 0), xytext=(0.5, 0.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.annotate("", xy=(1.5, 0), xytext=(0.5, 0.7), + arrowprops=dict(arrowstyle="->", lw=1.5)) +ax.text(0, 0, "key₃", ha='center', va='center', fontsize=11, + bbox=box_style) +ax.text(1.5, 0, "subkey₃", ha='center', va='center', fontsize=11, + bbox=box_used) +ax.text(2.7, 0, "→ draw", ha='left', va='center', fontsize=10, + color='green') +ax.text(0, 0.65, "split", ha='center', va='center', fontsize=9, + fontstyle='italic', color='gray') + +ax.text(3, -0.5, "⋮", ha='center', va='center', fontsize=14) + +ax.set_title("PRNG Key Splitting Tree", fontsize=13, pad=10) +plt.tight_layout() +plt.show() +``` + 对于 NumPy 或 Matlab 用户来说,这种语法看起来很不寻常——但当我们进入并行编程时,就会很有意义。 下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。 @@ -382,7 +451,7 @@ def gen_random_matrices(key, n=2, k=3): ```{code-cell} ipython3 seed = 42 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) matrices = gen_random_matrices(key) ``` @@ -400,11 +469,10 @@ def gen_random_matrices(key, n=2, k=3): ``` ```{code-cell} ipython3 -key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) matrices = gen_random_matrices(key) ``` - ### 为什么要显式随机状态? 为什么 JAX 需要这种相对冗长的随机数生成方法? @@ -432,7 +500,6 @@ print(np.random.randn()) # Updates state of random number generator * 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出 * 它有副作用:它修改了全局随机数生成器状态 - #### JAX 的方法 如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。 @@ -450,7 +517,7 @@ def random_sum_jax(key): 使用相同的密钥,我们总是得到相同的结果: ```{code-cell} ipython3 -key = jax.random.PRNGKey(42) +key = jax.random.key(42) random_sum_jax(key) ``` @@ -474,7 +541,6 @@ JAX 的显式性带来了显著的好处: 最后一点将在下一节中进行扩展。 - ## JIT 编译 JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。 @@ -556,7 +622,6 @@ with qe.Timer(): 这就是为什么 JAX 要等到看到数组大小后再进行编译——这需要 JIT 编译方法,而不是提供预编译的二进制文件。 - #### 更改数组大小 这里我们更改输入大小并观察运行时间。 @@ -582,14 +647,13 @@ with qe.Timer(): 这是因为 JIT 编译器针对数组大小进行专门优化以利用并行化——因此当数组大小改变时会生成新的编译代码。 - ### 评估更复杂的函数 让我们用一个更复杂的函数尝试同样的操作。 ```{code-cell} def f(x): - y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2 + y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2 return y ``` @@ -607,8 +671,6 @@ with qe.Timer(): y = f(x) ``` - - #### 使用 JAX 现在让我们用 JAX 再试一次。 @@ -641,8 +703,60 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -此外,使用 JAX,我们还有另一个技巧: +此外,使用 JAX,我们还有另一个技巧——我们可以对*整个*函数进行 JIT 编译,而不仅仅是单个操作。 + +### JIT 编译的工作原理 +当我们对一个函数应用 `jax.jit` 时,JAX 会对其进行*追踪*:它不会立即执行操作,而是将操作序列记录为计算图,并将该计算图交给 [XLA](https://openxla.org/xla) 编译器。 + +XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。 + +下图展示了一个简单函数的编译流程: + +```{code-cell} ipython3 +:tags: [hide-input] + +fig, ax = plt.subplots(figsize=(7, 2)) +ax.set_xlim(-0.2, 7.2) +ax.set_ylim(0.2, 2.2) +ax.axis('off') + +# Boxes for pipeline stages +stages = [ + (0.7, 1.2, "Python\nfunction"), + (2.6, 1.2, "computational\ngraph"), + (4.5, 1.2, "optimized\nkernel"), + (6.4, 1.2, "fast\nexecution"), +] + +colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"] + +for (x, y, label), color in zip(stages, colors): + box = mpatches.FancyBboxPatch( + (x - 0.7, y - 0.5), 1.4, 1.0, + boxstyle="round,pad=0.15", + facecolor=color, edgecolor="black", linewidth=1.5) + ax.add_patch(box) + ax.text(x, y, label, ha='center', va='center', fontsize=9) + +# Arrows with labels +arrows = [ + (1.4, 1.9, "trace"), + (3.3, 3.8, "XLA"), + (5.2, 5.7, "run"), +] + +for x_start, x_end, label in arrows: + ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2), + arrowprops=dict(arrowstyle="->", lw=1.5, color="gray")) + ax.text((x_start + x_end) / 2, 1.55, label, + ha='center', fontsize=8, color='gray') + +plt.tight_layout() +plt.show() +``` + +对 JIT 编译函数的第一次调用会产生编译开销,但后续使用相同输入形状和类型的调用将复用缓存的编译代码,以全速运行。 ### 编译整个函数 @@ -727,7 +841,6 @@ f(x) 这个故事的寓意:使用 JAX 时请编写纯函数! - ### 总结 现在我们可以理解为什么开发者和编译器都受益于纯函数。 @@ -744,25 +857,77 @@ f(x) * 纯函数更容易进行微分(自动微分) * 纯函数更容易并行化和优化(不依赖于共享可变状态) +## 使用 `vmap` 进行向量化 -## 梯度 +另一个强大的 JAX 变换是 `jax.vmap`,它能自动将针对单个输入编写的函数 向量化,使其可以在批量数据上运行。 -JAX 可以使用自动微分来计算梯度。 +这避免了手动编写向量化代码或使用显式循环的需要。 -这对于优化和求解非线性系统非常有用。 +### 一个简单的示例 + +假设我们有一个函数,用于计算单个数组的汇总统计量: -我们将在本讲座系列后面看到重要的应用。 +```{code-cell} ipython3 +def summary(x): + return jnp.mean(x), jnp.median(x) +``` -现在,这里有一个非常简单的说明,涉及函数: +我们可以将其应用于单个向量: ```{code-cell} ipython3 -def f(x): - return (x**2) / 2 +x = jnp.array([1.0, 2.0, 5.0]) +summary(x) +``` + +现在假设我们有一个矩阵,并希望对每一行计算这些统计量。 + +不使用 `vmap` 时,我们需要一个显式循环: + +```{code-cell} ipython3 +X = jnp.array([[1.0, 2.0, 5.0], + [4.0, 5.0, 6.0], + [1.0, 8.0, 9.0]]) + +for row in X: + print(summary(row)) +``` + +然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。 + +使用 `vmap` 可以让计算保留在加速器上,并与其他 JAX 变换(如 `jit` 和 `grad`)组合使用: + +```{code-cell} ipython3 +batch_summary = jax.vmap(summary) +batch_summary(X) +``` + +函数 `summary` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵——无需循环,无需重塑。 + +### 组合变换 + +JAX 的一大优势在于变换可以自然地组合。 + +例如,我们可以对向量化函数进行 JIT 编译: + +```{code-cell} ipython3 +fast_batch_summary = jax.jit(jax.vmap(summary)) +fast_batch_summary(X) ``` -让我们求导数: +`jit`、`vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习中尤为强大。 + +## 自动微分:预览 + +JAX 可以使用自动微分来计算梯度。 + +这对于优化和求解非线性系统非常有用。 + +以下是一个涉及函数 $f(x) = x^2 / 2$ 的简单示例: ```{code-cell} ipython3 +def f(x): + return (x**2) / 2 + f_prime = jax.grad(f) ``` @@ -770,11 +935,9 @@ f_prime = jax.grad(f) f_prime(10.0) ``` -让我们绘制函数和导数,注意 $f'(x) = x$。 +让我们绘制函数及其导数,注意 $f'(x) = x$。 ```{code-cell} ipython3 -import matplotlib.pyplot as plt - fig, ax = plt.subplots() x_grid = jnp.linspace(-4, 4, 200) ax.plot(x_grid, f(x_grid), label="$f$") @@ -783,8 +946,7 @@ ax.legend(loc='upper center') plt.show() ``` -我们将进一步探索 JAX 自动微分的内容推迟到 {doc}`jax:autodiff`。 - +自动微分是一个深刻的话题,在经济学和金融领域有许多应用。我们在 {doc}`关于自动微分的讲座 ` 中提供了更为深入的介绍。 ## 练习 @@ -826,7 +988,7 @@ def compute_call_price_jax(β=β, ρ=ρ, ν=ν, M=M, - key=jax.random.PRNGKey(1)): + key=jax.random.key(1)): s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) @@ -870,4 +1032,4 @@ with qe.Timer(): ``` ```{solution-end} -``` \ No newline at end of file +``` diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 04c4a94..1450c47 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -24,6 +24,7 @@ translation: Sequential operations::Numba Version: Numba 版本 Sequential operations::JAX Version: JAX 版本 Sequential operations::Summary: 总结 + Overall recommendations: 总体建议 --- (parallel)= @@ -69,7 +70,10 @@ tags: [hide-output] ```{code-cell} ipython3 import random +from functools import partial + import numpy as np +import numba import quantecon as qe import matplotlib.pyplot as plt import matplotlib as mpl # i18n @@ -80,6 +84,7 @@ from mpl_toolkits.mplot3d.axes3d import Axes3D from matplotlib import cm import jax import jax.numpy as jnp +from jax import lax ``` ## 向量化运算 @@ -117,7 +122,7 @@ ax.plot_surface(x, y, f(x, y), rstride=2, cstride=2, - cmap=cm.jet, + cmap=cm.viridis, alpha=0.7, linewidth=0.25) ax.set_zlim(-0.5, 1.0) @@ -143,7 +148,6 @@ for x in grid: m = z ``` - ### NumPy 向量化 如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快。 @@ -168,14 +172,11 @@ print(f"NumPy result: {z_max_numpy:.6f}") (并行化效率不高,因为二进制文件在看到数组 `x` 和 `y` 的大小之前就已经被编译了。) - ### 与 Numba 的比较 现在让我们看看能否使用简单循环的 Numba 获得更好的性能。 ```{code-cell} ipython3 -import numba - @numba.jit def compute_max_numba(grid): m = -np.inf @@ -189,9 +190,9 @@ def compute_max_numba(grid): grid = np.linspace(-3, 3, 3_000) with qe.Timer(precision=8): - z_max_numpy = compute_max_numba(grid) + z_max_numba = compute_max_numba(grid) -print(f"Numba result: {z_max_numpy:.6f}") +print(f"Numba result: {z_max_numba:.6f}") ``` 让我们再次运行以消除编译时间。 @@ -207,7 +208,6 @@ with qe.Timer(precision=8): 另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 - ### 并行化的 Numba 现在让我们使用 `prange` 尝试 Numba 的并行化: @@ -282,7 +282,6 @@ with qe.Timer(precision=8): 对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。 - ### 使用 JAX 的向量化代码 表面上,JAX 中的向量化代码与 NumPy 代码类似。 @@ -303,7 +302,7 @@ def f(x, y): ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) -x_mesh, y_mesh = np.meshgrid(grid, grid) +x_mesh, y_mesh = jnp.meshgrid(grid, grid) with qe.Timer(precision=8): z_max = jnp.max(f(x_mesh, y_mesh)) @@ -320,11 +319,10 @@ with qe.Timer(precision=8): z_max.block_until_ready() ``` -编译完成后,由于 GPU 加速,JAX 明显快于 NumPy。 +编译完成后,JAX 明显快于 NumPy,尤其是在 GPU 上。 编译开销是一次性成本,当函数被反复调用时,这种开销是值得的。 - ### JAX 加 vmap NumPy 代码和 JAX 代码都存在一个问题: @@ -386,7 +384,6 @@ with qe.Timer(precision=8): 当我们处理更大的问题时,将进一步探讨这些想法。 - ### vmap 版本 2 我们可以使用 vmap 进一步提高内存效率。 @@ -421,7 +418,7 @@ def compute_max_vmap_v2(grid): with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() -print(f"JAX vmap v1 result: {z_max:.6f}") +print(f"JAX vmap v2 result: {z_max:.6f}") ``` 让我们再次运行以消除编译时间: @@ -433,7 +430,6 @@ with qe.Timer(precision=8): 如果您像我们一样在 GPU 上运行,应该能看到又一个不小的速度提升。 - ### 总结 在我们看来,JAX 是向量化运算的赢家。 @@ -448,15 +444,13 @@ with qe.Timer(precision=8): 对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 - ## 顺序运算 某些运算本质上是顺序的——因此难以或不可能向量化。 在这种情况下,NumPy 是一个较差的选择,我们只剩下 Numba 或 JAX 可以选择。 -为了比较这两种选择,我们将重新回顾在{doc}`Numba 讲座 `中看到的迭代二次映射问题。 - +为了比较这两种选择,我们将重新回顾在 {doc}`Numba 讲座 ` 中看到的迭代二次映射问题。 ### Numba 版本 @@ -501,9 +495,6 @@ Numba 的编译通常相当快,对于像这样的顺序运算,生成的代 (我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。) ```{code-cell} ipython3 -from jax import lax -from functools import partial - cpu = jax.devices("cpu")[0] @partial(jax.jit, static_argnums=(1,), device=cpu) @@ -546,7 +537,6 @@ JAX 对于这种顺序运算也相当高效。 JAX 和 Numba 在编译后都能提供出色的性能,对于纯顺序运算,Numba 通常(但并非总是)提供略快的速度。 - ### 总结 虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但**在代码可读性和易用性方面存在显著差异**。 @@ -559,4 +549,30 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难。 -对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。 \ No newline at end of file +对于这类顺序运算,在代码清晰度、实现便利性以及高性能方面,Numba 是明显的赢家。 + +## 总体建议 + +让我们退一步,总结一下各方案的权衡取舍。 + +对于**向量化操作**,JAX 是最强的选择。 + +得益于 JIT 编译和跨 CPU 与 GPU 的高效并行化,它在速度上与 NumPy 持平甚至超越 NumPy。 + +`vmap` 变换可以减少内存使用,并且通常比基于传统网格(meshgrid)的向量化方式产生更清晰的代码。 + +此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进行探讨。 + +对于**顺序操作**,Numba 具有明显优势。 + +代码自然易读——只需一个带装饰器的 Python 循环——且性能出色。 + +JAX 可以通过 `lax.scan` 处理顺序问题,但对于纯顺序工作而言,其语法不够直观,性能提升也十分有限。 + +话虽如此,`lax.scan` 有一个重要优势:它支持对循环进行自动微分,而 Numba 无法做到这一点。 + +如果需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。 + +在实践中,许多问题往往同时涉及两种模式。 + +一个实用的经验法则是:新项目默认使用 JAX,尤其是在硬件加速或可微分性可能有用的情况下;当需要一个快速且可读的紧凑顺序循环时,则选用 Numba。