MTraining: Distributed Dynamic Sparse Attention for Efficient Ultra-Long Context Training
Wenxuan Li ⋅ Chengruidong Zhang ⋅ Huiqiang Jiang ⋅ Yucheng Li ⋅ ⋅ Lili Qiu
Abstract
The adoption of long context windows has become a standard feature in Large Language Models (LLMs), as extended contexts significantly enhance their capacity for complex reasoning and broaden their applicability across diverse scenarios. Dynamic sparse attention is a promising approach for reducing the computational cost of long-context training. However, efficiently training LLMs with dynamic sparse attention on ultra-long contexts, especially in distributed settings, remains a significant challenge, largely due to worker- and step-level imbalance. This paper introduces MTraining, a novel distributed methodology leveraging dynamic sparse attention to enable efficient training for LLMs with ultra-long contexts. Specifically, MTraining integrates three key components: a distributed sparse index approximating algorithm, balanced sparse ring attention, and hierarchical sparse ring attention. These components are designed to synergistically address the computational imbalance and communication overheads inherent in dynamic sparse attention mechanisms during training LLMs with extensive context lengths. We demonstrate the efficacy of MTraining mainly by training Qwen2.5-3B and Llama-3.1-8B, successfully expanding its context window from 32K/128K to 512K tokens on a cluster of 32$\times$ A100 GPUs. Our evaluations on a comprehensive suite of downstream tasks, including RULER, PG-19, InfiniteBench, and NIAH, reveal that MTraining achieves up to a 6x higher training throughput while preserving model accuracy. The core code is available at https://github.com/microsoft/MInference/tree/main/mtraining.
Chat is not available.
Successful Page Load