Skip to content

Commit

Permalink
Added ruff Python formatter proxy (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx authored Oct 30, 2024
1 parent 5574efe commit d69ffdd
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 7 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ jobs:
path: ~/.m2
key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }}

- name: Install Python
uses: actions/setup-python@v5
with:
cache: 'pip'
cache-dependency-path: '**/pyproject.toml'
python-version: '3.10'

- name: Initialize Python virtual environment for StandardInputPythonSubprocess
run: make dev

- name: Run Unit Tests with Maven
run: mvn --update-snapshots scoverage:report --file pom.xml --fail-at-end

Expand Down Expand Up @@ -203,6 +213,16 @@ jobs:
path: ~/.m2
key: ${{ github.job }}-${{ hashFiles('**/pom.xml') }}

- name: Install Python
uses: actions/setup-python@v5
with:
cache: 'pip'
cache-dependency-path: '**/pyproject.toml'
python-version: '3.10'

- name: Initialize Python virtual environment for StandardInputPythonSubprocess
run: make dev

- name: Create dummy test file
run: |
mkdir $INPUT_DIR_PARENT/snowflake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ object WorkflowStage {
case object PLAN extends WorkflowStage
case object OPTIMIZE extends WorkflowStage
case object GENERATE extends WorkflowStage
case object FORMAT extends WorkflowStage
}

sealed trait Result[+A] {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.databricks.labs.remorph.generators.py

import com.databricks.labs.remorph.utils.StandardInputPythonSubprocess
import com.databricks.labs.remorph.Result

class RuffFormatter {
private val ruffFmt = new StandardInputPythonSubprocess("ruff format -")
def format(input: String): Result[String] = ruffFmt(input)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.databricks.labs.remorph.utils

import com.databricks.labs.remorph.intermediate.TranspileFailure
import com.databricks.labs.remorph.{KoResult, OkResult, Result, WorkflowStage}

import java.io._
import scala.annotation.tailrec
import scala.sys.process.{Process, ProcessIO}
import scala.util.control.NonFatal

class StandardInputPythonSubprocess(passArgs: String) {
def apply(input: String): Result[String] = {
val process = Process(s"$getEffectivePythonBin -m $passArgs", None)
val output = new StringBuilder
val error = new StringBuilder
try {
val result = process.run(createIO(input, output, error)).exitValue()
if (result != 0) {
KoResult(WorkflowStage.FORMAT, new TranspileFailure(new IOException(error.toString)))
} else {
OkResult(output.toString)
}
} catch {
case e: IOException if e.getMessage.contains("Cannot run") =>
val failure = new TranspileFailure(new IOException("Invalid $PYTHON_BIN environment variable"))
KoResult(WorkflowStage.FORMAT, failure)
case NonFatal(e) =>
KoResult(WorkflowStage.FORMAT, new TranspileFailure(e))
}
}

private def createIO(input: String, output: StringBuilder, error: StringBuilder) = new ProcessIO(
stdin => {
stdin.write(input.getBytes)
stdin.close()
},
stdout => {
val reader = new BufferedReader(new InputStreamReader(stdout))
var line: String = reader.readLine()
while (line != null) {
output.append(s"$line\n")
line = reader.readLine()
}
reader.close()
},
stderr => {
val reader = new BufferedReader(new InputStreamReader(stderr))
var line: String = reader.readLine()
while (line != null) {
error.append(s"$line\n")
line = reader.readLine()
}
reader.close()
})

private def getEffectivePythonBin: String = {
sys.env.getOrElse(
"PYTHON_BIN", {
val projectRoot = findLabsYmlFolderIn(new File(System.getProperty("user.dir")))
val venvPython = new File(projectRoot, ".venv/bin/python")
venvPython.getAbsolutePath
})
}

@tailrec private def findLabsYmlFolderIn(path: File): File = {
if (new File(path, "labs.yml").exists()) {
path
} else {
val parent = path.getParentFile
if (parent == null) {
throw new FileNotFoundException(
"labs.yml not found anywhere in the project hierarchy. " +
"Please set PYTHON_BIN environment variable to point to the correct Python binary.")
}
findLabsYmlFolderIn(parent)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package com.databricks.labs.remorph.transpilers

import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult}
import com.databricks.labs.remorph.generators.py.RuffFormatter
import org.scalatest.wordspec.AnyWordSpec

class SnowflakeToPySparkTranspilerTest extends AnyWordSpec with TranspilerTestCommon {
protected val transpiler = new SnowflakeToPySparkTranspiler
protected override val reformat = false
private val formatter = new RuffFormatter
override def format(input: String): String = formatter.format(input) match {
case OkResult(formatted) => formatted
case KoResult(_, error) => fail(error.msg)
case PartialResult(output, error) => fail(s"Partial result: $output, error: $error")
}

"Snowflake SQL" should {
"transpile window functions" in {
Expand All @@ -14,7 +21,16 @@ class SnowflakeToPySparkTranspilerTest extends AnyWordSpec with TranspilerTestCo
|FROM t1;""".stripMargin transpilesTo
"""import pyspark.sql.functions as F
|from pyspark.sql.window import Window
|spark.table('t1').select(F.last(F.col('c1')).over(Window.partitionBy(F.col('t1.c2')).orderBy(F.col('t1.c3').desc_nulls_first()).rangeBetween(Window.unboundedPreceding, Window.currentRow)).alias('dc4'))
|
|spark.table("t1").select(
| F.last(F.col("c1"))
| .over(
| Window.partitionBy(F.col("t1.c2"))
| .orderBy(F.col("t1.c3").desc_nulls_first())
| .rangeBetween(Window.unboundedPreceding, Window.currentRow)
| )
| .alias("dc4")
|)
|""".stripMargin
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@ trait TranspilerTestCommon extends Matchers with Formatter {

protected def transpiler: Transpiler

protected def reformat = true

private def formatResult(result: String): String = if (reformat) format(result) else result

implicit class TranspilerTestOps(input: String) {
def transpilesTo(expectedOutput: String): Assertion = {
transpiler.transpile(SourceCode(input)) match {
case OkResult(output) => formatResult(output) shouldBe formatResult(expectedOutput)
case OkResult(output) => format(output) shouldBe format(expectedOutput)
case PartialResult(_, err) => fail(write(err))
case KoResult(_, err) => fail(write(err))
}
Expand Down

0 comments on commit d69ffdd

Please sign in to comment.