Drawing mode (d to exit, x to clear)
class: middle, title-slide # Symmetry-Preserving Neural Networks ## CDS DS 595 ### Siddharth Mishra-Sharma [smsharma.io/teaching/ds595-ai4science](https://smsharma.io/teaching/ds595-ai4science.html) --- # Logistics 1. **Assignment 1:** due tomorrow (Wed Feb 18) 2. **Assignment 2:** released tomorrow (Wed Feb 18), due Wed Mar 4 3. **Office hours:** today (Tue) 3–5pm, CDS 1528 --- # Mildred Dresselhaus (1930–2017) .cols[ .col-1-2[ .center.width-60[] ] .col-1-2[ Physicists like Dresselhaus showed that **encoding symmetry** (via group theory) into your analysis was what made intractable problems tractable. .small.muted[Spiritual connection to today's lecture!] Predicted that a nanotube's electronic behavior depends on how the sheet is rolled up: **symmetry → physics**. **First woman tenured at MIT.** ] ] --- # Recap: MLPs .cols[ .col-1-2[ **Fully connected:** every input connects to every output. **No structure:** no assumptions about the data. .center[.eq-box[ $h\_{\ell+1} = \sigma\left( W^{(\ell)}\, h\_{\ell} + b^{(\ell)} \right)$ ]] ] .col-1-2[ .center.width-100[] ] ] --- # Recap: Inductive biases .cols[ .col-1-2[ An **inductive bias** = assumption built into the model. Many hypotheses fit the training data. The inductive bias determines which ones we **prefer**. The right inductive bias makes learning **much** easier. ] .col-1-2[ .center[  ] ] ] --- # Recap: Deep Sets .cols[ .col-1-2[ **Permutation invariance:** reorder the elements, output doesn't change. .center[.eq-box[ $f(\{x\_1, \ldots, x\_n\}) = \rho\left( \sum\_{i=1}^n \phi(x\_i) \right)$ ]] The sum doesn't care about order. ] .col-1-2[ .center.width-100[] ] ] --- # Recap: CNNs .cols[ .col-1-2[ **Translation equivariance:** shift the input, the output shifts the same way. Same kernel everywhere — the network processes every position identically. .center[.eq-box[ $z\_i = \sum\_{j \in \text{patch}(i)} w\_j \, x\_{i+j}$ ]] ] .col-1-2[ .center.width-90[] ] ] --- # Recap: GNNs .cols[ .col-1-2[ **Permutation equivariance:** relabel the nodes, the per-node outputs relabel the same way. .center[.eq-box[ $h\_v^{(\ell+1)} = \phi\big( h\_v^{(\ell)},\; \textstyle\sum\_{u \in \mathcal{N}(v)} \psi(h\_v^{(\ell)}, h\_u^{(\ell)}) \big)$ ]] Sum over neighbors is order-independent. ] .col-1-2[ .center.width-70[] ] ] --- class: center, middle, section-slide # Invariance and Equivariance --- # Invariance .cols[ .col-1-2[ Transform the input — output doesn't change. .center[.eq-box[ $f(g \cdot x) = f(x)$ ]] Energy, charge, binding affinity, mass — all scalars. .highlight[ **Scalars are invariant.** Rotate a molecule: same energy. ] ] .col-1-2[ .center.width-100[] ] ] --- # Equivariance .cols[ .col-1-2[ Transform the input — output transforms the same way. .center[.eq-box[ $f(g \cdot x) = g \cdot f(x)$ ]] Forces, velocities, dipole moments — all vectors. .highlight[ **Vectors are equivariant.** Rotate a molecule: force vectors rotate too. ] ] .col-1-2[ .center.width-100[] ] ] --- # Symmetry transformations .center.width-80[] --- # The symmetry groups | Group | Transformations | |---|---| | $S\_n$ | Permutations of $n$ objects | | $\mathrm{SO}(3)$ | Rotations in 3D | | $\mathrm{O}(3)$ | Rotations + reflections | | $\mathrm{SE}(3)$ | Rotations + translations | | $\mathrm{E}(3)$ | **Rotations + translations + reflections** | Most molecular properties: $\mathrm{E}(3)$-invariant (energy) or $\mathrm{E}(3)$-equivariant (forces). .small[Some quantities like chirality distinguish mirror images, requiring $\mathrm{SE}(3)$ instead.] --- # From GNNs to geometric GNNs .cols[ .col-1-2[ Recall message passing: $$h\_v^{(\ell+1)} = \phi\big( h\_v^{(\ell)},\; \textstyle\sum\_{u \in \mathcal{N}(v)} \psi(h\_v^{(\ell)}, h\_u^{(\ell)}) \big)$$ ] .col-1-2[ .center.width-90[] ] ] --- count: false # From GNNs to geometric GNNs .cols[ .col-1-2[ Recall message passing: $$h\_v^{(\ell+1)} = \phi\big( h\_v^{(\ell)},\; \textstyle\sum\_{u \in \mathcal{N}(v)} \psi(h\_v^{(\ell)}, h\_u^{(\ell)}) \big)$$ What **geometric information** should enter the messages? .center[.eq-box[ $m\_{ij} = \psi\big(h\_i,\; h\_j,\; \underbrace{??\vphantom{d}}\_{\text{geometry}}\big)$ ]] ] .col-1-2[ .center.width-90[] ] ] --- # The geometric information hierarchy | Messages use | Symmetry | |---|---| | Nothing geometric | Permutation only (standard GNN) | | Relative positions $r\_i - r\_j$ | + Translation invariance | | Distances $\lVert r\_i - r\_j \rVert$ | + Rotation invariance | | Distances + relative vectors for coordinate updates | + Rotation **equivariance** | .highlight[ Each row restricts the geometric input further, gaining a symmetry — and shrinking the space of functions the network can represent. ] --- # SchNet: distance-based message passing .cols[ .col-1-2[ .center[.eq-box[ $h\_i' = h\_i + \sum\_{j \in \mathcal{N}(i)} \phi(h\_j) \cdot w(d\_{ij})$ ]] Neighbor features weighted by a learned function of distance. Distances discard directional information — two neighbors at the same distance but different angles are indistinguishable. .warning[How do we recover angular information?] ] .col-1-2[ .center.width-100[] .small.muted.center[General message passing. SchNet specializes it: messages depend only on distance.] ] ] .footnote[Schütt et al., "SchNet" (2017)] --- # Encoding distances: radial basis functions A single scalar $d$ doesn't give the network much to work with. Expand it: $$e\_k(d) = \exp\left( -\gamma\, (d - \mu\_k)^2 \right)$$ .center.width-80[] .small[Centers $\mu\_k$ spaced from 0 to a cutoff. Each Gaussian "activates" for distances near its center. One scalar becomes a $K$-dimensional vector.] --- # Towards equivariance .cols[ .col-1-2[ Distance-based networks are rotation-invariant — but they can only predict **scalars**. What if you need **forces**, **velocities**, **dipole moments**? These are **vectors** — they should rotate with the input. Distances alone can't give you that. ] .col-1-2[ .center.width-100[] ] ] --- # Nonlinearities on vectors .center.width-60[] --- count: false # Nonlinearities on vectors Three operations that **do** commute with rotation: .center.width-80[] .highlight[ Any function built from these three operations is automatically equivariant. ] --- # A recipe for equivariant networks .center.width-80[] .highlight[ Nonlinearity lives entirely in scalar-space. The vector pathway stays linear — just scaling and adding — which is all it needs to stay equivariant. ] --- # EGNN: equivariant message passing .cols[ .col-1-2[ Each node has scalar features $h\_i$ and coordinates $x\_i$. A layer updates both: **1.** $m\_{ij} = \phi\_e(h\_i, h\_j, d\_{ij}^2)$ .muted[scalar messages from invariant inputs] **2.** $x\_i' = x\_i + \textstyle\sum\_j (x\_i - x\_j)\,\phi\_x(m\_{ij})$ .muted[coordinate update: scalar $\times$ relative vector] **3.** $h\_i' = \phi\_h(h\_i, \textstyle\sum\_j m\_{ij})$ .muted[feature update from aggregated messages] ] .col-1-2[ .center.width-100[] ] ] .footnote[Satorras et al., "EGNN" (2021)] --- # EGNN: equivariant message passing .center.width-90[] .footnote[Satorras et al., "EGNN" (2021)] --- # EGNN with velocities .cols[ .col-1-2[ Each node can also carry a **velocity** $v\_i$ — another equivariant vector. $$v\_i' = \phi\_v(h\_i)\, v\_i + \textstyle\sum\_j (x\_i - x\_j)\,\phi\_x(m\_{ij})$$ $$x\_i' = x\_i + v\_i'$$ Useful for dynamics, where atoms have momenta. ] .col-1-2[ .center.width-100[] ] ] .footnote[Satorras et al., "EGNN" (2021)] --- # Beyond scalars and vectors EGNN uses $L=0$ (scalars) and $L=1$ (vectors). But some local environments need **higher-order features** to distinguish. .center.width-50[] --- count: false # Beyond scalars and vectors .center.width-70[] Spherical harmonics $Y\_l^m$ are a basis for directional information — each $L$ block transforms predictably under rotation, which is why equivariant networks use them. Models using higher-order features ($L \geq 2$) are significantly more data-efficient. --- class: center, middle, section-slide # Applications --- # Geometric GNNs for atomic systems .center.width-90[] .small.muted[Duval et al., "A Hitchhiker's Guide to Geometric GNNs for 3D Atomic Systems" (2023)] --- # Molecular dynamics .cols[ .col-1-2[ $N$ atoms with positions $\{x\_i\}$, evolving under Newton's equations: $$m\_i \ddot{x}\_i = F\_i = -\nabla\_{x\_i} E(\{x\_j\})$$ $E(\{x\_j\})$ is the **potential energy surface**: all $N$ positions in, one scalar out. The force on each atom is the gradient — it captures the influence of every other atom. ] .col-1-2[ .center.width-90[] .small.muted.center[Protein in water] ] ] --- # Where forces come from .center.width-70[] .small.muted[Friederich et al., Nature Materials (2021)] --- # Symmetry and the gradient trick Energy is invariant: $\quad E(\{x\_i\}) = E(\{Rx\_i + t\})$ Forces are equivariant: $\quad F\_i(\{Rx\_j\}) = R\, F\_i(\{x\_j\})$ Train an **invariant** network to predict $E$, then get forces by automatic differentiation: .center[.eq-box[ $F\_i = -\nabla\_{x\_i} E$ ]] If $E$ is rotation-invariant, $F\_i$ is **automatically** rotation-equivariant — equivariance for free. --- class: center, middle, section-slide # Should We Always Encode Symmetry? --- # Data efficiency An equivariant network **knows** rotations preserve energy — it doesn't need to see every rotation of every molecule. .center.width-60[] .small.muted[Batzner et al., NequIP (2022). Higher-order equivariant features ($L \geq 1$) achieve lower error at every training set size.] --- # Compute efficiency Even with infinite data, equivariant models use compute more efficiently. .center.width-30[] Both follow power-law scaling, but equivariant (red) maintains a consistent advantage at every compute budget. .small.muted[Brehmer et al. (2025)] --- # The alternative: data augmentation Instead of encoding symmetry, **train on transformed copies** of the data so the network learns invariance from examples. .center.width-50[] --- # The optimization perspective .cols[ .col-1-2[ Constraining the function space of the neural network can make optimization more challenging! A more general model has "more room to breathe" and can be earier from an optimization perspective. A tradeoff to consider. ] .col-1-2[ .center.width-120[] ] ] --- # Sometimes (?) the bitter lesson wins .cols[ .col-1-2[ AlphaFold 2 used SE(3)-equivariant layers throughout. AlphaFold 3 **dropped equivariance** in favor of a diffusion-based architecture with data augmentation. Seems like it was simply not needed and simplicity won out. ] .col-1-2[ .center.width-100.shadow[] .small.muted.center[Abramson et al., Nature (2024)] ] ] --- # Summary .highlight[ **Invariance**: $f(g \cdot x) = f(x)$ **Equivariance**: $f(g \cdot x) = g \cdot f(x)$ ] .cols[ .col-1-2[ .center.width-100[] .center.small[Equivariant message passing updates both scalar features and coordinates.] ] .col-1-2[ .center.width-100[] .center.small[Built-in symmetry often means less data needed to reach the same accuracy.] ] ] --- # Next time: Generative models .center[
] .center.small.muted[Seedance 2.0 video gen]