ディープラーニングモデル推論ワークフロー
ディープラーニング アプリケーションのモデル推論では、Databricks で次のワークフローをお勧めします。 TensorFlow と PyTorch を使用するノートブックの例については、「 ディープラーニング モデルの推論の例」を参照してください。
データを Spark DataFramesに読み込みます。 データ型に応じて、Databricks では次の方法でデータを読み込むことをお勧めします。
画像ファイル (JPG、PNG): 画像パスを Spark DataFrameに読み込みます。 画像の読み込みと入力データの前処理は、pandas UDF で行われます。
files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
TFRecords: spark-tensorflow-connectorを使用してデータを読み込みます。
df = spark.read.format("tfrecords").load(image_path)
Parquet、CSV、JSON、JDBC、その他のメタデータなどのデータソース: Spark データソースを使用してデータを読み込みます。
pandas UDF を使用してモデルの推論を実行します。 pandas UDF は Apache Arrow を使用してデータを転送し、pandas はデータを操作します。 モデル推論を行うために、pandas UDF を使用したワークフローの大まかな手順を次に示します。
トレーニング済みモデルを読み込む: 効率を上げるために、Databricks では、ドライバーからモデルの重みをブロードキャストし、モデル グラフを読み込んで、pandas UDF でブロードキャストされた変数から重みを取得することをお勧めします。
入力データのロードと前処理: データをバッチでロードするには、 Databricks TensorFlow には tf.data API を使用し、PyTorch には DataLoader クラス を使用することをお勧めします。 どちらもプリフェッチとマルチスレッド読み込みをサポートして、IO バウンドの待機時間を隠すこともできます。
モデル予測の実行: データ バッチでモデルの推論を実行します。
予測を Spark DataFramesに送り返す: 予測結果を収集し、
pd.Series
として返します。