Skip to content

Commit

Permalink
Add DocstringRemover (#197)
Browse files Browse the repository at this point in the history
* Add DocstringRemover

* fix test
  • Loading branch information
odashi authored Dec 3, 2023
1 parent 0cba4c9 commit 2114923
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/latexify/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ def is_constant(node: ast.AST) -> bool:
return isinstance(node, ast.Constant)


def is_str(node: ast.AST) -> bool:
"""Checks if the node is a str constant.
Args:
node: The node to examine.
Returns:
True if the node is a str constant, False otherwise.
"""
if sys.version_info.minor < 8 and isinstance(node, ast.Str):
return True

return isinstance(node, ast.Constant) and isinstance(node.value, str)


def extract_int_or_none(node: ast.expr) -> int | None:
"""Extracts int constant from the given Constant node.
Expand Down
32 changes: 32 additions & 0 deletions src/latexify/ast_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def test_is_constant(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_constant(value) is expected


@test_utils.require_at_most(7)
@pytest.mark.parametrize(
"value,expected",
[
(ast.Bytes(s=b"foo"), False),
(ast.Constant("bar"), True),
(ast.Ellipsis(), False),
(ast.NameConstant(value=None), False),
(ast.Num(n=123), False),
(ast.Str(s="baz"), True),
(ast.Expr(value=ast.Num(456)), False),
(ast.Global(names=["qux"]), False),
],
)
def test_is_str_legacy(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_str(value) is expected


@test_utils.require_at_least(8)
@pytest.mark.parametrize(
"value,expected",
[
(ast.Constant(value=123), False),
(ast.Constant(value="foo"), True),
(ast.Expr(value=ast.Constant(value="foo")), False),
(ast.Global(names=["foo"]), False),
],
)
def test_is_str(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_str(value) is expected


def test_extract_int_or_none() -> None:
assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123
assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0
Expand Down
1 change: 1 addition & 0 deletions src/latexify/generate_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_latex(
if merged_config.identifiers is not None:
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
if merged_config.reduce_assignments:
tree = transformers.DocstringRemover().visit(tree)
tree = transformers.AssignmentReducer().visit(tree)
if merged_config.expand_functions is not None:
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)
Expand Down
14 changes: 14 additions & 0 deletions src/latexify/generate_latex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def f(x):
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_docstring() -> None:
def f(x):
"""DocstringRemover is required."""
y = 3 * x
return y

latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"

assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_aug_assign() -> None:
def f(x):
y = 3
Expand Down
2 changes: 2 additions & 0 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
from latexify.transformers.docstring_remover import DocstringRemover
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer

__all__ = [
"AssignmentReducer",
"AugAssignReplacer",
"DocstringRemover",
"FunctionExpander",
"IdentifierReplacer",
"PrefixTrimmer",
Expand Down
20 changes: 20 additions & 0 deletions src/latexify/transformers/docstring_remover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Transformer to remove all docstrings."""

from __future__ import annotations

import ast
from typing import Union

from latexify import ast_utils


class DocstringRemover(ast.NodeTransformer):
"""NodeTransformer to remove all docstrings.
Docstrings here are detected as Expr nodes with a single string constant.
"""

def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]:
if ast_utils.is_str(node.value):
return None
return node
32 changes: 32 additions & 0 deletions src/latexify/transformers/docstring_remover_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Tests for latexify.transformers.docstring_remover."""

import ast

from latexify import ast_utils, parser, test_utils
from latexify.transformers.docstring_remover import DocstringRemover


def test_remove_docstrings() -> None:
def f():
"""Test docstring."""
x = 42
f() # This Expr should not be removed.
"""This string constant should also be removed."""
return x

tree = parser.parse_function(f).body[0]
assert isinstance(tree, ast.FunctionDef)

expected = ast.FunctionDef(
name="f",
body=[
ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast_utils.make_constant(42),
),
ast.Expr(value=ast.Call(func=ast.Name(id="f", ctx=ast.Load()))),
ast.Return(value=ast.Name(id="x", ctx=ast.Load())),
],
)
transformed = DocstringRemover().visit(tree)
test_utils.assert_ast_equal(transformed, expected)

0 comments on commit 2114923

Please sign in to comment.