Scaling models using Data parallelism with DTensor

In google IO, Machine Learning, Performance, Software Development by Prabhu Missier

Datasets are getting larger and models likewise. From a few millions today’s models are trained on billions of parameters.

Traditionally parallelism was achieved by splitting the data across multiple model instances. However there was a limitation that all the model instances should run on the same device. As models get bigger you need to scale across devices.

At the recently concluded Google IO 2023, DTensor has been introduced as a panacea to this problem. With DTensor you can successfully shard the data across model replicas running on multiple devices. In a similar vein you could also achieve model parallelism by sharding the model across devices and supply each model shard with a full copy of the dataset.

What’s interesting is that you can parallelize both the data and the model all in one place using DTensor.

Google claims that this can be achieved with just a few lines of code to initialize the DTensor context.

mesh = dtensor.create_distributed_mesh(mesh_dimensions, device_type);

You don’t have to rewrite your model for different parallelism strategies. DTensor works whether you use one or a hundred devices.

Performance is on par with industry benchmarks like NVIDIA’s Megatron for transformer training and Mesh TensorFlow and Google is targeting improving upon those results.

DTensor is fully integrated with Keras and tf.distribute and there’s only a single entry point regardless of the type of device be it CPU/GPU/TPU.