Databricks Connect for Scala のユーザー定義関数User-defined functions in Databricks Connect for Scala

注:

この記事では、Databricks Connect for Databricks Runtime 14.1 以降について説明します。

この記事では、Databricks Connect for Scala を使用してユーザー定義関数を実行する方法について説明します。 Databricks Connect を使用すると、一般的な IDE、ノートブック サーバー、カスタム アプリケーションを Databricks クラスターに接続できます。 この記事の Python バージョンについては、「 Databricks Connect for Python のユーザー定義関数」を参照してください。

注:

Databricks Connectの使用を開始する前に、Databricks Connect クライアントをセットアップする必要があります。

Databricks Runtime 14.1 以降では、Databricks Connect for Scala でユーザー定義関数 (UDF) の実行がサポートされています。

UDF を実行するには、UDF が必要とするコンパイル済みクラスと JAR をクラスターにアップロードする必要があります。 addCompiledArtifacts() APIを使用して、アップロードする必要があるコンパイル済みクラスおよびJARファイルを指定できます。

注:

クライアントで使用される Scala は、Databricks クラスターの Scala バージョンと一致する必要があります。 クラスターの Scala バージョンを確認するには、「Databricks Runtime リリースノートのバージョンと互換性」のクラスターの Databricks Runtime バージョンに関する「システム環境」セクションを参照してください。

次の Scala プログラムは、列内の値を 2 乗する単純な UDF を設定します。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

    val spark = DatabricksSession.builder()
      .addCompiledArtifacts(sourceLocation)
      .getOrCreate()

    def squared(x: Int): Int = x * x

    val squared_udf = udf(squared _)

    spark.range(3)
      .withColumn("squared", squared_udf(col("id")))
      .select("squared")
      .show()
  }
}

前の例では、UDF が Main内に完全に含まれているため、 Main のコンパイル済み成果物のみが追加されます。 UDFが他のクラスに分散している場合、または外部ライブラリ(つまり、JAR)を使用する場合は、これらのライブラリもすべて含める必要があります。

Sparkセッションがすでに初期化されている場合は、 spark.addArtifact() APIを使用して、さらにコンパイルされたクラスとJARをアップロードできます。

注:

JARをアップロードする場合、すべての推移的な依存関係JARをアップロードに含める必要があります。 APIsは、推移的な依存関係の自動検出を実行しません。

型付きデータセットAPIs

UDF について前のセクションで説明したのと同じメカニズムは、型指定されたデータセット APIsにも適用されます。

型付きデータセット APIs を使用すると、結果のデータセットに対してマップ、フィルター、集計などの変換を実行できます。 これらも、Databricks クラスターの UDF と同様に実行されます。

次の Scala アプリケーションでは、 map() API を使用して、結果列の数値をプレフィックス付きの文字列に変更します。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

    val spark = DatabricksSession.builder()
      .addCompiledArtifacts(sourceLocation)
      .getOrCreate()

    spark.range(3).map(f => s"row-$f").show()
  }
}

この例では map() API を使用していますが、これは filter()mapPartitions() など、他の型指定されたデータセット APIs にも当てはまります。