PyTorch

PyTorch プロジェクトは 、GPU アクセラレーションによるテンソル計算と、ディープラーニング ネットワークを構築するための高レベルの機能を提供する Python パッケージです。 ライセンスの詳細については、 GitHub の PyTorch ライセンスに関するドキュメントを参照してください。

PyTorch モデルを監視およびデバッグするには、 TensorBoard の使用を検討してください。

PyTorch は Databricks ランタイム for Machine Learning に含まれています。 Databricks Runtimeを使用している場合、 PyTorch のインストール手順については、「PyTorch のインストール 」を参照してください。

これは PyTorch の包括的なガイドではありません。 詳細については、 PyTorch の Web サイトを参照してください。

単一ノードおよび分散トレーニング

単一マシンのワークフローをテストおよび移行するには、 単一ノード クラスターを使用します。

ディープラーニングの分散トレーニング オプションについては、「 分散トレーニング」を参照してください。

ノートブック の例

PyTorch ノートブック

ノートブックを新しいタブで開く

PyTorch をインストールする

Databricks Runtime for 機械学習

Databricks Runtime for Machine Learning には PyTorch が含まれているため、クラスターを作成して PyTorch の使用を開始できます。 使用している Databricks Runtime 機械学習バージョンにインストールされている PyTorch のバージョンについては、 リリースノートを参照してください。

Databricks Runtime

Databricks では、機械学習の Databricks Runtime に含まれている PyTorch を使用することをお勧めします。 ただし、 標準 Databricks Runtimeを使用する必要がある場合は、PyTorch を Databricks PyPI ライブラリとしてインストールできます。 次の例は、PyTorch 1.5.0 をインストールする方法を示しています。

  • GPU クラスターでは、以下を指定して pytorchtorchvision をインストールします。

    • torch==1.5.0

    • torchvision==0.6.0

  • CPU クラスターで、次の Python wheel ファイルを使用して pytorchtorchvision をインストールします。

    https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    
    https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    

分散 PyTorch のエラーとトラブルシューティング

次のセクションでは、一般的なエラー メッセージと、クラスのトラブルシューティング ガイダンスについて説明します: PyTorch DataParallel または PyTorch DistributedDataParallelこれらのエラーのほとんどは、機械学習 13.0 以降で利用可能な TorchDistributor Databricks Runtime で解決できる可能性があります。 ただし、 TorchDistributor 実行可能なソリューションではない場合は、各セクション内で推奨されるソリューションも提供されます。

トーチディストリビューターの使用方法の例を次に示します。


from pyspark.ml.torch.distributor import TorchDistributor

def train_fn(learning_rate):
        # ...

num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)

distributor.run(train_fn, 1e-3)

process 0 terminated with exit code 1

このエラーは、環境 (Databricks、ローカル コンピューターなど) に関係なく、ノートブックを使用しているときに発生します。 このエラーを回避するには、 torch.multiprocessing.spawnの代わりに torch.multiprocessing.start_processesstart_method=fork と共に使用します。

例:

import torch

def train_fn(rank, learning_rate):
    # required setup, e.g. setup(rank)
        # ...

num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")

The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).

これは、トレーニングの実行中にセルを中断した後に分散トレーニングを再開すると表示されるエラーです。

解決するには、クラスターを再起動します。 それでも問題が解決しない場合は、トレーニング関数コードにエラーがある可能性があります。