Best practices for deep learning on Databricks

This article includes tips for deep learning on Databricks and information about built-in tools and libraries designed to optimize deep learning workloads such as the following:

Databricks Mosaic AI provides pre-built deep learning infrastructure with Databricks Runtime for Machine Learning, which includes the most common deep learning libraries like TensorFlow, PyTorch, and Keras. It also has built-in, pre-configured GPU support including drivers and supporting libraries.

Databricks Runtime ML also includes all of the capabilities of the Databricks workspace, such as cluster creation and management, library and environment management, code management with Databricks Git folders, automation support including Databricks Jobs and APIs, and integrated MLflow for model development tracking and model deployment and serving.

Resource and environment management

Databricks helps you to both customize your deep learning environment and keep the environment consistent across users.

Customize the development environment

With Databricks Runtime, you can customize your development environment at the notebook, cluster, and job levels.

Use cluster policies

You can create cluster policies to guide data scientists to the right choices, such as using a Single Node cluster for development and using an autoscaling cluster for large jobs.

Consider A100 GPUs for deep learning workloads

A100 GPUs are an efficient choice for many deep learning tasks, such as training and tuning large language models, natural language processing, object detection and classification, and recommendation engines.

  • Databricks supports A100 GPUs on all clouds. For the complete list of supported GPU types, see Supported instance types.

  • A100 GPUs usually have limited availability. Contact your cloud provider for resource allocation, or consider reserving capacity in advance.

GPU scheduling

To maximize your GPUs for distributed deep learning training and inference, optimize GPU scheduling. See GPU scheduling.

Best practices for loading data

Cloud data storage is typically not optimized for I/O, which can be a challenge for deep learning models that require large datasets. Databricks Runtime ML includes Delta Lake and Mosaic Streaming to optimize data throughput for deep learning applications.

Databricks recommends using Delta Lake tables for data storage. Delta Lake simplifies ETL and lets you access data efficiently. Especially for images, Delta Lake helps optimize ingestion for both training and inference. The reference solution for image applications provides an example of optimizing ETL for images using Delta Lake.

Databricks recommends Mosaic Streaming for data loading on PyTorch or Mosaic Composer, especially when it involves distributed workloads. The provided StreamingDataset and StreamingDataLoader APIs help simplify training on large datasets while maximizing correctness guarantees, performance, flexibility, and ease of use in a distributed environment. see Load data using Mosaic Streaming for additional details.

Best practices for training deep learning models

Databricks recommends using Databricks Runtime for Machine Learning and MLflow tracking and autologging for all model training.

Start with a Single Node cluster

A Single Node (driver only) GPU cluster is typically fastest and most cost-effective for deep learning model development. One node with 4 GPUs is likely to be faster for deep learning training that 4 worker nodes with 1 GPU each. This is because distributed training incurs network communication overhead.

A Single Node cluster is a good option during fast, iterative development and for training models on small- to medium-size data. If your dataset is large enough to make training slow on a single machine, consider moving to multi-GPU and even distributed compute.

Use TensorBoard to monitor the training process

TensorBoard is preinstalled in Databricks Runtime ML. You can use it within a notebook or in a separate tab. See TensorBoard for details.

Optimize performance for deep learning

You can, and should, use deep learning performance optimization techniques on Databricks.

Early stopping

Early stopping monitors the value of a metric calculated on the validation set and stops training when the metric stops improving. This is a better approach than guessing at a good number of epochs to complete. Each deep learning library provides a native API for early stopping; for example, see the EarlyStopping callback APIs for TensorFlow/Keras and for PyTorch Lightning. For an example notebook, see TensorFlow Keras example notebook.

Batch size tuning

Batch size tuning helps optimize GPU utilization. If the batch size is too small, the calculations cannot fully use the GPU capabilities.

Adjust the batch size in conjunction with the learning rate. A good rule of thumb is, when you increase the batch size by n, increase the learning rate by sqrt(n). When tuning manually, try changing batch size by a factor of 2 or 0.5. Then continue tuning to optimize performance, either manually or by testing a variety of hyperparameters using an automated tool like Optuna.

Transfer learning

With transfer learning, you start with a previously trained model and modify it as needed for your application. Transfer learning can significantly reduce the time required to train and tune a new model. See Featurization for transfer learning for more information and an example.

Move to distributed training

Databricks Runtime ML includes TorchDistributor, DeepSpeed and Ray to facilitate the move from single-node to distributed training.

TorchDistributor

TorchDistributor is an open-source module in PySpark that facilitates distributed training with PyTorch on Spark clusters, that allows you to launch PyTorch training jobs as Spark jobs. See Distributed training with TorchDistributor.

Optuna

Optuna provides adaptive hyperparameter tuning for machine learning.

Best practices for inference

This section contains general tips about using models for inference with Databricks.

  • To minimize costs, consider both CPUs and inference-optimized GPUs such as the A2 machine family. There is no clear recommendation, as the best choice depends on model size, data dimensions, and other variables.

  • Use MLflow to simplify deployment and model serving. MLflow can log any deep learning model, including custom preprocessing and postprocessing logic. Models in Unity Catalog or models registered in the Workspace Model Registry can be deployed for batch, streaming, or online inference.

Online serving

The best option for low-latency serving is online serving behind a REST API. Databricks provides Model Serving for online inference. Model Serving provides a unified interface to deploy, govern, and query AI models and supports serving the following:

  • Custom models. These are Python models packaged in the MLflow format. Examples include scikit-learn, XGBoost, PyTorch, and Hugging Face transformer models.

  • External models. These are models that are hosted outside of Databricks. For example, generative AI models like, OpenAI’s GPT-4, Anthropic’s Claude, and others. Endpoints that serve these models can be centrally governed and customers can establish rate limits and access control for them.

MLflow provides APIs for deploying to various managed services for online inference, as well as APIs for creating Docker containers for custom serving solutions.

Batch and streaming inference

Batch and streaming scoring supports high-throughput, low-cost scoring at latencies as low as minutes. For more information, see Deploy models for batch inference and prediction.

  • If you expect to access data for inference more than once, consider creating a preprocessing job to ETL the data into a Delta Lake table before running the inference job. This way, the cost of ingesting and preparing the data is spread across multiple reads of the data. Separating preprocessing from inference also allows you to select different hardware for each job to optimize cost and performance. For example, you might use CPUs for ETL and GPUs for inference.

  • Use Spark Pandas UDFs to scale batch and streaming inference across a cluster.