Adding Apple Silicon (MPS) Support to FlashABB
FlashABB is a fast antibody structure prediction library from the Oxford Protein Informatics Group. (OXPIG) with a preprint here. It predicts paired VH/VL structures, scores, developability (FlashTAP), and produces structure-aware sequence embeddings (FlashABB-SSS) — all from a single "HEAVY|LIGHT" sequence string, with weights bundled in the repo.
The library was written for CUDA. Getting it running on Apple Silicon required fixing five separate issues. I worked with Claude Code and got the following fixes:
---
Fix 1 — pip ≥ 26 editable install
pip 26 dropped the legacy setup.py-only editable install fallback. A minimal pyproject.toml was all that was needed:
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
---
Fix 2 — Device auto-detection
Three places hardcoded device='cuda'. Replaced with a helper:
if torch.cuda.is_available(): return 'cuda'
if torch.backends.mps.is_available(): return 'mps'
return 'cpu'
Applied to pretrained, pretrained_tap, and flash_abb.py (including featurize() and to_pdbs(), which previously re-ran featurization on CPU regardless of where the model lived).
---
Fix 3 — Deferred torch.utils.checkpoint import
checkpointing.py had a top-level import torch.utils.checkpoint. This triggers PyTorch's symbolic shapes machinery, which accesses sympy.core at class-definition time. In IPython/Jupyter, where sympy may be partially initialised when the first from flash_abb import ... cell runs, this raises:
AttributeError: module 'sympy' has no attribute 'core'
Fix: defer the import to the body of get_checkpoint_fn(), which is only reached during training — never during inference under torch.no_grad().
---
Fix 4 — flex_attention import on PyTorch < 2.5
Both flashpoint_attention.py files imported torch.nn.attention.flex_attention, which landed in PyTorch 2.5. The symbols were never actually called (only a commented-out line referenced them). Wrapping in try/except makes the library load on any PyTorch ≥ 2.x:
from torch.nn.attention.flex_attention import BlockMask, flex_attention, _score_mod_signature
except ImportError:
pass # requires PyTorch >= 2.5; unused in the inference path
---
Fix 5 — The critical one: SDPA dimension mismatch
FlashpointAttention is an IPA-style (Invariant Point Attention) geometry-aware attention module — not the Dao et al. Flash Attention algorithm. It combines scalar Q/K/V with 3D geometric point clouds per residue, which results in unequal query and value dimensions: d_q = 52, d_v = 168.
F.scaled_dot_product_attention on MPS silently returns an output tensor with dimension d_q (52) instead of d_v (168) when d_q ≠ d_v. No error is raised. The wrong-shaped tensor then hits a downstream split:
(input tensor's size at dimension -1), but got split_sizes=[16, 24, 128]
CUDA handles unequal dims correctly. The fix falls back to a manual matmul + softmax, gated on MPS only so CUDA retains the efficient SDPA kernel:
attn_weights = torch.matmul(q, k.transpose(-1, -2)) + attn_mask
o = torch.matmul(attn_weights.softmax(dim=-1), v)
else:
o = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
scale=1.0 is intentional: the original code pre-scales q by 1/√d_q before the call, so SDPA's default rescaling would double-count it. The manual path applies no implicit scaling — both branches are mathematically identical.
---
Numerical verification
To confirm the fallback is correct, I ran trastuzumab VH/VL through all three FlashABB capabilities on MPS (patched, this machine) and CUDA (original code, remote workstation) and compared outputs:
coords max_diff=5.22e-05 mean=1.63e-06 [PASS]
bb_coords max_diff=1.14e-05 mean=2.90e-06 [PASS]
tap_tensor max_diff=7.63e-05 mean=2.17e-05 [PASS]
sss_embeddings max_diff=1.58e-04 mean=9.48e-06 [PASS]
sss_mask max_diff=0.00e+00 mean=0.00e+00 [PASS]
Differences are float32 rounding noise from hardware-level arithmetic differences between the two backends — about five orders of magnitude below the values themselves. The PR with these fixes is open at oxpig/FlashABB.
Comments
Post a Comment