Databricks Connect for Scala のユーザー定義関数User-defined functions in Databricks Connect for Scala
注:
この記事では、Databricks Connect for Databricks Runtime 14.1 以降について説明します。
この記事では、Databricks Connect for Scala を使用してユーザー定義関数を実行する方法について説明します。 Databricks Connect を使用すると、一般的な IDE、ノートブック サーバー、カスタム アプリケーションを Databricks クラスターに接続できます。 この記事の Python バージョンについては、「 Databricks Connect for Python のユーザー定義関数」を参照してください。
注:
Databricks Connectの使用を開始する前に、Databricks Connect クライアントをセットアップする必要があります。
Databricks Runtime 14.1 以降では、Databricks Connect for Scala でユーザー定義関数 (UDF) の実行がサポートされています。
UDF を実行するには、UDF が必要とするコンパイル済みクラスと JAR をクラスターにアップロードする必要があります。 addCompiledArtifacts()
APIを使用して、アップロードする必要があるコンパイル済みクラスおよびJARファイルを指定できます。
注:
クライアントで使用される Scala は、Databricks クラスターの Scala バージョンと一致する必要があります。 クラスターの Scala バージョンを確認するには、「Databricks Runtime リリースノートのバージョンと互換性」のクラスターの Databricks Runtime バージョンに関する「システム環境」セクションを参照してください。
次の Scala プログラムは、列内の値を 2 乗する単純な UDF を設定します。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI
val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()
def squared(x: Int): Int = x * x
val squared_udf = udf(squared _)
spark.range(3)
.withColumn("squared", squared_udf(col("id")))
.select("squared")
.show()
}
}
前の例では、UDF が Main
内に完全に含まれているため、 Main
のコンパイル済み成果物のみが追加されます。 UDFが他のクラスに分散している場合、または外部ライブラリ(つまり、JAR)を使用する場合は、これらのライブラリもすべて含める必要があります。
Sparkセッションがすでに初期化されている場合は、 spark.addArtifact()
APIを使用して、さらにコンパイルされたクラスとJARをアップロードできます。
注:
JARをアップロードする場合、すべての推移的な依存関係JARをアップロードに含める必要があります。 APIsは、推移的な依存関係の自動検出を実行しません。
型付きデータセットAPIs
UDF について前のセクションで説明したのと同じメカニズムは、型指定されたデータセット APIsにも適用されます。
型付きデータセット APIs を使用すると、結果のデータセットに対してマップ、フィルター、集計などの変換を実行できます。 これらも、Databricks クラスターの UDF と同様に実行されます。
次の Scala アプリケーションでは、 map()
API を使用して、結果列の数値をプレフィックス付きの文字列に変更します。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI
val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()
spark.range(3).map(f => s"row-$f").show()
}
}
この例では map()
API を使用していますが、これは filter()
、mapPartitions()
など、他の型指定されたデータセット APIs にも当てはまります。