Python ユーザー定義テーブル関数 (UDTF)

プレビュー

この機能は、Databricks Runtime 14.3 LTS 以降でパブリック プレビューされています。

ユーザー定義テーブル関数 (UDTF) を使用すると、スカラー値ではなくテーブルを返す関数を登録できます。 各呼び出しから単一の結果値を返すスカラー関数とは異なり、各 UDTF は SQL ステートメントのFROM句で呼び出され、出力としてテーブル全体を返します。

各 UDTF 呼び出しは、0 個以上の引数を受け入れることができます。 これらの引数は、入力テーブル全体を表すスカラー式またはテーブル引数にすることができます。

基本的な UDTF 構文

Apache Spark は、出力行を生成するためにyieldを使用する必須のevalメソッドを持つ Python クラスとして Python UDTF を実装します。

クラスを UDTF として使用するには、PySpark udtf関数をインポートする必要があります。 Databricks では、この関数をデコレータとして使用し、 returnTypeオプションを使用してフィールド名と型を明示的に指定することをお勧めします (後のセクションで説明するように、クラスでanalyzeメソッドが定義されている場合を除く)。

次の UDTF は、2 つの整数引数の固定リストを使用してテーブルを作成します。

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

UDTFを登録する

UDTF はローカルSparkSessionに登録され、ノートブックまたはジョブ レベルで分離されます。

UDTF を Unity Catalogのオブジェクトとして登録したり、UDTF を SQLウェアハウスで使用することはできません。

関数 spark.udtf.register() を使用して、 SQLクエリで使用するために UDTF を現在の SparkSession に登録できます。 SQL 関数と Python UDTF クラスの名前を指定します。

spark.udtf.register("get_sum_diff", GetSumDiff)

登録済みの UDTF を呼び出す

登録したら、 %sqlマジック コマンドまたはspark.sql()関数のいずれかを使用して SQL で UDTF を使用できます。

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);

Apache Arrowを使用する

UDTF が入力として少量のデータを受け取り、大きなテーブルを出力する場合、Databricks では Apache Arrow の使用を推奨します。 UDTF を宣言するときにuseArrow引数を指定することで有効にできます。

@udtf(returnType="c1: int, c2: int", useArrow=True)

可変引数リスト - *args と **kwargs

Python *argsまたは**kwargs構文を使用して、不特定の数の入力値を処理するロジックを実装できます。

次の例では、引数の入力の長さと型を明示的にチェックしながら、同じ結果を返します。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

同じ例を次に示しますが、キーワード引数を使用します。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

登録時に静的スキーマを定義する

UDTF は、列名と型の順序付けられたシーケンスで構成される出力スキーマを持つ行を返します。 UDTF スキーマがすべてのクエリで常に同じである必要がある場合は、 @udtf デコレーターの後に静的な固定スキーマを指定できます。 これは、次のいずれかの StructTypeである必要があります。

StructType().add("c1", StringType())

または、構造体型を表す DDL 文字列:

c1: string

関数呼び出し時に動的スキーマを作成する

UDTF は、入力引数の値に応じて、呼び出しごとに出力スキーマをプログラムでコンピュートすることもできます。 これを行うには、特定の UDTF 呼び出しに提供される引数に対応する 0 個以上のパラメーターを受け入れるanalyzeという静的メソッドを定義します。

analyze メソッドの各引数は、次のフィールドを含む AnalyzeArgument クラスのインスタンスです。

AnalyzeArgument class フィールド

説明

dataType

DataTypeとしての入力引数の型。入力テーブル引数の場合、これはテーブルの列を表す StructType です。

value

Optional[Any]としての入力引数の値。これは、定数ではないテーブル引数またはリテラル スカラー引数に対して None されます。

isTable

入力引数が table as a BooleanTypeであるかどうか。

isConstantExpression

入力引数が BooleanTypeとしての定数畳み込み式であるかどうか

analyze メソッドは、結果表のスキーマをStructTypeとして含み、いくつかのオプション・フィールドを含む AnalyzeResult クラスのインスタンスを戻します。UDTFが入力テーブル引数を受け入れる場合、 AnalyzeResult には、後で説明するように、複数のUDTF呼び出し間で入力テーブルの行を分割および順序付けする要求された方法を含めることもできます。

AnalyzeResult class フィールド

説明

schema

結果表のスキーマ ( StructType.

withSinglePartition

すべての入力行を BooleanTypeと同じ UDTF クラスインスタンスに送信するかどうか。

partitionBy

空でないに設定すると、パーティショニング式の値の一意の組み合わせを持つすべての行が、 UDTF クラスの個別のインスタンスによって消費されます。

orderBy

空でないに設定すると、各パーティション内の行の順序が指定されます。

select

空以外に設定されている場合、これは UDTF が Catalyst に入力 TABLE 引数の列に対して評価するように指定する一連の式です。 UDTFは、リスト内の名前ごとに1つの入力属性をリストされている順序で受け取ります。

このanalyzeの例では、入力文字列引数の各単語に対して 1 つの出力列が返されます。

@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result
['word_0', 'word_1']

状態を将来のeval呼び出しに転送する

analyze メソッドは、初期化を実行し、その結果を同じ UDTF 呼び出しの将来のevalメソッド呼び出しに転送するのに便利な場所として機能します。

これを行うには、 AnalyzeResult のサブクラスを作成し、 analyze メソッドからサブクラスのインスタンスを返します。 次に、 __init__ メソッドに引数を追加して、そのインスタンスを受け入れます。

このanalyzeの例では、定数の出力スキーマが返されますが、将来の__init__メソッド呼び出しで使用される結果メタデータにカスタム情報が追加されます。

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

self.spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

出力行の yield

evalメソッドは、入力テーブル引数の各行に対して 1 回実行され (テーブル引数が指定されていない場合は 1 回だけ実行)、最後にterminateメソッドが 1 回呼び出されます。 どちらの方法でも、タプル、リスト、または pyspark.sql.Row オブジェクトを生成することで、結果スキーマに準拠する 0 個以上の行が出力されます。

この例では、次の 3 つの要素のタプルを指定して行を返します。

def eval(self, x, y, z):
  yield (x, y, z)

括弧を省略することもできます。

def eval(self, x, y, z):
  yield x, y, z

末尾にコンマを追加して、列が 1 つしかない行を返します。

def eval(self, x, y, z):
  yield x,

また、 pyspark.sql.Row オブジェクトを生成することもできます。

def eval(self, x, y, z)
  from pyspark.sql.types import Row
  yield Row(x, y, z)

この例では、Python リストを使用してterminateメソッドから出力行を生成します。 この目的のために、UDTF 評価の前のステップからクラス内に状態を保存できます。

def terminate(self):
  yield [self.x, self.y, self.z]

スカラー引数を UDTFに渡す

スカラー引数は、リテラル値またはそれらに基づく関数で構成される定数式として UDTF に渡すことができます。 例えば:

SELECT * FROM udtf(42, group => upper("finance_department"));

テーブル引数を UDTF に渡す

Python UDTF は、スカラー入力引数に加えて、入力テーブルを引数として受け入れることができます。 単一の UDTF は、テーブル引数と複数のスカラー引数を受け入れることもできます。

次に、任意の SQL クエリで、 TABLEキーワードに続いて適切なテーブル識別子を囲む括弧 ( TABLE(t)など) を使用して入力テーブルを提供できます。 または、 TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))のようなテーブルサブクエリを渡すこともできます。

入力テーブル引数は、eval メソッドのpyspark.sql.Row引数として表され、入力テーブルの各行に対して eval メソッドが 1 回呼び出されます。標準の PySpark 列フィールド注釈を使用して、各行の列を操作できます。 次の例は、PySpark Row型を明示的にインポートし、渡されたテーブルをidフィールドでフィルタリングする方法を示しています。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

関数をクエリするには、 TABLE SQL キーワードを使用します。

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

関数呼び出しから入力行のパーティション分割を指定する

テーブル引数を使用して UDTF を呼び出す場合、任意の SQL クエリは、1 つ以上の入力テーブル列の値に基づいて、入力テーブルを複数の UDTF 呼び出しに分割できます。

パーティションを指定するには、関数呼び出しの TABLE 引数の後に PARTITION BY 句を使用します。これにより、パーティション分割列の値の一意の組み合わせを持つすべての入力行が、 UDTF クラスの 1 つのインスタンスによってのみ使用されることが保証されます。

単純な列参照に加えて、 PARTITION BY 句は入力テーブル列に基づく任意の式も受け入れることに注意してください。 たとえば、文字列のLENGTHを指定したり、日付から月を抽出したり、2 つの値を連結したりできます。

また、PARTITION BYの代わりに WITH SINGLE PARTITION を指定して、すべての入力行を UDTF クラスの 1 つのインスタンスだけで消費する必要がある 1 つのパーティションのみを要求することもできます。

各パーティション内では、 UDTF の eval メソッドが入力行を消費するときに、必要に応じて入力行の必要な順序を指定できます。 これを行うには、上記の PARTITION BY 句または WITH SINGLE PARTITION 句の後に ORDER BY 句を指定します。

たとえば、次の UDTF について考えてみます。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

複数の方法で入力テーブルに対して UDTF を呼び出すときに、パーティション分割オプションを指定できます。

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

analyze メソッドからの入力行のパーティション方法を指定します

SQL クエリで UDTF を呼び出すときに入力テーブルをパーティション分割する上記の各方法には、代わりに UDTF のanalyzeメソッドで同じパーティション分割方法を自動的に指定する対応する方法があることに注意してください。

  • SELECT * FROM udtf(TABLE(t) PARTITION BY a)として UDTF を呼び出す代わりに、analyze メソッドを更新してフィールドpartitionBy=[PartitioningColumn("a")]を設定し、 SELECT * FROM udtf(TABLE(t))を使用して関数を呼び出すことができます。

  • 同じトークンを使用して、 SQLクエリで TABLE(t) WITH SINGLE PARTITION ORDER BY b を指定する代わりに、analyze でフィールド withSinglePartition=trueorderBy=[OrderingColumn("b")] を設定し、TABLE(t) を渡すこともできます。

  • SQL クエリでTABLE(SELECT a FROM t)を渡す代わりに、 analyzeselect=[SelectedColumn("a")]を設定し、 TABLE(t)を渡すこともできます。

次の例では、 analyze は定数出力スキーマを返し、入力テーブルから列のサブセットを選択し、 date 列の値に基づいて入力テーブルを複数の UDTF 呼び出しに分割することを指定します。

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add('longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word),
        alias="length_word")])