diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala index 8dcb65b6a..c9dc50fbc 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/sql/LogicalPlanGenerator.scala @@ -437,12 +437,13 @@ class LogicalPlanGenerator( val child = generate(aggregate.child) val expressions = expr.commas(aggregate.grouping_expressions) aggregate.group_type match { + case ir.GroupByAll => code"$child GROUP BY ALL" case ir.GroupBy => code"$child GROUP BY $expressions" case ir.Pivot if aggregate.pivot.isDefined => val pivot = aggregate.pivot.get val col = expr.generate(pivot.col) - val values = pivot.values.map(expr.generate(_)).mkCode(" IN(", ", ", ")") + val values = pivot.values.map(expr.generate).mkCode(" IN(", ", ", ")") code"$child PIVOT($expressions FOR $col$values)" case a => partialResult(a, ir.UnsupportedGroupType(a.toString)) } diff --git a/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala b/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala index 931248c7a..88b958907 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/intermediate/relations.scala @@ -385,6 +385,8 @@ case object UnspecifiedGroupType extends GroupType case object GroupBy extends GroupType +case object GroupByAll extends GroupType + case object Pivot extends GroupType case object UnspecifiedFormat extends ParseFormat diff --git a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala index 14985ddb6..5ba5bac07 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/parsers/snowflake/SnowflakeRelationBuilder.scala @@ -117,12 +117,12 @@ class SnowflakeRelationBuilder(override val vc: SnowflakeVisitorCoordinator) private def buildGroupBy(ctx: GroupByClauseContext, input: ir.LogicalPlan): ir.LogicalPlan = { Option(ctx).fold(input) { c => val groupingExpressions = - c.groupByList() - .groupByElem() - .asScala + Option(c.groupByList()).toSeq + .flatMap(_.groupByElem().asScala) .map(_.accept(vc.expressionBuilder)) + val groupType = if (c.ALL() != null) ir.GroupByAll else ir.GroupBy val aggregate = - ir.Aggregate(child = input, group_type = ir.GroupBy, grouping_expressions = groupingExpressions, pivot = None) + ir.Aggregate(child = input, group_type = groupType, grouping_expressions = groupingExpressions, pivot = None) buildHaving(c.havingClause(), aggregate) } } diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala index 4536d9fbb..8975280af 100644 --- a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToDatabricksTranspilerTest.scala @@ -258,6 +258,11 @@ class SnowflakeToDatabricksTranspilerTest extends AnyWordSpec with TranspilerTes "SELECT ARRAY_SORT([0, 2, 4, NULL, 5, NULL], 1 = 1, TRUE);".failsTranspilation } + "GROUP BY ALL" in { + "SELECT car_model, COUNT(DISTINCT city) FROM dealer GROUP BY ALL;" transpilesTo + "SELECT car_model, COUNT(DISTINCT city) FROM dealer GROUP BY ALL;" + } + } "Snowflake transpile function with optional brackets" should {