PyTorchName

O projeto PyTorch é um pacote Python que fornece computação tensor acelerada por GPU e funcionalidades de alto nível para a construção de redes profundas de aprendizagem. Para obter detalhes de licenciamento, consulte o documento de licença do PyTorch no GitHub.

Para monitorar e depurar seus modelos PyTorch, considere o uso do TensorBoard.

O PyTorch está incluído no Databricks Runtime para Machine Learning. Se você estiver usando Databricks Runtime, consulte Instalar PyTorch para obter instruções sobre como instalar o PyTorch.

Observação

Este não é um guia completo do PyTorch. Para obter mais informações, consulte o site do PyTorch.

Nó único e treinamento distribuído

Para testar e migrar o fluxo de trabalho de uma única máquina, use um cluster de nó único.

Para opções de treinamento distribuído para aprendizagem profunda, consulte treinamento distribuído.

Notebook de exemplo

NotebookPyTorch

Abra o bloco de anotações em outra guia

Instalar o PyTorch

Databricks Runtime para ML

O Databricks Runtime for Machine Learning inclui PyTorch para que você possa criar os clusters e começar a usar o PyTorch. Para a versão do PyTorch instalada na versão do Databricks Runtime ML que você está usando, consulte as notas sobre a versão.

Databricks Runtime

A Databricks recomenda que você use o PyTorch incluído no Databricks Runtime for Machine Learning. No entanto, se você precisar usar o Databricks Runtime padrão, o PyTorch poderá ser instalado como uma biblioteca PyPI do Databricks. O exemplo a seguir mostra como instalar o PyTorch 1.5.0:

  • Em clusters de GPU, instale pytorch e torchvision especificando o seguinte:

    • torch==1.5.0

    • torchvision==0.6.0

  • Em clusters de CPU, instale pytorch e torchvision usando os seguintes arquivos Python wheel :

    https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    
    https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    

Erros e solução de problemas para PyTorch distribuído

As seções a seguir descrevem mensagens de erro comuns e orientações de solução de problemas para as classes: PyTorch DataParallel ou PyTorch DistributedDataParallel. A maioria desses erros provavelmente pode ser resolvida com TorchDistributor, que está disponível no Databricks Runtime ML 13.0e acima. No entanto, se TorchDistributor não for uma solução viável, as soluções recomendadas também serão fornecidas em cada seção.

Veja a seguir um exemplo de como usar o TorchDistributor:


from pyspark.ml.torch.distributor import TorchDistributor

def train_fn(learning_rate):
        # ...

num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)

distributor.run(train_fn, 1e-3)

process 0 terminated with exit code 1

Este erro ocorre ao usar Notebook, independentemente do ambiente: Databricks, máquina local, etc. Para evitar esse erro, use torch.multiprocessing.start_processes com start_method=fork em vez de torch.multiprocessing.spawn.

Por exemplo:

import torch

def train_fn(rank, learning_rate):
    # required setup, e.g. setup(rank)
        # ...

num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")

The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).

Este erro aparece quando você reinicia o treinamento distribuído após interromper a célula enquanto o treinamento está acontecendo.

Para resolver, reinicie os clusters. Se isso não resolver o problema, pode haver um erro no código da função de treinamento.