Skip to content

Commit

Permalink
[SPARK-50785][SQL] Refactor FOR statement to utilize local variables …
Browse files Browse the repository at this point in the history
…properly

### What changes were proposed in this pull request?
This PR refactors FOR statement to use local variables instead of session variables to represent columns. Previously, FOR simulated local variables by artificially creating and dropping session variables, which caused a number of issues. In this PR, we create an internal `CompoundBodyExec` to represent the "scope" of the FOR statement. Within this body we declare local variables, which are automatically cleaned up when we exit the scope. We set the label of this body to the FOR variable name, if present, which enables easy access to the columns by qualifying with the FOR variable name.

### Why are the changes needed?
Previous version had a number of issues, e.g. nested for loops with same column names would fail.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
`SqlScriptingInterpreterSuite` and `SqlScriptingExecutionNodeSuite` were updated.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #50026 from dusantism-db/scripting-for-improvements-v2.

Authored-by: Dušan Tišma <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
dusantism-db authored and cloud-fan committed Feb 25, 2025
1 parent 7feb911 commit 0184c5b
Show file tree
Hide file tree
Showing 4 changed files with 486 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
package org.apache.spark.sql.scripting

import java.util
import java.util.{Locale, UUID}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.{ExecuteImmediateQuery, NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType
Expand Down Expand Up @@ -206,6 +207,15 @@ class TriggerToExceptionHandlerMap(
def getNotFoundHandler: Option[ExceptionHandlerExec] = notFoundHandler
}

object TriggerToExceptionHandlerMap {
def createEmptyMap: TriggerToExceptionHandlerMap = new TriggerToExceptionHandlerMap(
Map.empty[String, ExceptionHandlerExec],
Map.empty[String, ExceptionHandlerExec],
None,
None
)
}

/**
* Executable node for CompoundBody.
* @param statements
Expand All @@ -221,7 +231,7 @@ class TriggerToExceptionHandlerMap(
* Map of condition names/sqlstates to error handlers defined in this compound body.
*/
class CompoundBodyExec(
statements: Seq[CompoundStatementExec],
val statements: Seq[CompoundStatementExec],
label: Option[String] = None,
isScope: Boolean,
context: SqlScriptingExecutionContext,
Expand Down Expand Up @@ -888,31 +898,23 @@ class LoopStatementExec(
* Executable node for ForStatement.
* @param query Executable node for the query.
* @param variableName Name of variable used for accessing current row during iteration.
* @param body Executable node for the body.
* @param statements List of statements to be executed in the FOR body.
* @param label Label set to ForStatement by user or None otherwise.
* @param session Spark session that SQL script is executed within.
* @param context SqlScriptingExecutionContext keeps the execution state of current script.
*/
class ForStatementExec(
query: SingleStatementExec,
variableName: Option[String],
body: CompoundBodyExec,
statements: Seq[CompoundStatementExec],
val label: Option[String],
session: SparkSession,
context: SqlScriptingExecutionContext) extends NonLeafStatementExec {

private object ForState extends Enumeration {
val VariableAssignment, Body, VariableCleanup = Value
val VariableAssignment, Body = Value
}
private var state = ForState.VariableAssignment
private var areVariablesDeclared = false

// map of all variables created internally by the for statement
// (variableName -> variableExpression)
private var variablesMap: Map[String, Expression] = Map()

// compound body used for dropping variables while in ForState.VariableAssignment
private var dropVariablesExec: CompoundBodyExec = null

private var queryResult: util.Iterator[Row] = _
private var isResultCacheValid = false
Expand All @@ -925,6 +927,8 @@ class ForStatementExec(
queryResult
}

private var bodyWithVariables: CompoundBodyExec = null

/**
* For can be interrupted by LeaveStatementExec
*/
Expand All @@ -935,35 +939,43 @@ class ForStatementExec(

override def hasNext: Boolean = !interrupted && (state match {
case ForState.VariableAssignment => cachedQueryResult().hasNext
case ForState.Body => true
case ForState.VariableCleanup => dropVariablesExec.getTreeIterator.hasNext
case ForState.Body => bodyWithVariables.getTreeIterator.hasNext
})

@scala.annotation.tailrec
override def next(): CompoundStatementExec = state match {

case ForState.VariableAssignment =>
variablesMap = createVariablesMapFromRow(cachedQueryResult().next())

if (!areVariablesDeclared) {
// create and execute declare var statements
variablesMap.keys.toSeq
.map(colName => createDeclareVarExec(colName, variablesMap(colName)))
.foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect())
areVariablesDeclared = true
}

// create and execute set var statements
variablesMap.keys.toSeq
.map(colName => createSetVarExec(colName, variablesMap(colName)))
.foreach(setVarExec => setVarExec.buildDataFrame(session).collect())
val row = cachedQueryResult().next()

val variableInitStatements = row.schema.names.toSeq
.map { colName => (colName, createExpressionFromValue(row.getAs(colName))) }
.flatMap { case (colName, expr) => Seq(
createDeclareVarExec(colName, expr),
createSetVarExec(colName, expr)
) }

bodyWithVariables = new CompoundBodyExec(
// NoOpStatementExec appended to end of body to prevent
// dropping variables before last statement is executed.
// This is necessary because we are calling exitScope before returning the last
// statement, so we need the last statement to be NoOp.
statements = variableInitStatements ++ statements :+ new NoOpStatementExec,
// We generate label name if FOR variable is not specified, similar to how
// compound bodies have generated label names if label is not specified.
label = variableName.orElse(Some(UUID.randomUUID().toString.toLowerCase(Locale.ROOT))),
isScope = true,
context = context,
triggerToExceptionHandlerMap = TriggerToExceptionHandlerMap.createEmptyMap
)

state = ForState.Body
body.reset()
bodyWithVariables.reset()
bodyWithVariables.enterScope()
next()

case ForState.Body =>
val retStmt = body.getTreeIterator.next()
val retStmt = bodyWithVariables.getTreeIterator.next()

// Handle LEAVE or ITERATE statement if it has been encountered.
retStmt match {
Expand All @@ -972,34 +984,28 @@ class ForStatementExec(
leaveStatementExec.hasBeenMatched = true
}
interrupted = true
// If this for statement encounters LEAVE, it will either not be executed
// again, or it will be reset before being executed.
// In either case, variables will not
// be dropped normally, from ForState.VariableCleanup, so we drop them here.
dropVars()
// If this for statement encounters LEAVE, we need to exit the scope, as
// we will not reach the point where we usually exit it.
bodyWithVariables.exitScope()
return retStmt
case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
if (label.contains(iterStatementExec.label)) {
iterStatementExec.hasBeenMatched = true
} else {
// if an outer loop is being iterated, this for statement will either not be
// executed again, or it will be reset before being executed.
// In either case, variables will not
// be dropped normally, from ForState.VariableCleanup, so we drop them here.
dropVars()
// If an outer loop is being iterated, we need to exit the scope, as
// we will not reach the point where we usually exit it.
bodyWithVariables.exitScope()
}
switchStateFromBody()
state = ForState.VariableAssignment
return retStmt
case _ =>
}

if (!body.getTreeIterator.hasNext) {
switchStateFromBody()
if (!bodyWithVariables.getTreeIterator.hasNext) {
bodyWithVariables.exitScope()
state = ForState.VariableAssignment
}
retStmt

case ForState.VariableCleanup =>
dropVariablesExec.getTreeIterator.next()
}
}

Expand Down Expand Up @@ -1032,46 +1038,6 @@ class ForStatementExec(
case _ => Literal(value)
}

private def createVariablesMapFromRow(row: Row): Map[String, Expression] = {
var variablesMap = row.schema.names.toSeq.map { colName =>
colName -> createExpressionFromValue(row.getAs(colName))
}.toMap

if (variableName.isDefined) {
val namedStructArgs = variablesMap.keys.toSeq.flatMap { colName =>
Seq(Literal(colName), variablesMap(colName))
}
val forVariable = CreateNamedStruct(namedStructArgs)
variablesMap = variablesMap + (variableName.get -> forVariable)
}
variablesMap
}

/**
* Create and immediately execute dropVariable exec nodes for all variables in variablesMap.
*/
private def dropVars(): Unit = {
variablesMap.keys.toSeq
.map(colName => createDropVarExec(colName))
.foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect())
areVariablesDeclared = false
}

private def switchStateFromBody(): Unit = {
state = if (cachedQueryResult().hasNext) ForState.VariableAssignment
else {
// create compound body for dropping nodes after execution is complete
dropVariablesExec = new CompoundBodyExec(
variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)),
None,
isScope = false,
context,
new TriggerToExceptionHandlerMap(Map.empty, Map.empty, None, None)
)
ForState.VariableCleanup
}
}

private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = {
val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null")
val declareVariable = CreateVariable(
Expand All @@ -1097,24 +1063,13 @@ class ForStatementExec(
context)
}

private def createDropVarExec(varName: String): SingleStatementExec = {
// As DROP TEMPORARY VARIABLE is forbidden within a script, use EXECUTE IMMEDIATE to bypass
// this limitation. This will be removed once FOR is updated to properly use local variables.
val dropVar = ExecuteImmediateQuery(
Seq.empty, Left("DROP TEMPORARY VARIABLE IF EXISTS " + varName), Seq.empty)
new SingleStatementExec(dropVar, Origin(), Map.empty, isInternal = true, context)
}

override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator

override def reset(): Unit = {
state = ForState.VariableAssignment
isResultCacheValid = false
variablesMap = Map()
areVariablesDeclared = false
dropVariablesExec = null
interrupted = false
body.reset()
bodyWithVariables = null
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
context)
val bodyExec =
transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec]
new ForStatementExec(queryExec, variableNameOpt, bodyExec, label, session, context)
new ForStatementExec(
queryExec, variableNameOpt, bodyExec.statements, label, session, context)

case leaveStatement: LeaveStatement =>
new LeaveStatementExec(leaveStatement.label)
Expand Down
Loading

0 comments on commit 0184c5b

Please sign in to comment.