Fused Batched Thin SVD: Engineering a 5000× Speedup with Triton Kernels
A generalized batched thin SVD system exploiting the Gram matrix shortcut, fused Triton kernels for N=2,3, and subspace-preserving Procrustes alignment.
Abstract
https://huggingface.co/AbstractPhil/svd-triton
We present a generalized batched thin SVD system for GPU that achieves 3,800–5,500× speedup over torch.linalg.svd for the N=2,3 case and 400–580× for N=4–32. The system auto-dispatches across three strategies based on the thin dimension N: fused Triton kernels for N≤3, Gram-Eigh hybrid for N=4–32, and rank-projected SVD for N≥48. We also introduce a subspace-preserving Procrustes alignment method that achieves 1.000 nearest-neighbor agreement with full-rank Procrustes while staying in the sub-millisecond compute zone. All components are designed for integration into neural network training loops where batched SVD is called every forward pass.
1. Motivation
Modern geometric deep learning architectures use SVD decomposition as a structural primitive inside the forward pass — decomposing feature maps into spatial modes (U), energy distributions (S), and channel mixing patterns (Vh). On CIFAR-sized images (32×32 pixels, 3 channels), this means batched SVD of (B, 1024, 3) matrices where B=512. The "thin" dimension is only 3.
cuSOLVER, PyTorch's backend for torch.linalg.svd, dispatches through a general-purpose bidiagonalization pipeline. For M×3 matrices, 99.9% of the compute is dispatch overhead — kernel launch, workspace allocation, format conversion. The actual 3×3 eigensolve takes nanoseconds, but cuSOLVER doesn't know that.
As of today this matters 3/24/2026 - tomorrow is another story.
Benchmark on NVIDIA RTX PRO 6000 Blackwell (B=512, M=1024):
| Method | N=2 | N=3 | N=8 | N=32 |
|---|---|---|---|---|
torch.linalg.svd |
79.8ms | 117.5ms | 169.3ms | 303.1ms |
| Our system | 0.021ms | 0.022ms | 0.290ms | 0.781ms |
| Speedup | 3,850× | 5,488× | 584× | 388× |
2. Mathematical Foundation
2.1 The Eckart-Young Shortcut
For an M×N matrix A where M >> N, the standard SVD approach works on the full M×N matrix. The thin-matrix shortcut exploits the Gram matrix:
G = A^T A (N×N, symmetric positive semi-definite)
The eigendecomposition of G gives us everything:
- Eigenvalues of G = σ² (squared singular values of A)
- Eigenvectors of G = V (right singular vectors of A)
- U = A V S^{-1} (left singular vectors recovered via matrix multiply)
This reduces an M×N SVD to an N×N eigendecomposition plus two matrix multiplies. For M=1024, N=3, we replace a 1024×3 bidiagonalization with a 3×3 eigensolver — a problem that fits entirely in scalar registers.
Mathematical lineage:
- Eckart & Young (1936): G = A^T A eigendecomposition equivalence
- Jacobi (1846): Cyclic Givens rotations for symmetric eigendecomposition
- Golub & Reinsch (1970): U = A V S^{-1} recovery formula
- Batcher (1968): Optimal sorting networks for eigenvalue ordering
2.2 Jacobi Eigensolve for 3×3 Symmetric Matrices
The 3×3 symmetric eigendecomposition decomposes into three cyclic Jacobi rotations per sweep, each zeroing one off-diagonal element:
Rotation (p, q): Given off-diagonal element g_pq and diagonal difference g_qq - g_pp:
τ = (g_qq - g_pp) / (2 · g_pq)
t = sign(τ) / (|τ| + √(1 + τ²))
c = 1 / √(1 + t²)
s = t · c
This computes the Givens rotation angle without trigonometric functions — only arithmetic and one square root.
One sweep applies rotations to pairs (0,1), (0,2), (1,2) in sequence. The matrix G converges to diagonal form (eigenvalues) while accumulated rotations form V (eigenvectors). For 3×3, 4–6 sweeps achieve machine precision. We use 6 for safety.
2.3 The 2×2 Special Case
For N=2, the eigendecomposition is a single Jacobi rotation — one application of the formula above, no iteration needed. This makes the N=2 kernel the simplest possible fused SVD: accumulate G (3 values), one rotation, sort, recover U.
3. Kernel Architecture
3.1 Fused Triton Kernel for N=3
The kernel is structured in three stages, all within a single program instance per batch element:
Stage 1 — Gram Accumulation (tiled over M):
for each tile of BLOCK_M rows:
load a0, a1, a2 (3 columns)
g00 += sum(a0 * a0)
g01 += sum(a0 * a1)
...
Six scalar accumulators. No shared memory. The only global memory access is loading A in tiles.
Stage 2 — Jacobi Eigensolver (in registers):
v[3×3] = I (9 scalar registers)
for 6 sweeps:
rotate(0,1): update g00,g01,g11,g02,g12 + v columns 0,1
rotate(0,2): update g00,g02,g22,g01,g12 + v columns 0,2
rotate(1,2): update g11,g12,g22,g01,g02 + v columns 1,2
15 scalar registers for G (6 unique) and V (9). Zero shared memory, zero global memory. The entire eigensolver is pure register arithmetic.
Stage 2b — Sort (Batcher network): Three compare-and-swap operations on (s0,s1), (s0,s2), (s1,s2) with corresponding V column permutation. Produces descending singular value order.
Stage 3 — U Recovery (tiled over M):
for each tile of BLOCK_M rows:
load a0, a1, a2
u_col_j = (a0*v[0,j] + a1*v[1,j] + a2*v[2,j]) / s_j
store u0, u1, u2
Same tiling as Stage 1. V entries are in registers from Stage 2.
Resource usage per program:
- Registers: 15 (G) + 9 (V) + 3 (S) + 6 (tile loads) ≈ 33 scalar registers
- Shared memory: 0 bytes
- Global memory: 2 passes over A (read), 1 write each for U, S, Vh
- Programs: B (one per batch element)
3.2 Why This Is Fast
The kernel eliminates every source of overhead that makes cuSOLVER slow for small N:
- Single kernel launch vs cuSOLVER's multi-kernel pipeline (workspace alloc → format conversion → bidiagonalization → QR iteration → back-conversion)
- Zero shared memory — the entire 3×3 eigensolver lives in registers
- Minimal global memory — A is read twice (Gram + U recovery), S/Vh/U written once
- No synchronization — each program is independent, no inter-thread communication
- No workspace — cuSOLVER allocates temporary buffers per batch element
The result: the kernel is limited only by global memory bandwidth for the A loads. On Blackwell, the actual compute (Jacobi iterations) takes <1μs per batch element. The 0.022ms wall time at B=512 is dominated by the kernel launch overhead and the two tiled reads of A.
3.3 The N≥4 Cliff and Gram-Eigh Hybrid
For N≥4, the Jacobi eigensolver would need N(N-1)/2 pairs per sweep with N² register slots for V — quickly exceeding register files. Instead, we use a hybrid approach:
G = torch.bmm(A.T, A) # cuBLAS bmm, highly optimized
eigenvalues, V = torch.linalg.eigh(G) # cuSOLVER eigh on N×N
U = torch.bmm(A, V) / S # cuBLAS bmm
The key insight: bmm on (B, N, M) × (B, M, N) → (B, N, N) is bandwidth-bound and highly optimized in cuBLAS. The eigh call on (B, N, N) is the bottleneck, but for N≤32, it's still sub-millisecond because the N×N matrices are tiny.
The cliff at N=48: torch.linalg.eigh serializes across the batch dimension for matrices larger than ~32×32. At N=48, timing jumps from 0.78ms to 344ms — a 440× wall. This is a cuSOLVER limitation, not a mathematical one.
4. Dispatch Strategy
def batched_svd(A, method='auto', target_rank=None):
B, M, N = A.shape
if N == 2: return batched_svd2(A) # Fused Triton, 0.02ms
elif N == 3: return batched_svd3(A) # Fused Triton, 0.02ms
elif N <= 32: return gram_eigh_svd(A) # Gram + eigh, 0.25-0.78ms
elif target_rank:
return projected_svd(A, target_rank) # Rank-k approx, ~11ms
else:
return gram_eigh_svd(A) # Exact but slow, 344ms+
4.1 Rank-Projected SVD for N≥48
For applications that don't need all N singular values, we project to a smaller subspace first:
P = randn(N, k+oversampling) / sqrt(k) # Random projection
A_proj = A @ P # (B, M, k+oversampling)
U_k, S_k, Vh_k = gram_eigh_svd(A_proj) # Cheap: (k+8)×(k+8) eigh
Vh_full = Vh_k @ P.T # Lift back to N-d
U_full = A @ Vh_full.T / S_k # Recover U
This is a simplified Halko-Martinsson-Tropp (2011) randomized SVD. With k=24 and oversampling=8, the internal eigh operates on 32×32 matrices (sub-ms) regardless of the original N. The quality depends on spectrum concentration: neural network features are typically highly low-rank, making this approximation excellent in practice.
5. Subspace-Preserving Procrustes Alignment
5.1 The Problem
Procrustes alignment finds the optimal rotation R that minimizes ||source @ R - target||. For N≤32, the standard approach (cross-covariance SVD) is sub-millisecond. For N≥48, it hits the eigh cliff.
5.2 The Solution
Project both spaces to k=24 dimensions, align there, and lift back to N-d. The critical insight: don't try to reconstruct the full rotation. Instead, decompose source into two orthogonal components and handle them separately:
P = QR(randn(N, k)).Q # Orthonormal projection basis
src_in = source @ P # Component in k-d subspace
src_perp = source - src_in @ P.T # Component orthogonal to subspace
# Procrustes in k-d (cheap)
C_k = src_proj.T @ tgt_proj
R_k = SVD(C_k).U @ SVD(C_k).Vh
# Subspace-preserving lift
aligned = src_in @ R_k @ P.T + src_perp # Rotate seen, preserve unseen
The in-subspace component gets the full k-d rotation. The orthogonal complement is left untouched. This is mathematically exact for the subspace the projection can see, and identity (no-op) for dimensions it can't.
5.3 Validation
We tested five lift-back methods across N=32-128 and k=8-64:
| Method | Cosine quality | NN agreement | Notes |
|---|---|---|---|
| Naive (P @ R @ pinv(P)) | 0.16-0.31 | 0.08-0.39 | Smears rotation, destroys perp dims |
| LERP (α blend with I) | 0.37-0.43 | 0.66-0.76 | Conservative but blurry |
| SLERP (geodesic) | Failed | — | matrix_log numerically unstable |
| Subspace-preserving | 0.38-0.44 | 1.000 | Exact for seen dims, identity for unseen |
| Stay in k-d | 0.38-0.44 | — | Reference (different space) |
Subspace-preserving achieves 1.000 nearest-neighbor agreement with full Procrustes across every configuration tested. The downstream task literally cannot distinguish between the two alignments.
5.4 Newton-Schulz Whitening
For Procrustes with whitening (required for proper alignment), we provide newton_schulz_invsqrt(G) which computes G^{-1/2} via pure batched matrix multiplies — zero eigensolvers:
Y, Z = G/trace, I
for _ in range(10):
factor = 1.5*I - 0.5*Z@Y
Y, Z = Y@factor, factor@Z
# Z ≈ G^{-1/2}
Quadratic convergence. 10 iterations for machine precision. Each iteration is 2 bmm operations. The entire whitening chain is ~1ms regardless of N.
6. Integration with Neural Networks
6.1 AMP Compatibility
torch.linalg.eigh does not support bfloat16. When running under torch.amp.autocast, the AMP context silently casts float32 tensors back to bf16 before the eigh call, causing NotImplementedError. All linalg operations must be wrapped:
def gram_eigh_svd(A):
with torch.amp.autocast('cuda', enabled=False):
A_f = A.float()
G = torch.bmm(A_f.T, A_f)
eigenvalues, V = torch.linalg.eigh(G)
...
6.2 Gradient Considerations
The backward pass through torch.linalg.eigh is numerically unstable for repeated or near-zero eigenvalues — common when the projected features don't have full rank. For observation taps (SVD features as input to a classifier), we recommend detaching the SVD output:
with torch.no_grad():
_, S, Vh = gram_eigh_svd(features)
S = S.clamp(min=1e-6)
Gradient flows through the conv/transformer backbone via the normal path. The SVD provides complementary structural features that the classifier learns to interpret.
6.3 Empirical Results
On CIFAR-100 with a 4-stage ConvEncoder (3.9M params), adding detached SVD observation taps at each depth:
| Model | Val accuracy | Params |
|---|---|---|
| Conv only (pooled features) | ~65-68% | 3.4M |
| Conv + SVD features | 70.9% | 3.9M |
| SVD contribution | +3-6 points | +0.5M |
Feature attribution analysis revealed that the SVD features and conv features are deeply entangled — zeroing either path collapses accuracy to near-chance. The classifier learned a nonlinear interaction where conv tells "what" and SVD tells "how the conv organized itself."
7. API Reference
# Unified dispatcher
U, S, Vh = batched_svd(A) # Auto-dispatch by N
U, S, Vh = batched_svd(A, method='triton') # Force fused kernel (N=2,3)
U, S, Vh = batched_svd(A, method='gram_eigh') # Force Gram+eigh path
# Standalone utilities
G_invsqrt = newton_schulz_invsqrt(G) # Batched G^{-1/2}, pure bmm
aligned, info = batched_procrustes(src, tgt, rank=24) # Subspace Procrustes
8. Reproducing Results
The profiling suite validates correctness against torch.linalg.svd and benchmarks all methods:
from triton_svd_general import main
main() # Runs validation + all benchmarks
Required: PyTorch ≥ 2.0, Triton ≥ 2.1, CUDA GPU.
Citation
@software{abstractphil2026svd,
title={Fused Batched Thin SVD: Triton Kernels for Geometric Deep Learning},
author={AbstractPhil and Claude},
year={2026},
url={https://huggingface.co/AbstractPhil}
}
License
Apache 2.0