data:image/s3,"s3://crabby-images/39abf/39abfe49c5772415117213693699e6f8cece551c" alt="image"
Table des matières
PyTorch: Multi-GPU and multi-node data parallelism
This page explains how to distribute an artificial neural network model implemented in a PyTorch code, according to the data parallelism method.
Here, we are documenting the DistributedDataParallel
integrated solution which is the most efficient according to the PyTorch documentation. This is a multi-process parallelism which functions equally well in mono-node and multi-node.
Multi-process configuration with SLURM
For multi-nodes, it is necessary to use multi-processing managed by SLURM (execution via the SLURM command srun
). For mono-node, it is possible to use torch.multiprocessing.spawn
as indicated in the PyTorch documentation. However, it is possible, and more practical to use SLURM multi-processing in either case, mono-node or multi-node. This is what we will document on this page.
When you launch a script with the SLURM srun
command, the script is automatically distributed on all the predefined tasks. For example, if we reserve four 8-GPU nodes and request 3 GPUs per node, we obtain:
- 4 nodes, indexed from 0 to 3.
- 3 GPUs/node, indexed from 0 to 2 on each node.
- 4 x 3 = 12 processes in total, allowing the execution of 12 tasks, with ranks from 0 to 11.
The collective inter-node communications are managed by the NCCL library.
The following are two examples of SLURM scripts for Jean-Zay:
- For a reservation of N four-GPU V100 nodes via the default GPU partition:
#!/bin/bash #SBATCH --job-name=torch-multi-gpu #SBATCH --nodes=N # total number of nodes (N to be defined) #SBATCH --ntasks-per-node=4 # number of tasks per node (here 4 tasks, or 1 task per GPU) #SBATCH --gres=gpu:4 # number of GPUs reserved per node (here 4, or all the GPUs) #SBATCH --cpus-per-task=10 # number of cores per task (4x10 = 40 cores, or all the cores) #SBATCH --hint=nomultithread #SBATCH --time=40:00:00 #SBATCH --output=torch-multi-gpu%j.out ##SBATCH --account=abc@v100 module load pytorch-gpu/py3/1.11.0 srun python myscript.py
Comment : Here, the nodes are reserved exclusively. Of particular note, this gives access to the entire memory of each node.
- For a reservation of N eight-GPU A100 nodes:
#!/bin/bash #SBATCH --job-name=torch-multi-gpu #SBATCH --nodes=N # total number of nodes (N to be defined) #SBATCH --ntasks-per-node=8 # number of tasks per node (here 8 tasks, or 1 task per GPU) #SBATCH --gres=gpu:8 # number of GPUs reserved per node (here 8, or all the GPUs) #SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores) #SBATCH --hint=nomultithread #SBATCH --time=40:00:00 #SBATCH --output=torch-multi-gpu%j.out #SBATCH -C a100 ##SBATCH --account=abc@a100 module load cpuarch/amd module load pytorch-gpu/py3/1.11.0 srun python myscript.py
Comment : Here, the nodes are reserved exclusively. Of particular note, this gives access to the entire memory of each node.
Implementation of the DistributedDataParallel solution
To implement the DistributedDataParallel
solution in PyTorch, it is necessary to:
- Define the environment variables linked to the master node.
MASTER_ADD
: The IP address or the hostname of the node corresponding to task 0 (the first node on the node list). If you are in mono-node, the valuelocalhost
is sufficient.MASTER_PORT
: The number of a random port. To avoid conflicts, and by convention, we will use a port number between10001
and20000
(for example,12345
).- On Jean Zay, a library developed by IDRIS has been included in the Pytorch modules to automatically define the
MASTER_ADD
andMASTER_PORT
variables. You simply need to import it to your script:import idr_torch
This command alone will create the variables. For your information, the library is available here.
Comment: The idr_torch
module recovers the values of the environment. You can reuse them in your script by calling idr_torch.rank
, idr_torch.local_rank
, idr_torch.size
and/or idr_torch.cpus_per_task
.
- Initialise the process group (i.e. the number of processes, the protocol of collective communications or backend, …). The backends possible are
NCCL
,GLOO
andMPI
.NCCL
is recommended both for the performance and the guarantee of correct functioning on Jean Zay.import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP dist.init_process_group(backend='nccl', init_method='env://', world_size=idr_torch.size, rank=idr_torch.rank)
- Send the model on the GPU. Note that
local_rank
(numbering 0, 1, 2, … for each node) serves as GPU identifier.torch.cuda.set_device(idr_torch.local_rank) gpu = torch.device("cuda") model = model.to(gpu)
- Transform the model into distributed model associated with a GPU.
ddp_model = DDP(model, device_ids=[idr_torch.local_rank])
- Send the micro-batches and labels to the dedicated GPU during the training.
for (images, labels) in train_loader: images = images.to(gpu, non_blocking=True) labels = labels.to(gpu, non_blocking=True)
Comment: Here, the option
non_blocking=True
is necessary if the DataLoader uses thepin memory
functionality to accelerate the loading of inputs.
The code shown below illustrates the usage of the DataLoader with a sampler adapted to data parallelism.batch_size = args.batch_size batch_size_per_gpu = batch_size // idr_torch.size # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4) # Data loading code train_dataset = torchvision.datasets.MNIST(root=os.environ['DSDIR'], train=True, transform=transforms.ToTensor(), download=False) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=idr_torch.size, rank=idr_torch.rank, shuffle=True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size_per_gpu, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
Be careful, shuffling is assigned to the DistributedSampler
. Furthermore, for the seed to be different at each epoch, you need to call train_sampler.set_epoch(epoch)
at the beginning of each epoch.
Saving and loading checkpoints
It is possible to put checkpoints in place during a distributed training on GPUs.
Saving
Since the model is replicated on each GPU, the saving of checkpoints can be effectuated on just one GPU to limit the writing operations. By convention, we use the GPU rank 0:
if idr_torch.rank == 0: torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
Consequently, the checkpoint contains information from GPU rank 0 which is saved in a format specific to distributed models.
Loading
At the beginning of the training, the loading of a checkpoint is first operated by the CPU. Then, the information is sent onto the GPU.
By default and by convention, this is sent to the memory location which was used during the saving step. In our example, only the GPU 0 will load the model in memory.
For the information to be communicated to all the GPUs, it is necessary to use the map_location
argument of the torch.load
function to redirect the memory storage.
In the example below, the map_location
argument orders a redirection of the memory storage to the local GPU rank. Since this function is called by all the GPUs, each GPU loads the checkpoint in its own memory:
map_location = {'cuda:%d' % 0: 'cuda:%d' % idr_torch.local_rank} # remap storage from GPU 0 to local GPU ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH), map_location=map_location)) # load checkpoint
Comment: If a checkpoint is loaded just after a save, as in the PyTorch tutorial, it is necessary to call the dist.barrier()
method before the loading. This call to dist.barrier()
guards the synchronisation of the GPUs, guaranteeing that the saving of the checkpoint by GPU rank 0 has completely finished before the other GPUs attempt to load it.
Distributed validation
The validation step performed after each epoch or after a set of training iterations can be distributed to all GPUs engaged in model training. When data parallelism is used and the validation dataset is large, this GPU distributed validation solution appears to be the most efficient and fastest.
Here, the challenge is to calculate the metrics (loss, accuracy, etc…) per batch and per GPU, then to weighted average them on the validation dataset.
For this, it is necessary to:
- Load validation dataset in the same way as the training dataset but without randomized transformations such as data augmentation or shuffling (see documentation on loading PyTorch databases):
# validation dataset loading (imagenet for example) val_dataset = torchvision.datasets.ImageNet(root=root,split='val', transform=val_transform) # define distributed sampler for validation val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=idr_torch.size, rank=idr_torch.rank, shuffle=False) # define dataloader for validation val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size_per_gpu, shuffle=False, num_workers=4, pin_memory=True, sampler=val_sampler, prefetch_factor=2)
- Switch from “training” mode to “validation” mode to disable some training-specific features that are costly and unnecessary here:
model.eval()
to switch the model to “validation” mode and disable the management of dropouts, batchnorms, etc.'with torch.no_grad()
to ignore gradient calculation- optionally,
with autocast()
to use AMP (mixed precision)
- Evaluate the model and calculate the metric by batch in the usual way (here we take the example of calculating the loss; it will be the same for other metrics):
'outputs = model(val_images)
followed byloss = criterion(outputs, val_labels)
- Weight and accumulate the metric per GPU:
val_loss += loss * val_images.size(0) / N
withval_images.size(0)
as the size of the batch andN
the global size of the validation dataset. Knowing that the batches do not necessarily have the same size (the last batch is sometimes smaller), it is better here to use the valueval_images.size(0)
.
- Sum the metric weighted averages over all GPUs:
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
to sum the metric values calculated per GPU and communicate the result to all GPUs. This operation results in inter-GPU communications.
Example after loading validation data:
model.eval() # - switch into validation mode val_loss = torch.Tensor([0.]).to(gpu) # initialize val_loss value N = len(val_dataset) # get validation dataset length for val_images, val_labels in val_loader: # loop over validation batches val_images = val_images.to(gpu, non_blocking=True) # transfer images and labels to GPUs val_labels = val_labels.to(gpu, non_blocking=True) with torch.no_grad(): # deactivate gradient computation with autocast(): # activate AMP outputs = model(val_images) # evaluate model loss = criterion(outputs, val_labels) # compute loss val_loss += loss * val_images.size(0) / N # cumulate weighted mean per GPU dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) # sum weighted means and broadcast value to each GPU model.train() # switch again into training mode
Application example
Multi-GPU and multi-node execution with "DistributedDataParallel"
An example is found on Jean Zay in $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_Pytorch-eng.ipynb
; it uses the MNIST data base and a simple dense network. The example is a Notebook which allows creating an execution script.
You can also download the notebook by clicking on this link.
This should be copied in your personal space (ideally in your $WORK
).
$ cp $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch-eng.ipynb $WORK
You should then execute the Notebook from a Jean Zay front end after loading a PyTorch module (see our JupyterHub documentation for more information on how to run Jupyter Notebook).
Documentation and sources
Appendices
On Jean Zay, for a ResNet-101 model, by setting a fixed minibatch size (the global size of the batch increases with the number of GPUs involved), we obtain the following throughputs which grow with the number of GPUs involved. The NCCL communication protocol is always more efficient than GLOO. Communication between Octo-GPU appears slower than between quad-GPU.
For NCCL, here are the average times of a training iteration for a number of GPUs involved in the distribution. The time gaps correspond to the synchronization time between GPUs.