Shreya Shankar
Shreya Shankar

@sh_reya

6 Tweets Dec 09, 2022
Dear ML community, read past the abstract! Before you jadedly scream "why train a gajillion parameters on a gajillion TPUs," realize that @lepikhin et al. did so much more. Thread:
If you were just going to scale up params & machines, you'd imagine that adding more layers would cause computation to increase linearly. But naive parallelism by splitting layer weights and computation across multiple devices adds overhead. Computation scales *super-linearly.*
@lepikhin et al. introduce two main systems optimizations: 1) API for the to intelligently partition portions of the computation graph across multiple devices, 2) technique to make code compilation time independent of the number of devices. Now computation scales sub-linearly!
Additionally, there's a modeling contribution: incorporating a layer that shards feedforward layers across multiple devices, rather than copy each feedforward layer to each device (seems to be new for transformers). But again, this is a systems-driven optimization.
Most people just don't realize that it's very difficult to train a ~100B+ parameter model quickly on many machines! I don't like that this paper gets a bad rep b/c people's main takeaway is a sarcastic, "look, the authors trained a 600B parameter model! Wowie zowie!"
This paper isn't just throwing more compute at a bigger transformer model. It's a big systems feat!
Also no intended offense to @mark_riedl -- I just RTed this tweet because it was most popular and represents the sentiment of many other tweets about this paper that I've seen.

Loading suggestions...