Train a PyTorch model

Note

The managed MLflow integration with Databricks on Google Cloud requires Databricks Runtime for Machine Learning 8.1 or above.

PyTorch is a Python package that provides GPU-accelerated tensor computation and high level functionality for building deep learning networks.

The MLflow PyTorch notebook fits a neural network on MNIST handwritten digit recognition data and logs run results to an MLflow server. It logs training metrics and weights in TensorFlow event format locally and then uploads them to the MLflow run’s artifact directory. Finally, it starts TensorBoard and reads the events logged locally.

MLflow PyTorch model training notebook

Open notebook in new tab