自動微分
Vibe Prompt
「從零實作一個簡單的自動微分引擎,支援 +, -, *, sin, cos,並計算 f(x,y)=sin(x)*cos(y) 的梯度。」
import math
class Var:
def __init__(self, val, children=()):
self.val = val
self.grad = 0
self._backward = lambda: None
self._prev = set(children)
def __add__(self, other):
other = other if isinstance(other, Var) else Var(other)
out = Var(self.val + other.val, (self, other))
def _backward():
self.grad += out.grad
other.grad += out.grad
out._backward = _backward
return out
def __mul__(self, other):
other = other if isinstance(other, Var) else Var(other)
out = Var(self.val * other.val, (self, other))
def _backward():
self.grad += other.val * out.grad
other.grad += self.val * out.grad
out._backward = _backward
return out
def __neg__(self):
return self * Var(-1)
def __sub__(self, other):
return self + (-other)
def sin(self):
out = Var(math.sin(self.val), (self,))
def _backward():
self.grad += math.cos(self.val) * out.grad
out._backward = _backward
return out
def cos(self):
out = Var(math.cos(self.val), (self,))
def _backward():
self.grad += -math.sin(self.val) * out.grad
out._backward = _backward
return out
def backward(self):
topo = []
visited = set()
def build(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build(child)
topo.append(v)
build(self)
self.grad = 1
for v in reversed(topo):
v._backward()
# 測試
x = Var(0.5)
y = Var(1.2)
f = (x.sin() * y.cos())
f.backward()
print(f"f = sin({x.val}) * cos({y.val}) = {f.val:.4f}")
print(f"df/dx = {x.grad:.4f} (理論: cos(x)*cos(y) = {math.cos(0.5)*math.cos(1.2):.4f})")
print(f"df/dy = {y.grad:.4f} (理論: -sin(x)*sin(y) = {-math.sin(0.5)*math.sin(1.2):.4f})")