diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 19eb6b524..eb61719ec 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -142,6 +142,16 @@ jobs: path: ~/.m2 key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Initialize Python virtual environment for StandardInputPythonSubprocess + run: make dev + - name: Run Unit Tests with Maven run: mvn --update-snapshots scoverage:report --file pom.xml --fail-at-end @@ -203,6 +213,16 @@ jobs: path: ~/.m2 key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }} + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Initialize Python virtual environment for StandardInputPythonSubprocess + run: make dev + - name: Create dummy test file run: | mkdir $INPUT_DIR_PARENT/snowflake diff --git a/core/src/main/scala/com/databricks/labs/remorph/Result.scala b/core/src/main/scala/com/databricks/labs/remorph/Result.scala index f3f7f8747..374f6abb5 100644 --- a/core/src/main/scala/com/databricks/labs/remorph/Result.scala +++ b/core/src/main/scala/com/databricks/labs/remorph/Result.scala @@ -8,6 +8,7 @@ object WorkflowStage { case object PLAN extends WorkflowStage case object OPTIMIZE extends WorkflowStage case object GENERATE extends WorkflowStage + case object FORMAT extends WorkflowStage } sealed trait Result[+A] { diff --git a/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala b/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala new file mode 100644 index 000000000..b2c5c8ae4 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/generators/py/RuffFormatter.scala @@ -0,0 +1,9 @@ +package com.databricks.labs.remorph.generators.py + +import com.databricks.labs.remorph.utils.StandardInputPythonSubprocess +import com.databricks.labs.remorph.Result + +class RuffFormatter { + private val ruffFmt = new StandardInputPythonSubprocess("ruff format -") + def format(input: String): Result[String] = ruffFmt(input) +} \ No newline at end of file diff --git a/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala b/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala new file mode 100644 index 000000000..67b678eb6 --- /dev/null +++ b/core/src/main/scala/com/databricks/labs/remorph/utils/StandardInputPythonSubprocess.scala @@ -0,0 +1,78 @@ +package com.databricks.labs.remorph.utils + +import com.databricks.labs.remorph.intermediate.TranspileFailure +import com.databricks.labs.remorph.{KoResult, OkResult, Result, WorkflowStage} + +import java.io._ +import scala.annotation.tailrec +import scala.sys.process.{Process, ProcessIO} +import scala.util.control.NonFatal + +class StandardInputPythonSubprocess(passArgs: String) { + def apply(input: String): Result[String] = { + val process = Process(s"$getEffectivePythonBin -m $passArgs", None) + val output = new StringBuilder + val error = new StringBuilder + try { + val result = process.run(createIO(input, output, error)).exitValue() + if (result != 0) { + KoResult(WorkflowStage.FORMAT, new TranspileFailure(new IOException(error.toString))) + } else { + OkResult(output.toString) + } + } catch { + case e: IOException if e.getMessage.contains("Cannot run") => + val failure = new TranspileFailure(new IOException("Invalid $PYTHON_BIN environment variable")) + KoResult(WorkflowStage.FORMAT, failure) + case NonFatal(e) => + KoResult(WorkflowStage.FORMAT, new TranspileFailure(e)) + } + } + + private def createIO(input: String, output: StringBuilder, error: StringBuilder) = new ProcessIO( + stdin => { + stdin.write(input.getBytes) + stdin.close() + }, + stdout => { + val reader = new BufferedReader(new InputStreamReader(stdout)) + var line: String = reader.readLine() + while (line != null) { + output.append(s"$line\n") + line = reader.readLine() + } + reader.close() + }, + stderr => { + val reader = new BufferedReader(new InputStreamReader(stderr)) + var line: String = reader.readLine() + while (line != null) { + error.append(s"$line\n") + line = reader.readLine() + } + reader.close() + }) + + private def getEffectivePythonBin: String = { + sys.env.getOrElse( + "PYTHON_BIN", { + val projectRoot = findLabsYmlFolderIn(new File(System.getProperty("user.dir"))) + val venvPython = new File(projectRoot, ".venv/bin/python") + venvPython.getAbsolutePath + }) + } + + @tailrec private def findLabsYmlFolderIn(path: File): File = { + if (new File(path, "labs.yml").exists()) { + path + } else { + val parent = path.getParentFile + if (parent == null) { + throw new FileNotFoundException( + "labs.yml not found anywhere in the project hierarchy. " + + "Please set PYTHON_BIN environment variable to point to the correct Python binary.") + } + findLabsYmlFolderIn(parent) + } + } +} diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala index 033efec3a..984640bea 100644 --- a/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/SnowflakeToPySparkTranspilerTest.scala @@ -1,10 +1,17 @@ package com.databricks.labs.remorph.transpilers +import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult} +import com.databricks.labs.remorph.generators.py.RuffFormatter import org.scalatest.wordspec.AnyWordSpec class SnowflakeToPySparkTranspilerTest extends AnyWordSpec with TranspilerTestCommon { protected val transpiler = new SnowflakeToPySparkTranspiler - protected override val reformat = false + private val formatter = new RuffFormatter + override def format(input: String): String = formatter.format(input) match { + case OkResult(formatted) => formatted + case KoResult(_, error) => fail(error.msg) + case PartialResult(output, error) => fail(s"Partial result: $output, error: $error") + } "Snowflake SQL" should { "transpile window functions" in { @@ -14,7 +21,16 @@ class SnowflakeToPySparkTranspilerTest extends AnyWordSpec with TranspilerTestCo |FROM t1;""".stripMargin transpilesTo """import pyspark.sql.functions as F |from pyspark.sql.window import Window - |spark.table('t1').select(F.last(F.col('c1')).over(Window.partitionBy(F.col('t1.c2')).orderBy(F.col('t1.c3').desc_nulls_first()).rangeBetween(Window.unboundedPreceding, Window.currentRow)).alias('dc4')) + | + |spark.table("t1").select( + | F.last(F.col("c1")) + | .over( + | Window.partitionBy(F.col("t1.c2")) + | .orderBy(F.col("t1.c3").desc_nulls_first()) + | .rangeBetween(Window.unboundedPreceding, Window.currentRow) + | ) + | .alias("dc4") + |) |""".stripMargin } } diff --git a/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala index 6350a96af..81213951d 100644 --- a/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala +++ b/core/src/test/scala/com/databricks/labs/remorph/transpilers/TranspilerTestCommon.scala @@ -9,14 +9,10 @@ trait TranspilerTestCommon extends Matchers with Formatter { protected def transpiler: Transpiler - protected def reformat = true - - private def formatResult(result: String): String = if (reformat) format(result) else result - implicit class TranspilerTestOps(input: String) { def transpilesTo(expectedOutput: String): Assertion = { transpiler.transpile(SourceCode(input)) match { - case OkResult(output) => formatResult(output) shouldBe formatResult(expectedOutput) + case OkResult(output) => format(output) shouldBe format(expectedOutput) case PartialResult(_, err) => fail(write(err)) case KoResult(_, err) => fail(write(err)) }