Distributed training with TorchDistributor
This article describes how to perform distributed training on PyTorch ML models using TorchDistributor.
TorchDistributor is an open-source module in PySpark that helps users do distributed training with PyTorch on their Spark clusters, so it lets you launch PyTorch training jobs as Spark jobs. Under-the-hood, it initializes the environment and the communication channels between the workers and utilizes the CLI command torch.distributed.run
to run distributed training across the worker nodes.
The TorchDistributor API supports the methods shown in the following table.
Method and signature |
Description |
---|---|
|
Create an instance of TorchDistributor. |
|
Runs distributed training by invoking |
Development workflow for notebooks
If the model creation and training process happens entirely from a notebook on your local machine or a Databricks Notebook, you only have to make minor changes to get your code ready for distributed training.
Prepare single node code: Prepare and test the single node code with PyTorch, PyTorch Lightning, or other frameworks that are based on PyTorch/PyTorch Lightning like, the HuggingFace Trainer API.
Prepare code for standard distributed training: You need to convert your single process training to distributed training. Have this distributed code all encompassed within one training function that you can use with the
TorchDistributor
.Move imports within training function: Add the necessary imports, such as
import torch
, within the training function. Doing so allows you to avoid common pickling errors. Furthermore, thedevice_id
that models and data are be tied to is determined by:device_id = int(os.environ["LOCAL_RANK"])
Launch distributed training: Instantiate the
TorchDistributor
with the desired parameters and call.run(*args)
to launch training.
The following is a training code example:
from pyspark.ml.torch.distributor import TorchDistributor
def train(learning_rate, use_gpu):
import torch
import torch.distributed as dist
import torch.nn.parallel.DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, DataLoader
backend = "nccl" if use_gpu else "gloo"
dist.init_process_group(backend)
device = int(os.environ["LOCAL_RANK"]) if use_gpu else "cpu"
model = DDP(createModel(), **kwargs)
sampler = DistributedSampler(dataset)
loader = DataLoader(dataset, sampler=sampler)
output = train(model, loader, learning_rate)
dist.cleanup()
return output
distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
distributor.run(train, 1e-3, True)
Migrate training from external repositories
If you have an existing distributed training procedure stored in an external repository, you can easily migrate to Databricks by doing the following:
Import the repository: Import the external repository as a Databricks Git folder.
Create a new notebook Initialize a new Databricks Notebook within the repository.
Launch distributed training In a notebook cell, call
TorchDistributor
like the following:
from pyspark.ml.torch.distributor import TorchDistributor
train_file = "/path/to/train.py"
args = ["--learning_rate=0.001", "--batch_size=16"]
distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
distributor.run(train_file, *args)
Troubleshooting
A common error for the notebook workflow is that objects cannot be found or pickled when running distributed training. This can happen when the library import statements are not distributed to other executors.
To avoid this issue, include all import statements (for example, import torch
) both at the top of the training function that is called with TorchDistributor(...).run(<func>)
and inside any other user-defined functions called in the training method.
NCCL failure: ncclInternalError: Internal check failed.
When you encounter this error during multi-node training, it typically indicates a problem with network communication among GPUs. This issue arises when NCCL (NVIDIA Collective Communications Library) cannot use certain network interfaces for GPU communication.
To resolve this error, add the following snippet in your training code to use the primary network interface.
import os
os.environ["NCCL_SOCKET_IFNAME"] = "eth0"