Skip to content

Commit

Permalink
Introduce an abstraction for handling stateful traversal of the tree (#…
Browse files Browse the repository at this point in the history
…1018)

While still being able to accumulate errors, produce partial results,
etc.

The base principle is to build a *representation* of the computation
(parsing, visiting, optimizing, generating, etc) that will eventually
get `run`. This representation keeps track of a state (of arbitrary type
if we wanted to, but we practically stick to `RemorphContext` for now)
along with the result it is producing.

This way, we can interleave any step of the computation with updates to
this state, like counting the statements processed thus far (for any
definition of what a statement is, including arbitrarily nested
subqueries), the branch of the IR tree currently being processed, etc.

Similarly, at any point of the computation, we can inspect the current
state, to guide said computation (by customizing error messages with
contextual information for example).
  • Loading branch information
vil1 authored Oct 30, 2024
1 parent d69ffdd commit 50a3f23
Show file tree
Hide file tree
Showing 31 changed files with 797 additions and 741 deletions.
23 changes: 23 additions & 0 deletions core/src/main/scala/com/databricks/labs/remorph/Phase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.databricks.labs.remorph

import com.databricks.labs.remorph.intermediate.{LogicalPlan, TreeNode}
import org.antlr.v4.runtime.ParserRuleContext

sealed trait Phase

case object Init extends Phase

case class SourceCode(source: String, filename: String = "-- test source --") extends Phase

case class Parsed(tree: ParserRuleContext, sources: Option[SourceCode] = None) extends Phase

case class Ast(unoptimizedPlan: LogicalPlan, parsed: Option[Parsed] = None) extends Phase

case class Optimized(optimizedPlan: TreeNode[_], ast: Option[Ast] = None) extends Phase

case class Generating(
currentNode: TreeNode[_],
totalStatements: Int = 0,
transpiledStatements: Int = 0,
optimized: Option[Optimized] = None)
extends Phase
38 changes: 38 additions & 0 deletions core/src/main/scala/com/databricks/labs/remorph/Result.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,44 @@ object WorkflowStage {
case object FORMAT extends WorkflowStage
}

/**
* Represents a stateful computation
* @param runF
* @tparam State
* @tparam Out
*/
final class Transformation[State, +Out](val runF: Result[State => Result[(State, Out)]]) {

def map[B](f: Out => B): Transformation[State, B] = new Transformation(runF.map(_.andThen(_.map { case (s, a) =>
(s, f(a))
})))

def flatMap[B](f: Out => Transformation[State, B]): Transformation[State, B] = new Transformation(
runF.map(_.andThen(_.flatMap { case (s, a) =>
f(a).runF.flatMap(_.apply(s))
})))

def run(initialState: State): Result[(State, Out)] = runF.flatMap(_.apply(initialState))

/**
* Runs the computation and discard the final state,
* @param initialState
* @return
*/
def runAndDiscardState(initialState: State): Result[Out] = run(initialState).map(_._2)
}

trait TransformationConstructors[S] {
def ok[A](a: A): Transformation[S, A] = new Transformation(OkResult(s => OkResult((s, a))))
def ko(stage: WorkflowStage, err: RemorphError): Transformation[S, Nothing] = new Transformation(
OkResult(s => KoResult(stage, err)))
def lift[X](res: Result[X]): Transformation[S, X] = new Transformation(OkResult(s => res.map(x => (s, x))))
def get: Transformation[S, S] = new Transformation(OkResult(s => OkResult((s, s))))
def set(newState: S): Transformation[S, Unit] = new Transformation(OkResult(_ => OkResult((newState, ()))))
def update[T](f: PartialFunction[S, S]): Transformation[S, Unit] = new Transformation(
OkResult(s => OkResult((f.applyOrElse(s, identity[S]), ()))))
}

sealed trait Result[+A] {
def map[B](f: A => B): Result[B]
def flatMap[B](f: A => Result[B]): Result[B]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package com.databricks.labs.remorph.coverage
import com.databricks.labs.remorph.WorkflowStage.PARSE
import com.databricks.labs.remorph.intermediate.{RemorphError, UnexpectedOutput}
import com.databricks.labs.remorph.queries.ExampleQuery
import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult, SourceCode}
import com.databricks.labs.remorph.WorkflowStage.PARSE
import com.databricks.labs.remorph.intermediate.UnexpectedOutput
import com.databricks.labs.remorph.transpilers._
import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult}

trait QueryRunner extends Formatter {
def runQuery(exampleQuery: ExampleQuery): ReportEntryReport
Expand All @@ -30,8 +32,7 @@ abstract class BaseQueryRunner(transpiler: Transpiler) extends QueryRunner {
}

override def runQuery(exampleQuery: ExampleQuery): ReportEntryReport = {
val x = transpiler.transpile(SourceCode(exampleQuery.query))
x match {
transpiler.transpile(SourceCode(exampleQuery.query)).runAndDiscardState(SourceCode(exampleQuery.query)) match {
case KoResult(PARSE, error) => ReportEntryReport(statements = 1, parsing_error = Some(error))
case KoResult(_, error) =>
// If we got past the PARSE stage, then remember to record that we parsed it correctly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import com.databricks.labs.remorph.coverage._
import com.databricks.labs.remorph.discovery.{Anonymizer, ExecutedQuery, QueryHistoryProvider}
import com.databricks.labs.remorph.intermediate.{LogicalPlan, ParsingError, TranspileFailure}
import com.databricks.labs.remorph.parsers.PlanParser
import com.databricks.labs.remorph.{KoResult, OkResult}
import com.databricks.labs.remorph.transpilers.{SourceCode, SqlGenerator}
import com.databricks.labs.remorph.{KoResult, OkResult, SourceCode, Ast}
import com.databricks.labs.remorph.transpilers.SqlGenerator
import com.typesafe.scalalogging.LazyLogging

class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], analyzer: EstimationAnalyzer)
Expand Down Expand Up @@ -35,14 +35,17 @@ class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], a
anonymizer: Anonymizer,
parsedSet: scala.collection.mutable.Set[String]): Option[EstimationReportRecord] = {

val initialState = SourceCode(query.source)

// Skip entries that have already been seen as text but for which we were unable to parse or
// produce a plan for
val fingerprint = anonymizer(query.source)
if (!parsedSet.contains(fingerprint)) {
parsedSet += fingerprint
planParser
.parse(SourceCode(query.source, query.user.getOrElse("unknown") + "_" + query.id))
.flatMap(planParser.visit) match {
.flatMap(planParser.visit)
.run(initialState) match {
case KoResult(PARSE, error) =>
Some(
EstimationReportRecord(
Expand All @@ -60,13 +63,13 @@ class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], a
complexity = SqlComplexity.VERY_COMPLEX)))

case OkResult(plan) =>
val queryHash = anonymizer(plan)
val score = analyzer.evaluateTree(plan)
val queryHash = anonymizer(plan._2)
val score = analyzer.evaluateTree(plan._2)
// Note that the plan hash will generally be more accurate than the query hash, hence we check here
// as well as against the plain text
if (!parsedSet.contains(queryHash)) {
parsedSet += queryHash
Some(generateReportRecord(query, plan, score, anonymizer))
Some(generateReportRecord(query, plan._2, score, anonymizer))
} else {
None
}
Expand All @@ -93,7 +96,8 @@ class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], a
ruleScore: RuleScore,
anonymizer: Anonymizer): EstimationReportRecord = {
val generator = new SqlGenerator
planParser.optimize(plan).flatMap(generator.generate) match {
val initialState = Ast(plan, None)
planParser.optimize(plan).flatMap(generator.generate).run(initialState) match {
case KoResult(_, error) =>
// KoResult to transpile means that we need to increase the ruleScore as it will take some
// time to manually investigate and fix the issue
Expand All @@ -109,7 +113,7 @@ class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], a
score = tfr,
complexity = SqlComplexity.fromScore(tfr.rule.score)))

case OkResult(output: String) =>
case OkResult((_, output: String)) =>
val newScore =
RuleScore(SuccessfulTranspileRule().plusScore(ruleScore.rule.score), Seq(ruleScore))
EstimationReportRecord(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package com.databricks.labs.remorph.discovery

import com.databricks.labs.remorph.parsers.PlanParser
import com.databricks.labs.remorph.intermediate._
import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult, WorkflowStage}
import com.databricks.labs.remorph.transpilers.SourceCode
import com.databricks.labs.remorph.{KoResult, OkResult, PartialResult, SourceCode, WorkflowStage}
import com.typesafe.scalalogging.LazyLogging
import upickle.default._

Expand Down Expand Up @@ -71,7 +70,7 @@ class Anonymizer(parser: PlanParser[_]) extends LazyLogging {
def apply(query: String): String = fingerprint(query)

private[discovery] def fingerprint(query: ExecutedQuery): Fingerprint = {
parser.parse(SourceCode(query.source)).flatMap(parser.visit) match {
parser.parse(SourceCode(query.source)).flatMap(parser.visit).run(SourceCode(query.source)) match {
case KoResult(WorkflowStage.PARSE, error) =>
logger.warn(s"Failed to parse query: ${query.source} ${error.msg}")
Fingerprint(
Expand All @@ -92,7 +91,7 @@ class Anonymizer(parser: PlanParser[_]) extends LazyLogging {
query.user.getOrElse("unknown"),
WorkloadType.OTHER,
QueryType.OTHER)
case PartialResult(plan, error) =>
case PartialResult((_, plan), error) =>
logger.warn(s"Errors occurred while producing plan from query: ${query.source} ${error.msg}")
Fingerprint(
query.id,
Expand All @@ -102,7 +101,7 @@ class Anonymizer(parser: PlanParser[_]) extends LazyLogging {
query.user.getOrElse("unknown"),
workloadType(plan),
queryType(plan))
case OkResult(plan) =>
case OkResult((_, plan)) =>
Fingerprint(
query.id,
query.timestamp,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
package com.databricks.labs.remorph.generators

import com.databricks.labs.remorph.{KoResult, Result, WorkflowStage}
import com.databricks.labs.remorph.{Phase, Transformation, TransformationConstructors, WorkflowStage}
import com.databricks.labs.remorph.intermediate.{TreeNode, UnexpectedNode}

trait Generator[In <: TreeNode[In], Out] {
def generate(ctx: GeneratorContext, tree: In): Result[Out]
def unknown(tree: In): Result[Out] =
KoResult(WorkflowStage.GENERATE, UnexpectedNode(tree.getClass.getSimpleName))
trait Generator[In <: TreeNode[In], Out] extends TransformationConstructors[Phase] {
def generate(ctx: GeneratorContext, tree: In): Transformation[Phase, Out]
def unknown(tree: In): Transformation[Phase, Nothing] =
ko(WorkflowStage.GENERATE, UnexpectedNode(tree.getClass.getSimpleName))
}

trait CodeGenerator[In <: TreeNode[In]] extends Generator[In, String] {

private def generateAndJoin(
ctx: GeneratorContext,
trees: Seq[In],
separator: String): Transformation[Phase, String] = {
trees.map(generate(ctx, _)).sequence.map(_.mkString(separator))
}

def commas(ctx: GeneratorContext, trees: Seq[In]): Transformation[Phase, String] = generateAndJoin(ctx, trees, ", ")
def spaces(ctx: GeneratorContext, trees: Seq[In]): Transformation[Phase, String] = generateAndJoin(ctx, trees, " ")

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.databricks.labs.remorph

import com.databricks.labs.remorph.intermediate.UncaughtException

import scala.util.control.NonFatal

package object generators {

implicit class TBAInterpolator(sc: StringContext) extends TransformationConstructors[Phase] {
def code(args: Any*): Transformation[Phase, String] = {

args
.map {
case tba: Transformation[_, _] => tba.asInstanceOf[Transformation[Phase, String]]
case x => ok(x.toString)
}
.sequence
.map { a =>
val stringParts = sc.parts.iterator
val arguments = a.iterator
val sb = new StringBuilder(StringContext.treatEscapes(stringParts.next()))
while (arguments.hasNext) {
try {
sb.append(StringContext.treatEscapes(arguments.next()))
sb.append(StringContext.treatEscapes(stringParts.next()))
} catch {
case NonFatal(e) =>
return lift(KoResult(WorkflowStage.GENERATE, UncaughtException(e)))
}
}
sb.toString()

}
}
}

implicit class TBAOps(sql: Transformation[Phase, String]) {
def nonEmpty: Transformation[Phase, Boolean] = sql.map(_.nonEmpty)
def isEmpty: Transformation[Phase, Boolean] = sql.map(_.isEmpty)
}

implicit class TBASeqOps(tbas: Seq[Transformation[Phase, String]]) extends TransformationConstructors[Phase] {

def mkCode: Transformation[Phase, String] = mkCode("", "", "")

def mkCode(sep: String): Transformation[Phase, String] = mkCode("", sep, "")

def mkCode(start: String, sep: String, end: String): Transformation[Phase, String] = {
tbas.sequence.map(_.mkString(start, sep, end))
}

/**
* Combine multiple Transformation[RemorphContext, String] into a Transformation[ RemorphContext, Seq[String] ].
* The resulting Transformation will run each individual Transformation in sequence, accumulating all the effects
* along the way.
*
* For example, when a Transformation in the input Seq modifies the state, TBAs that come after it in the input
* Seq will see the modified state.
*/
def sequence: Transformation[Phase, Seq[String]] =
tbas.foldLeft(ok(Seq.empty[String])) { case (agg, item) =>
for {
aggSeq <- agg
i <- item
} yield aggSeq :+ i
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package com.databricks.labs.remorph.generators.py

import com.databricks.labs.remorph.PartialResult
import com.databricks.labs.remorph.generators.{Generator, GeneratorContext}
import com.databricks.labs.remorph.generators._
import com.databricks.labs.remorph.intermediate.{RemorphError, TreeNode, UnexpectedNode}

abstract class BasePythonGenerator[In <: TreeNode[In]] extends Generator[In, String] {
def commas(ctx: GeneratorContext, nodes: Seq[In]): Python = nodes.map(generate(ctx, _)).mkPython(", ")
abstract class BasePythonGenerator[In <: TreeNode[In]] extends CodeGenerator[In] {

def partialResult(tree: In): Python = partialResult(tree, UnexpectedNode(tree.toString))
def partialResult(trees: Seq[Any], err: RemorphError): Python =
PartialResult(s"# FIXME: ${trees.mkString(" | ")} !!!", err)
def partialResult(tree: Any, err: RemorphError): Python = PartialResult(s"# FIXME: $tree !!!", err)
lift(PartialResult(s"# FIXME: ${trees.mkString(" | ")} !!!", err))
def partialResult(tree: Any, err: RemorphError): Python = lift(PartialResult(s"# FIXME: $tree !!!", err))
}
Loading

0 comments on commit 50a3f23

Please sign in to comment.