Momentum 與 Adam

Momentum

傳統梯度下降在狹長山谷中會劇烈震盪。Momentum 加入「慣性」:

$$v_{t+1} = \beta v_t + \nabla L(w_t)$$ $$w_{t+1} = w_t - \eta v_{t+1}$$

Adam

Adam = Momentum + RMSProp,是目前最常用的最佳化器。

Vibe Prompt

「比較 SGD、Momentum、Adam 在 f(x,y)=x²+10y² 上的收斂速度,畫出三個最佳化器的參數軌跡。」

import numpy as np

def adam(grad, w0, lr=0.1, beta1=0.9, beta2=0.999, eps=1e-8, steps=100):
    w = w0.copy()
    m = np.zeros_like(w)
    v = np.zeros_like(w)
    history = [w.copy()]
    for t in range(1, steps+1):
        g = grad(w)
        m = beta1 * m + (1-beta1) * g
        v = beta2 * v + (1-beta2) * g*g
        m_hat = m / (1-beta1**t)
        v_hat = v / (1-beta2**t)
        w = w - lr * m_hat / (np.sqrt(v_hat) + eps)
        history.append(w.copy())
    return w, history

深入理解:三種最佳化器的核心差異

Momentum:用慣性克服振盪

標準 GD 在狹長山谷中會 zigzag,因為梯度在垂直方向來回擺動。Momentum 累積過去梯度的方向,像滾球一樣:

$$v_{t+1} = \beta v_t + \nabla L(w_t)$$

  • $\beta$(通常 0.9):控制慣性保留程度
  • $\beta=0$:退化為標準 GD
  • $\beta$ 越大,越平滑但可能 overshoot

RMSProp:為每個參數調整學習率

不同參數的梯度尺度可能差異很大。RMSProp 為每個參數維護獨立的學習率:

$$s_{t+1} = \beta_2 s_t + (1-\beta_2)(\nabla L(w_t))^2$$ $$w_{t+1} = w_t - \frac{\eta}{\sqrt{s_{t+1}} + \epsilon} \nabla L(w_t)$$

  • 梯度大的參數:學習率自動變小
  • 梯度小的參數:學習率自動變大

Adam:Momentum + RMSProp 的完美結合

| 特性 | SGD | Momentum | RMSProp | Adam | |------|:---:|:--------:|:-------:|:--------:| | 克服振盪 | ✗ | ✓ | ✗ | ✓ | | 自適應學習率 | ✗ | ✗ | ✓ | ✓ | | 慣性累積 | ✗ | ✓ | ✗ | ✓ | | 偏矯正 (Bias Correction) | ✗ | ✗ | ✗ | ✓ | | 收斂速度 | 慢 | 中 | 中 | 最快 |


實戰:三種最佳化器效能比較

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 建立一個 ill-conditioned 問題
A = np.array([[1, 0], [0, 20]])  # 條件數 20
b = np.array([0, 0])

def loss(w):
    return 0.5 * w.T @ A @ w - b.T @ w

def grad(w):
    return A @ w - b

def sgd_optimizer(grad_fn, w0, lr=0.1, steps=100):
    w = w0.copy()
    history = [w.copy()]
    for _ in range(steps):
        w = w - lr * grad_fn(w)
        history.append(w.copy())
    return np.array(history)

def momentum_optimizer(grad_fn, w0, lr=0.1, beta=0.9, steps=100):
    w = w0.copy()
    v = np.zeros_like(w)
    history = [w.copy()]
    for _ in range(steps):
        g = grad_fn(w)
        v = beta * v + g
        w = w - lr * v
        history.append(w.copy())
    return np.array(history)

def adam_optimizer(grad_fn, w0, lr=0.1, steps=100):
    w = w0.copy()
    m = np.zeros_like(w)
    v = np.zeros_like(w)
    history = [w.copy()]
    for t in range(1, steps+1):
        g = grad_fn(w)
        m = 0.9*m + 0.1*g
        v = 0.999*v + 0.001*g*g
        m_hat = m / (1-0.9**t)
        v_hat = v / (1-0.999**t)
        w = w - lr * m_hat / (np.sqrt(v_hat) + 1e-8)
        history.append(w.copy())
    return np.array(history)

w0 = np.array([5.0, 5.0])

h_sgd = sgd_optimizer(grad, w0, lr=0.05, steps=200)
h_mom = momentum_optimizer(grad, w0, lr=0.05, steps=200)
h_adam = adam_optimizer(grad, w0, lr=0.1, steps=200)

# 計算損失
print(f"SGD 終點損失: {loss(h_sgd[-1]):.6f}")
print(f"Momentum 終點損失: {loss(h_mom[-1]):.6f}")
print(f"Adam 終點損失: {loss(h_adam[-1]):.6f}")
print(f"SGD 震盪幅度: {np.std(h_sgd[-50:, 1]):.4f}")
print(f"Momentum 震盪: {np.std(h_mom[-50:, 1]):.4f}")

最佳化器選用指南

| 場景 | 推薦最佳化器 | 原因 | |------|:----------:|------| | 簡單凸問題 | SGD / Momentum | 不需要自適應學習率 | | Computer Vision | SGD + Momentum | 泛化能力較好 | | NLP / Transformer | Adam / AdamW | 處理稀疏梯度與不同尺度 | | GAN / 強化學習 | Adam | 穩定不穩定訓練 | | 超大模型 | AdamW | Adam + 權重衰減的正規化 | | 資源受限裝置 | RMSProp | 不需要儲存動量項 |


關鍵要點

  • ✅ Momentum 用慣性解決 GD 的振盪問題
  • ✅ RMSProp 為每個參數自適應調整學習率
  • ✅ Adam = Momentum + RMSProp + 偏矯正,是目前最通用的最佳化器
  • ✅ AdamW 改良了 Adam 的權重衰減方式,是 LLM 訓練的首選
  • ✅ 沒有萬能的最佳化器,需根據問題特性選擇


最佳化器的選擇實戰建議

沒有萬能的最佳化器。選擇取決於你的問題類型和資源。

實戰選擇指南

| 場景 | 推薦 | 理由 | |:----|:----|:----| | 簡單線性模型 | SGD | 不需要自適應學習率 | | CNN / 影像分類 | SGD + Momentum | 泛化能力較好 | | RNN / Transformer | Adam / AdamW | 處理稀疏梯度 | | GAN 訓練 | Adam | 穩定不穩定訓練 | | LLM 微調 | AdamW | Adam + 正確的權重衰減 |

Adam 的缺點

Adam 雖然收斂快,但泛化能力有時不如 SGD。AdamW 解決了這個問題——將權重衰減從 adaptive learning rate 中分離出來。

下一章預告:隨機梯度下降 SGD

Adam 每次用全部資料算梯度?不——它在 Mini-batch 上運作。下一章的 SGD 探討批次大小對訓練的影響。

解鎖完整教學內容

本章為付費內容。加入專案即可解鎖超過 5000 字的深度解析,包含 10 個以上神級 Prompt 與真實 Source Code 範例!