diff --git a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 index 1009f5158..72f3bf1b9 100644 --- a/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 +++ b/core/src/main/antlr4/com/databricks/labs/remorph/parsers/snowflake/SnowflakeParser.g4 @@ -3329,14 +3329,11 @@ switchSection: WHEN expr THEN expr queryStatement: withExpression? selectStatement setOperators* ; -withExpression: WITH commonTableExpression (COMMA commonTableExpression)* +withExpression: WITH RECURSIVE? commonTableExpression (COMMA commonTableExpression)* ; commonTableExpression - : tableName = id (L_PAREN columns += id (COMMA columns += id)* R_PAREN)? AS L_PAREN ( - (selectStatement setOperators*) - | expr - ) R_PAREN + : tableName = id (L_PAREN columnList R_PAREN)? AS L_PAREN ((selectStatement setOperators*) | expr) R_PAREN ; selectStatement diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala index 94357be00..8328fb290 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/extensions.scala @@ -37,6 +37,11 @@ case class WithCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationC override def children: Seq[LogicalPlan] = ctes :+ query } +case class WithRecursiveCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationCommon { + override def output: Seq[Attribute] = query.output + override def children: Seq[LogicalPlan] = ctes :+ query +} + // TODO: (nfx) refactor to align more with catalyst, rename to UnresolvedStar case class Star(objectName: Option[ObjectReference] = None) extends LeafExpression with StarOrAlias { override def dataType: DataType = UnresolvedType diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala index a34fad511..a228a9f1a 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilder.scala @@ -90,8 +90,14 @@ class SnowflakeAstBuilder(override val vc: SnowflakeVisitorCoordinator) errorCheck(ctx) match { case Some(errorResult) => errorResult case None => - val ctes = vc.relationBuilder.visitMany(ctx.commonTableExpression()) - ir.WithCTE(ctes, relation) + if (ctx.RECURSIVE() == null) { + val ctes = vc.relationBuilder.visitMany(ctx.commonTableExpression()) + ir.WithCTE(ctes, relation) + } else { + // TODO With Recursive CTE are not support by default, will require a custom implementation IR to be redefined + val ctes = vc.relationBuilder.visitMany(ctx.commonTableExpression()) + ir.WithRecursiveCTE(ctes, relation) + } } } diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala index 80ed4e7f4..973146816 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala @@ -328,9 +328,14 @@ class SnowflakeRelationBuilder(override val vc: SnowflakeVisitorCoordinator) case Some(errorResult) => errorResult case None => val tableName = vc.expressionBuilder.buildId(ctx.tableName) - val columns = ctx.columns.asScala.map(vc.expressionBuilder.buildId) + val columns = ctx.columnList() match { + case null => Seq.empty[ir.Id] + case c => c.columnName().asScala.flatMap(_.id.asScala.map(vc.expressionBuilder.buildId)) + } + val query = ctx.selectStatement().accept(this) ir.SubqueryAlias(query, tableName, columns) + } private def buildNum(ctx: NumContext): BigDecimal = { diff --git a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala index 0e90cc487..82becf90e 100644 --- a/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala +++ b/core/src/test/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeAstBuilderSpec.scala @@ -494,5 +494,56 @@ class SnowflakeAstBuilderSpec extends AnyWordSpec with SnowflakeParserTestCommon Project(Filter(namedTable("a"), Equals(Id("b"), Id("$ids"))), Seq(Star()))) } } + + "translate with recursive" should { + """WITH RECURSIVE employee_hierarchy""".stripMargin in { + singleQueryExample( + """WITH RECURSIVE employee_hierarchy AS ( + | SELECT + | employee_id, + | manager_id, + | employee_name, + | 1 AS level + | FROM + | employees + | WHERE + | manager_id IS NULL + | UNION ALL + | SELECT + | e.employee_id, + | e.manager_id, + | e.employee_name, + | eh.level + 1 AS level + | FROM + | employees e + | INNER JOIN + | employee_hierarchy eh ON e.manager_id = eh.employee_id + |) + |SELECT * + |FROM employee_hierarchy + |ORDER BY level, employee_id;""".stripMargin, + WithRecursiveCTE( + Seq( + SubqueryAlias( + Project( + Filter(NamedTable("employees", Map.empty, false), IsNull(Id("manager_id", false))), + Seq( + Id("employee_id", false), + Id("manager_id", false), + Id("employee_name", false), + Alias(Literal(1, IntegerType), Id("level", false)))), + Id("employee_hierarchy", false), + Seq.empty)), + Project( + Sort( + NamedTable("employee_hierarchy", Map.empty, false), + Seq( + SortOrder(Id("level", false), Ascending, NullsLast), + SortOrder(Id("employee_id", false), Ascending, NullsLast)), + false), + Seq(Star(None))))) + } + + } } } diff --git a/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql b/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql new file mode 100644 index 000000000..25e60ae3f --- /dev/null +++ b/tests/resources/functional/snowflake/core_engine/test_cte/cte_simple.sql @@ -0,0 +1,16 @@ +-- snowflake sql: +WITH employee_hierarchy AS ( + SELECT + employee_id, + manager_id, + employee_name + FROM + employees + WHERE + manager_id IS NULL +) +SELECT * +FROM employee_hierarchy; + +-- databricks sql: +WITH employee_hierarchy AS (SELECT employee_id, manager_id, employee_name FROM employees WHERE manager_id IS NULL) SELECT * FROM employee_hierarchy; \ No newline at end of file