Skip to yearly menu bar Skip to main content


Poster

Scaling Deep Learning Training with MPMD Pipeline Parallelism

Anxhelo Xhebraj · Sean Lee · Hanfeng Chen · Vinod Grover


Abstract: We present JaxPP, a system for efficiently scaling the training of large deep learningmodels with flexible pipeline parallelism.We introduce a seamless programming model that allows implementing user-defined pipelineschedules for gradient accumulation.JaxPP automatically distributes tasks, corresponding to pipeline stages, overa cluster of nodes and automatically infers the communication among them.We implement a MPMD runtime for asynchronous execution of SPMD tasks.The pipeline parallelism implementation of JaxPP improves hardware utilization by upto $1.11\times$ with respect to the best performing SPMD configuration.

Chat is not available.