Adding Apple Silicon (MPS) Support to FlashABB

  FlashABB is a fast antibody structure prediction library from the Oxford Protein Informatics Group. (OXPIG) with a preprint here. It predicts paired VH/VL structures, scores, developability (FlashTAP), and produces structure-aware sequence embeddings (FlashABB-SSS) — all from a single "HEAVY|LIGHT" sequence string, with weights bundled in the repo.

The library was written for CUDA. Getting it running on Apple Silicon required fixing five separate issues. I worked with Claude Code and got the following fixes:

  ---

Fix 1 — pip ≥ 26 editable install

pip 26 dropped the legacy setup.py-only editable install fallback. A minimal pyproject.toml was all that was needed:

  [build-system]
  requires = ["setuptools>=64", "wheel"]
  build-backend = "setuptools.build_meta"

  ---

  Fix 2 — Device auto-detection

  Three places hardcoded device='cuda'. Replaced with a helper:

  def _default_device():
      if torch.cuda.is_available(): return 'cuda'
      if torch.backends.mps.is_available(): return 'mps'
      return 'cpu'

Applied to pretrained, pretrained_tap, and flash_abb.py (including featurize() and to_pdbs(), which previously re-ran featurization on CPU regardless of where the model lived).

  ---

  Fix 3 — Deferred torch.utils.checkpoint import

  checkpointing.py had a top-level import torch.utils.checkpoint. This triggers PyTorch's symbolic shapes machinery, which accesses sympy.core at class-definition time. In IPython/Jupyter, where sympy may be partially initialised when the first from flash_abb import ... cell runs, this raises:

AttributeError: module 'sympy' has no attribute 'core'

Fix: defer the import to the body of get_checkpoint_fn(), which is only reached during training — never during inference under torch.no_grad().

  ---

  Fix 4 — flex_attention import on PyTorch < 2.5

Both flashpoint_attention.py files imported torch.nn.attention.flex_attention, which landed in PyTorch 2.5. The symbols were never actually called (only a commented-out line referenced them). Wrapping in try/except makes the library load on any PyTorch ≥ 2.x:

  try:
      from torch.nn.attention.flex_attention import BlockMask, flex_attention, _score_mod_signature
  except ImportError:
      pass  # requires PyTorch >= 2.5; unused in the inference path


  ---

  Fix 5 — The critical one: SDPA dimension mismatch

  FlashpointAttention is an IPA-style (Invariant Point Attention) geometry-aware attention module — not the Dao et al. Flash Attention algorithm. It combines scalar Q/K/V with 3D geometric point clouds per residue, which results in unequal query and value dimensions: d_q = 52, d_v = 168.

F.scaled_dot_product_attention on MPS silently returns an output tensor with dimension d_q (52) instead of d_v (168) when d_q ≠ d_v. No error is raised. The wrong-shaped tensor then hits a downstream split:

  RuntimeError: split_with_sizes expects split_sizes to sum exactly to 52
    (input tensor's size at dimension -1), but got split_sizes=[16, 24, 128]

CUDA handles unequal dims correctly. The fix falls back to a manual matmul + softmax, gated on MPS only so CUDA retains the efficient SDPA kernel:

  if q.device.type == 'mps' and q.shape[-1] != v.shape[-1]:
      attn_weights = torch.matmul(q, k.transpose(-1, -2)) + attn_mask
      o = torch.matmul(attn_weights.softmax(dim=-1), v)
  else:
      o = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)

  scale=1.0 is intentional: the original code pre-scales q by 1/√d_q before the call, so SDPA's default rescaling would double-count it. The manual path applies no implicit scaling — both branches are mathematically identical.

  ---

  Numerical verification

  To confirm the fallback is correct, I ran trastuzumab VH/VL through all three FlashABB capabilities on MPS (patched, this machine) and CUDA (original code, remote workstation) and compared outputs:


  coords               max_diff=5.22e-05  mean=1.63e-06  [PASS]

  bb_coords            max_diff=1.14e-05  mean=2.90e-06  [PASS]

  tap_tensor           max_diff=7.63e-05  mean=2.17e-05  [PASS]

  sss_embeddings       max_diff=1.58e-04  mean=9.48e-06  [PASS]

  sss_mask             max_diff=0.00e+00  mean=0.00e+00  [PASS]


  Differences are float32 rounding noise from hardware-level arithmetic differences between the two backends — about five orders of magnitude below the values themselves. The PR with these fixes is open at oxpig/FlashABB.

Comments

Popular posts from this blog

Controlled Sampling in High-Dimensional Latent Spaces for Protein Design

A base model based off of gpt-oss

First Impression about Claude Code