Skip to content

Commit

Permalink
Ensure that unparsable text is not lost in the generated output (#1012)
Browse files Browse the repository at this point in the history
The default ErrorStrategy for ANTLR generated parsers performs a
sophisticated error recovery synchronization process for both
PLUS_LOOP_BACK and STAR_LOOP_BACK as well as token manufacture/insertion
and single token deletion within token sequences. This culminates and a
call to `recover()` which finds the next token in the followSet for an
alt that allows the parser to resume. We intercept `recover()` in order
to record where we were unable to parse text.

While errors are reported, the default error strategy does not preserve
any discarded input in the generated ParseTree and so this is lost if we
generate those parts of the ParseTree that were successfully generated.

Here, we implement custom strategies that gather un-parsable input and
preserve them as custom error nodes in the ParserTree at strategic
insertion points in the higher level rules such as `sqlCommand` (in the
case of Snowflake) and `sqlClauses` in the case of TSQL.

The visitors for these rules can then first check for an error node in
the children and generate an Ir node representing the unparsed text.

For this PR to be usable, we need our PLanParser to no longer stop when
syntax errors are discovered as it is now safe to walk the ParseTree.
That improvement is for a separate PR.
  • Loading branch information
jimidle authored Oct 21, 2024
1 parent c18f189 commit c484f6a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.databricks.labs.remorph.parsers
import com.databricks.labs.remorph.{intermediate => ir}
import com.typesafe.scalalogging.LazyLogging
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{AbstractParseTreeVisitor, ParseTree, ParseTreeVisitor, RuleNode}
import org.antlr.v4.runtime.tree._
import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -123,4 +123,18 @@ trait ParserCommon[A] extends ParseTreeVisitor[A] with LazyLogging { self: Abstr
result
}
}

/**
* If the parser recognizes a syntax error, then it generates an ErrorNode, which represents text in error
* and contains a manufactured token that encapsulates all the text that the parser ignored when it recovered
* from the error. NOt that if the error recovery strategy inserts a token rather than deletes one, then an
* error node will not be created.
*
* @param node the ErrorNode to visit
* @return The unresolved object representing the error and containing the text that was skipped
*/
override def visitErrorNode(node: ErrorNode): A = {
logger.warn(s"Error node encountered: ${node.getText}")
unresolved(node.getText, "Unparsed input - ErrorNode encountered")
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.databricks.labs.remorph.parsers

import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.SqlCommandContext
import com.databricks.labs.remorph.parsers.tsql.TSqlParser.SqlClausesContext
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.misc.IntervalSet
import org.antlr.v4.runtime.misc.{Interval, IntervalSet, Pair}
import org.antlr.v4.runtime.tree.ErrorNodeImpl

/**
* Custom error strategy for SQL parsing <p> While we do not do anything super special here, we wish to override a
Expand All @@ -15,6 +18,53 @@ import org.antlr.v4.runtime.misc.IntervalSet
*/
abstract class SqlErrorStrategy extends DefaultErrorStrategy {

@throws[RecognitionException]
override def sync(recognizer: Parser): Unit = {
val tokens: TokenStream = recognizer.getInputStream
val startIndex: Int = tokens.index
val first = tokens.LT(1)
try {
super.sync(recognizer)
} catch {
case e: RecognitionException => throw e // Throw back to parser
} finally {
val endIndex: Int = tokens.index
if (startIndex < endIndex) {
val interval = new Interval(startIndex, endIndex)
val errorToken: CommonToken = new CommonToken(
new Pair(first.getTokenSource, first.getInputStream),
Token.INVALID_TYPE,
Token.DEFAULT_CHANNEL,
first.getStartIndex,
tokens.LT(1).getStopIndex)
errorToken.setText(first.getInputStream.getText(interval))
errorToken.setLine(first.getLine)
errorToken.setCharPositionInLine(first.getCharPositionInLine)
val errorNode = new ErrorNodeImpl(errorToken)

// Here we add the error node to the highest level context in the tree for the particular parser,
// so that we do not have to search for error nodes in the children of every visitor context.
// If we added it to the current context, it would mean that every visitor method would
// have to check for it, which would soon become unwieldy. We may come back to that though
// if preserved skipped text is generated far away from the original text in error.
findHighestContext(recognizer.getContext).addErrorNode(errorNode)
}
}
}

def findHighestContext(ctx: ParserRuleContext): ParserRuleContext = {
@annotation.tailrec
def findContext(currentCtx: ParserRuleContext): ParserRuleContext = {
currentCtx match {
case _: SqlClausesContext | _: SqlCommandContext => currentCtx
case _ if currentCtx.getParent == null => currentCtx
case _ => findContext(currentCtx.getParent.asInstanceOf[ParserRuleContext])
}
}

findContext(ctx)
}

// Note that it is not possible to get this error from the current grammar, we would have to do an inordinate
// amount of mocking to raise this. It isn't worth the effort.
// $COVERAGE-OFF$
Expand Down

0 comments on commit c484f6a

Please sign in to comment.