Skip to content

Commit

Permalink
Ensure to escape characters before constructing JSON profile trace (#…
Browse files Browse the repository at this point in the history
…21872)

Fixes #21858 by setting up special escapes for characters that might
corrupt the output JSON file produced by `-Yprofile-trace`
  • Loading branch information
WojciechMazur authored Nov 12, 2024
2 parents 896965c + 55d2bd7 commit 6f48c39
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 6 deletions.
46 changes: 46 additions & 0 deletions compiler/src/dotty/tools/dotc/profile/JsonNameTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package dotty.tools.dotc.profile

import scala.annotation.internal.sharable

// Based on NameTransformer but dedicated for JSON encoding rules
object JsonNameTransformer {
private val nops = 128

@sharable private val op2code = new Array[String](nops)
private def enterOp(op: Char, code: String) = op2code(op.toInt) = code

enterOp('\"', "\\\"")
enterOp('\\', "\\\\")
// enterOp('/', "\\/") // optional, no need for escaping outside of html context
enterOp('\b', "\\b")
enterOp('\f', "\\f")
enterOp('\n', "\\n")
enterOp('\r', "\\r")
enterOp('\t', "\\t")

def encode(name: String): String = {
var buf: StringBuilder = null.asInstanceOf
val len = name.length
var i = 0
while (i < len) {
val c = name(i)
if (c < nops && (op2code(c.toInt) ne null)) {
if (buf eq null) {
buf = new StringBuilder()
buf.append(name.subSequence(0, i))
}
buf.append(op2code(c.toInt))
} else if (c <= 0x1F || c >= 0x7F) {
if (buf eq null) {
buf = new StringBuilder()
buf.append(name.subSequence(0, i))
}
buf.append("\\u%04X".format(c.toInt))
} else if (buf ne null) {
buf.append(c)
}
i += 1
}
if (buf eq null) name else buf.toString
}
}
15 changes: 10 additions & 5 deletions compiler/src/dotty/tools/dotc/profile/Profiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
override def beforePhase(phase: Phase): (TracedEventId, ProfileSnap) = {
assert(mainThread eq Thread.currentThread())
traceThreadSnapshotCounters()
val eventId = traceDurationStart(Category.Phase, phase.phaseName)
val eventId = traceDurationStart(Category.Phase, escapeSpecialChars(phase.phaseName))
if (ctx.settings.YprofileRunGcBetweenPhases.value.contains(phase.toString))
doGC()
if (ctx.settings.YprofileExternalTool.value.contains(phase.toString)) {
Expand All @@ -287,7 +287,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
assert(mainThread eq Thread.currentThread())
if chromeTrace != null then
traceThreadSnapshotCounters()
traceDurationStart(Category.File, unit.source.name)
traceDurationStart(Category.File, escapeSpecialChars(unit.source.name))
else TracedEventId.Empty
}

Expand Down Expand Up @@ -325,7 +325,7 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
then EmptyCompletionEvent
else
val completionName = this.completionName(root, associatedFile)
val event = TracedEventId(associatedFile.name)
val event = TracedEventId(escapeSpecialChars(associatedFile.name))
chromeTrace.traceDurationEventStart(Category.Completion.name, "", colour = "thread_state_sleeping")
chromeTrace.traceDurationEventStart(Category.File.name, event)
chromeTrace.traceDurationEventStart(Category.Completion.name, completionName)
Expand All @@ -350,8 +350,13 @@ private [profile] class RealProfiler(reporter : ProfileReporter)(using Context)
if chromeTrace != null then
chromeTrace.traceDurationEventEnd(category.name, event, colour)

private def symbolName(sym: Symbol): String = s"${sym.showKind} ${sym.showName}"
private def completionName(root: Symbol, associatedFile: AbstractFile): String =
private inline def escapeSpecialChars(value: String): String =
JsonNameTransformer.encode(value)

private def symbolName(sym: Symbol): String = escapeSpecialChars:
s"${sym.showKind} ${sym.showName}"

private def completionName(root: Symbol, associatedFile: AbstractFile): String = escapeSpecialChars:
def isTopLevel = root.owner != NoSymbol && root.owner.is(Flags.Package)
if root.is(Flags.Package) || isTopLevel
then root.javaBinaryName
Expand Down
2 changes: 1 addition & 1 deletion compiler/test/dotty/tools/DottyTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ trait DottyTest extends ContextEscapeDetection {

protected def defaultCompiler: Compiler = new Compiler()

private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler {
protected def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler {

private val baseCompiler = defaultCompiler

Expand Down
133 changes: 133 additions & 0 deletions compiler/test/dotty/tools/dotc/profile/TraceNameManglingTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package dotty.tools.dotc.profile

import org.junit.Assert.*
import org.junit.*

import scala.annotation.tailrec
import dotty.tools.DottyTest
import dotty.tools.dotc.util.SourceFile
import dotty.tools.dotc.core.Contexts.FreshContext
import java.nio.file.Files
import java.util.Locale

class TraceNameManglingTest extends DottyTest {

override protected def initializeCtx(fc: FreshContext): Unit = {
super.initializeCtx(fc)
val tmpDir = Files.createTempDirectory("trace_name_mangling_test").nn
fc.setSetting(fc.settings.YprofileEnabled, true)
fc.setSetting(
fc.settings.YprofileTrace,
tmpDir.resolve("trace.json").nn.toAbsolutePath().toString()
)
fc.setSetting(
fc.settings.YprofileDestination,
tmpDir.resolve("profiler.out").nn.toAbsolutePath().toString()
)
}

@Test def escapeBackslashes(): Unit = {
val isWindows = sys.props("os.name").toLowerCase(Locale.ROOT) == "windows"
val filename = if isWindows then "/.scala" else "\\.scala"
checkTraceEvents(
"""
|class /\ :
| var /\ = ???
|object /\{
| def /\ = ???
|}""".stripMargin,
filename = filename
)(
Set(
raw"class /\\",
raw"object /\\",
raw"method /\\",
raw"variable /\\",
raw"setter /\\_="
).map(TraceEvent("typecheck", _))
++ Set(
TraceEvent("file", if isWindows then "/.scala" else "\\\\.scala")
)
)
}

@Test def escapeDoubleQuotes(): Unit = {
val filename = "\"quoted\".scala"
checkTraceEvents(
"""
|class `"QuotedClass"`:
| var `"quotedVar"` = ???
|object `"QuotedObject"` {
| def `"quotedMethod"` = ???
|}""".stripMargin,
filename = filename
):
Set(
raw"class \"QuotedClass\"",
raw"object \"QuotedObject\"",
raw"method \"quotedMethod\"",
raw"variable \"quotedVar\""
).map(TraceEvent("typecheck", _))
++ Set(TraceEvent("file", "\\\"quoted\\\".scala"))
}
@Test def escapeNonAscii(): Unit = {
val filename = "unic😀de.scala"
checkTraceEvents(
"""
|class ΩUnicodeClass:
| var `中文Var` = ???
|object ΩUnicodeObject {
| def 中文Method = ???
|}""".stripMargin,
filename = filename
):
Set(
"class \\u03A9UnicodeClass",
"object \\u03A9UnicodeObject",
"method \\u4E2D\\u6587Method",
"variable \\u4E2D\\u6587Var"
).map(TraceEvent("typecheck", _))
++ Set(TraceEvent("file", "unic\\uD83D\\uDE00de.scala"))
}

case class TraceEvent(category: String, name: String)
private def compileWithTracer(
code: String,
filename: String,
afterPhase: String = "typer"
)(checkEvents: Seq[TraceEvent] => Unit) = {
val runCtx = locally:
val source = SourceFile.virtual(filename, code)
val c = compilerWithChecker(afterPhase) { (_, _) => () }
val run = c.newRun
run.compileSources(List(source))
run.runContext
assert(!runCtx.reporter.hasErrors, "compilation failed")
val outfile = ctx.settings.YprofileTrace.value
checkEvents:
scala.io.Source
.fromFile(outfile)
.getLines()
.collect:
case s"""${_}"cat":"${category}","name":${name},"ph":${_}""" =>
TraceEvent(category, name.stripPrefix("\"").stripSuffix("\""))
.distinct.toSeq
}

private def checkTraceEvents(code: String, filename: String = "test")(expected: Set[TraceEvent]): Unit = {
compileWithTracer(code, filename = filename, afterPhase = "typer"){ events =>
val missing = expected.diff(events.toSet)
def showFound = events
.groupBy(_.category)
.collect:
case (category, events)
if expected.exists(_.category == category) =>
s"- $category: [${events.map(_.name).mkString(", ")}]"
.mkString("\n")
assert(
missing.isEmpty,
s"""Missing ${missing.size} names [${missing.mkString(", ")}] in events, got:\n${showFound}"""
)
}
}
}

0 comments on commit 6f48c39

Please sign in to comment.