Deep learning model inference performance tuning guide
This section provides some tips for debugging and performance tuning for model inference on Databricks. For an overview, see the deep learning inference workflow.
Typically there are two main parts in model inference: data input pipeline and model inference. The data input pipeline is heavy on data I/O input and model inference is heavy on computation. Determining the bottleneck of the workflow is simple. Here are some approaches:
Reduce the model to a trivial model and measure the examples per second. If the difference of the end to end time between the full model and the trivial model is minimal, then the data input pipeline is likely a bottleneck, otherwise model inference is the bottleneck.
If running model inference with GPU, check the GPU utilization metrics. If GPU utilization is not continuously high, then the data input pipeline may be the bottleneck.
Optimize data input pipeline
Using GPUs can efficiently optimize the running speed for model inference. As GPUs and other accelerators become faster, it is important that the data input pipeline keep up with demand. The data input pipeline reads the data into Spark Dataframes, transforms it, and loads it as the input for model inference. If data input is the bottleneck, here are some tips to increase I/O throughput:
Set the max records per batch. Larger number of max records can reduce the I/O overhead to call the UDF function as long as the records can fit in memory. To set the batch size, set the following config:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "5000")
Load the data in batches and prefetch it when preprocessing the input data in the pandas UDF.
For TensorFlow, Databricks recommends using the tf.data API. You can parse the map in parallel by setting
num_parallel_calls
in amap
function and callprefetch
andbatch
for prefetching and batching.dataset.map(parse_example, num_parallel_calls=num_process).prefetch(prefetch_size).batch(batch_size)
For PyTorch, Databricks recommends using the DataLoader class. You can set
batch_size
for batching andnum_workers
for parallel data loading.torch.utils.data.DataLoader(images, batch_size=batch_size, num_workers=num_process)