Skip to main content

Training MoEs at Scale with PyTorch and Databricks


Share this post
Hero image with abstract drawing representing a sparse matrix on a dark background

Mixture-of-Experts (MoE) has emerged as a promising LLM architecture for efficient training and inference. MoE models like DBRX, which use multiple expert networks to make predictions, offer a significant reduction in inference costs compared to dense models of equal quality. In this blog post, researchers at Databricks and Meta discuss libraries and tools created by both teams that facilitate MoE development within the PyTorch deep learning framework. MegaBlocks, a lightweight open source library for MoE training maintained by Databricks, is integrated into the LLM Foundry library to enable distributed model training workloads to scale to thousands of GPUs. PyTorch's low-level abstraction DTensor is used to represent parallelism strategies across GPUs. Fully Sharded Data Parallel (FSDP), PyTorch’s implementation of ZeRO-3, is an API for sharding model parameters with data parallelism. Communicating model parameters, gradients, and optimizer states across GPUs present performance challenges when scaling to thousands of GPUs, which are mitigated by PyTorch Hybrid Sharded Data Parallel (HSDP) to balance memory efficiency and communication cost. PyTorch also supports elastic sharded checkpointing for fault tolerance during long distributed training runs. To dive deeper into how PyTorch and Databricks are enabling training state-of-the-art LLMs, read the complete blog post.