A
argbe.tech - news1min read
NVIDIA ties long-context JAX training speedups to NVSHMEM inside XLA
NVIDIA described an XLA backend path that uses NVSHMEM to reduce communication overhead in context-parallel ring attention for long sequences. Reported results include up to a 36% speedup versus NCCL on multi-node long-context training.
NVIDIA detailed how wiring NVSHMEM into XLA can speed up long-context LLM training in JAX by targeting the communication bottlenecks of ring attention.
- The approach focuses on context parallelism (splitting the sequence dimension across devices) using ring attention, where KV blocks move device-to-device and transfers sit on the critical path.
- NVIDIA reports training a Llama 3 8B workload with sequences up to 256K tokens using this integration.
- In the published benchmarks, NVSHMEM delivered up to a 36% speedup compared with NCCL for long-context runs, with gains increasing as sequence length grows.
- The post highlights NVSHMEM capabilities used by the compiler path, including symmetric GPU-resident memory, stream-aware operations, and interoperability with CUDA Graphs.
- Reported improvements are strongest in multi-node setups and when combined with tensor parallelism (hybrid parallelism), where fine-grained latency dominates.