NVIDIA ties long-context JAX training speedups to NVSHMEM inside XLA
NVIDIA described an XLA backend path that uses NVSHMEM to reduce communication overhead in context-parallel ring attention for long sequences. Reported results include up to a 36% speedup versus NCCL on multi-node long-context training.