Skip to content

Commit

Permalink
Fixed handling of projected expressions in TreeNode
Browse files Browse the repository at this point in the history
Fix #1072
  • Loading branch information
nfx committed Nov 4, 2024
1 parent cbea0c0 commit a6da8cb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ abstract class Plan[PlanType <: Plan[PlanType]] extends TreeNode[PlanType] {
* Returns all of the expressions present in this query (that is expression defined in this plan operator and in each
* its descendants).
*/
def expressions: Seq[Expression] = { // TODO: bring back `final` after "expressions" in Project is renamed
final def expressions: Seq[Expression] = {
// Recursively find all expressions from a traversable.
def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap {
case e: Expression => e :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class DataSource(
override def output: Seq[Attribute] = Seq.empty
}

case class Project(input: LogicalPlan, override val expressions: Seq[Expression]) extends UnaryNode {
case class Project(input: LogicalPlan, columns: Seq[Expression]) extends UnaryNode {
override def child: LogicalPlan = input
// TODO: add resolver for Star
override def output: Seq[Attribute] = expressions.map {
Expand Down Expand Up @@ -65,7 +65,6 @@ case class Join(
join_data_type: JoinDataType)
extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
override def expressions: Seq[Expression] = super.expressions ++ join_condition.toSeq
}

case class SetOperation(
Expand Down Expand Up @@ -135,7 +134,6 @@ case class Deduplicate(
within_watermark: Boolean)
extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def expressions: Seq[Expression] = super.expressions ++ column_names
}

case class LocalRelation(child: LogicalPlan, data: Array[Byte], schemaString: String) extends UnaryNode {
Expand Down Expand Up @@ -249,7 +247,6 @@ case class WithWatermark(child: LogicalPlan, event_time: String, delay_threshold

case class Hint(child: LogicalPlan, name: String, parameters: Seq[Expression]) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def expressions: Seq[Expression] = super.expressions ++ parameters
}

case class Values(values: Seq[Seq[Expression]]) extends LeafNode { // TODO: fix it
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.databricks.labs.remorph.intermediate

import org.scalatest.matchers.must.Matchers
import org.scalatest.wordspec.AnyWordSpec

class JoinTest extends AnyWordSpec with Matchers {
"Join" should {
"propagage expressions" in {
val join = Join(
NoopNode,
NoopNode,
Some(Name("foo")),
InnerJoin,
Seq.empty,
JoinDataType(is_left_struct = true, is_right_struct = true))
join.expressions mustBe Seq(Name("foo"))
}
}
}

0 comments on commit a6da8cb

Please sign in to comment.