Optimize stateful Structured Streaming queries

Managing the intermediate state information of stateful Structured Streaming queries can help prevent unexpected latency and production problems.

Databricks recommends:

  • Use compute-optimized instances as workers.

  • Set the number of shuffle partitions to 1-2 times number of cores in the cluster.

  • Set the spark.sql.streaming.noDataMicroBatches.enabled configuration to false in the SparkSession. This prevents the streaming micro-batch engine from processing micro-batches that do not contain data. Note also that setting this configuration to false could result in stateful operations that leverage watermarks or processing time timeouts to not get data output until new data arrives instead of immediately.

Databricks recommends using RocksDB with changelog checkpointing to manage the state for stateful streams. See Configure RocksDB state store on Databricks.

Note

The state management scheme cannot be changed between query restarts. That is, if a query has been started with the default management, then it cannot changed without starting the query from scratch with a new checkpoint location.

Work with multiple stateful operators in Structured Streaming

In Databricks Runtime 13.1 and above, Databricks offers advanced support for stateful operators in Structured Streaming workloads. You can now chain multiple stateful operators together, meaning that you can feed the output of an operation such as a windowed aggregation to another stateful operation such as a join.

The following examples demonstrate several patterns you can use.

Important

The following limitations exist when working with multiple stateful operators:

  • FlatMapGroupWithState is not supported.

  • Only the append output mode is supported.

Chained time window aggregation

words = ...  # streaming DataFrame of schema { timestamp: Timestamp, word: String }

# Group the data by window and word and compute the count of each group
windowedCounts = words.groupBy(
    window(words.timestamp, "10 minutes", "5 minutes"),
    words.word
).count()

# Group the windowed data by another window and word and compute the count of each group
anotherWindowedCounts = windowedCounts.groupBy(
    window(window_time(windowedCounts.window), "1 hour"),
    windowedCounts.word
).count()
import spark.implicits._

val words = ... // streaming DataFrame of schema { timestamp: Timestamp, word: String }

// Group the data by window and word and compute the count of each group
val windowedCounts = words.groupBy(
  window($"timestamp", "10 minutes", "5 minutes"),
  $"word"
).count()

// Group the windowed data by another window and word and compute the count of each group
val anotherWindowedCounts = windowedCounts.groupBy(
  window($"window", "1 hour"),
  $"word"
).count()

Time window aggregation in two different streams followed by stream-stream window join

clicksWindow = clicksWithWatermark.groupBy(
  clicksWithWatermark.clickAdId,
  window(clicksWithWatermark.clickTime, "1 hour")
).count()

impressionsWindow = impressionsWithWatermark.groupBy(
  impressionsWithWatermark.impressionAdId,
  window(impressionsWithWatermark.impressionTime, "1 hour")
).count()

clicksWindow.join(impressionsWindow, "window", "inner")
val clicksWindow = clicksWithWatermark
  .groupBy(window("clickTime", "1 hour"))
  .count()

val impressionsWindow = impressionsWithWatermark
  .groupBy(window("impressionTime", "1 hour"))
  .count()

clicksWindow.join(impressionsWindow, "window", "inner")

Stream-stream time interval join followed by time window aggregation

joined = impressionsWithWatermark.join(
  clicksWithWatermark,
  expr("""
    clickAdId = impressionAdId AND
    clickTime >= impressionTime AND
    clickTime <= impressionTime + interval 1 hour
    """),
  "leftOuter"                 # can be "inner", "leftOuter", "rightOuter", "fullOuter", "leftSemi"
)

joined.groupBy(
  joined.clickAdId,
  window(joined.clickTime, "1 hour")
).count()
val joined = impressionsWithWatermark.join(
  clicksWithWatermark,
  expr("""
    clickAdId = impressionAdId AND
    clickTime >= impressionTime AND
    clickTime <= impressionTime + interval 1 hour
  """),
  joinType = "leftOuter"      // can be "inner", "leftOuter", "rightOuter", "fullOuter", "leftSemi"
)

joined
  .groupBy($"clickAdId", window($"clickTime", "1 hour"))
  .count()

State rebalancing for Structured Streaming

State rebalancing is enabled by default for all streaming workloads in Delta Live Tables. In Databricks Runtime 11.1 and above, you can set the following configuration option in the Spark cluster configuration to enable state rebalancing:

spark.sql.streaming.statefulOperator.stateRebalancing.enabled true

State rebalancing benefits stateful Structured Streaming pipelines that undergo cluster resizing events. Stateless streaming operations do not benefit, regardless of changing cluster sizes.

Note

Compute auto-scaling has limitations scaling down cluster size for Structured Streaming workloads. Databricks recommends using Delta Live Tables with Enhanced Autoscaling for streaming workloads. See What is Enhanced Autoscaling?.

Cluster resizing events cause state rebalancing to trigger. During rebalancing events, micro-batches might have higher latency as the state loads from cloud storage to the new executors.

Specify initial state for mapGroupsWithState

You can specify a user defined initial state for Structured Streaming stateful processing using flatMapGroupsWithStateor mapGroupsWithState. This allows you to avoid reprocessing data when starting a stateful stream without a valid checkpoint.

def mapGroupsWithState[S: Encoder, U: Encoder](
    timeoutConf: GroupStateTimeout,
    initialState: KeyValueGroupedDataset[K, S])(
    func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]

def flatMapGroupsWithState[S: Encoder, U: Encoder](
    outputMode: OutputMode,
    timeoutConf: GroupStateTimeout,
    initialState: KeyValueGroupedDataset[K, S])(
    func: (K, Iterator[V], GroupState[S]) => Iterator[U])

Example use case that specifies an initial state to the flatMapGroupsWithState operator:

val fruitCountFunc =(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
  val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
  state.update(new RunningCount(count))
  Iterator((key, count.toString))
}

val fruitCountInitialDS: Dataset[(String, RunningCount)] = Seq(
  ("apple", new RunningCount(1)),
  ("orange", new RunningCount(2)),
  ("mango", new RunningCount(5)),
).toDS()

val fruitCountInitial = initialState.groupByKey(x => x._1).mapValues(_._2)

fruitStream
  .groupByKey(x => x)
  .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout, fruitCountInitial)(fruitCountFunc)

Example use case that specifies an initial state to the mapGroupsWithState operator:

val fruitCountFunc =(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
  val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
  state.update(new RunningCount(count))
  (key, count.toString)
}

val fruitCountInitialDS: Dataset[(String, RunningCount)] = Seq(
  ("apple", new RunningCount(1)),
  ("orange", new RunningCount(2)),
  ("mango", new RunningCount(5)),
).toDS()

val fruitCountInitial = initialState.groupByKey(x => x._1).mapValues(_._2)

fruitStream
  .groupByKey(x => x)
  .mapGroupsWithState(GroupStateTimeout.NoTimeout, fruitCountInitial)(fruitCountFunc)

Test the mapGroupsWithState update function

The TestGroupState API enables you to test the state update function used for Dataset.groupByKey(...).mapGroupsWithState(...) and Dataset.groupByKey(...).flatMapGroupsWithState(...).

The state update function takes the previous state as input using an object of type GroupState. See the Apache Spark GroupState reference documentation. For example:

import org.apache.spark.sql.streaming._
import org.apache.spark.api.java.Optional

test("flatMapGroupsWithState's state update function") {
  var prevState = TestGroupState.create[UserStatus](
    optionalState = Optional.empty[UserStatus],
    timeoutConf = GroupStateTimeout.EventTimeTimeout,
    batchProcessingTimeMs = 1L,
    eventTimeWatermarkMs = Optional.of(1L),
    hasTimedOut = false)

  val userId: String = ...
  val actions: Iterator[UserAction] = ...

  assert(!prevState.hasUpdated)

  updateState(userId, actions, prevState)

  assert(prevState.hasUpdated)
}