PyTorch project is a Python package that provides GPU accelerated tensor computation and high level functionalities for building deep learning networks. For licensing details, see the PyTorch license doc on GitHub.
To monitor and debug your PyTorch models, consider using TensorBoard.
PyTorch is included in Databricks Runtime for Machine Learning. If you are using Databricks Runtime, see Install PyTorch for instructions on installing PyTorch.
This is not a comprehensive guide to PyTorch. For more information, see the PyTorch website.
To test and migrate single-machine workflows, use a Single Node cluster.
For distributed training options for deep learning, see Distributed training.
Databricks Runtime for Machine Learning includes PyTorch so you can create the cluster and start using PyTorch. For the version of PyTorch installed in the Databricks Runtime ML version you are using, see the release notes.
Databricks recommends that you use the PyTorch included in Databricks Runtime for Machine Learning. However, if you must use the standard Databricks Runtime, PyTorch can be installed as a Databricks PyPI library. The following example shows how to install PyTorch 1.5.0:
On GPU clusters, install
torchvisionby specifying the following:
On CPU clusters, install
torchvisionby using the following Python wheel files:
The following sections describe common error messages and troubleshooting guidance for the classes: PyTorch DataParallel or PyTorch DistributedDataParallel. Most of these errors can likely be resolved with TorchDistributor, which is available on Databricks Runtime ML 13.0 and above. However, if
TorchDistributor is not a viable solution, recommended solutions are also provided within each section.
The following is an example of how to use TorchDistributor:
from pyspark.ml.torch.distributor import TorchDistributor def train_fn(learning_rate): # ... num_processes=2 distributor = TorchDistributor(num_processes=num_processes, local_mode=True) distributor.run(train_fn, 1e-3)
This error occurs when using notebooks, regardless of environment: Databricks, local machine, etc. To avoid this error, use
start_method=fork instead of
import torch def train_fn(rank, learning_rate): # required setup, e.g. setup(rank) # ... num_processes = 2 torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")
This is error appears when you restart the distributed training after interrupting the cell while training is happening.
To resolve, restart the cluster. If that does not solve the problem, there may be an error in the training function code.