Skip to content

Commit

Permalink
support removal of multiply operator (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
odashi authored Oct 15, 2023
1 parent 37d47c9 commit e0ddde4
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def collatz(n):
\If{$n \mathbin{\%} 2 = 0$}
\State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$
\Else
\State $n \gets 3 \cdot n + 1$
\State $n \gets 3 n + 1$
\EndIf
\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$
\EndWhile
Expand All @@ -80,7 +80,7 @@ def collatz(n):
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 \cdot n + 1 \\"
r" \hspace{3em} n \gets 3 n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
Expand Down
21 changes: 9 additions & 12 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def test_quadratic_solution() -> None:
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)

latex = (
r"\mathrm{solve}(a, b, c) ="
r" \frac{-b + \sqrt{ b^{2} - 4 \cdot a \cdot c }}{2 \cdot a}"
)
latex = r"\mathrm{solve}(a, b, c) =" r" \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}"
integration_utils.check_function(solve, latex)


Expand Down Expand Up @@ -47,7 +44,7 @@ def xtimesbeta(x, beta):
xtimesbeta, latex_without_symbols, use_math_symbols=False
)

latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \cdot \beta"
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta"
integration_utils.check_function(
xtimesbeta, latex_with_symbols, use_math_symbols=True
)
Expand Down Expand Up @@ -145,7 +142,7 @@ def test_nested_function() -> None:
def nested(x):
return 3 * x

integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 \cdot x")
integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x")


def test_double_nested_function() -> None:
Expand All @@ -155,7 +152,7 @@ def inner(y):

return inner

integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x \cdot y")
integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y")


def test_reduce_assignments() -> None:
Expand All @@ -165,11 +162,11 @@ def f(x):

integration_utils.check_function(
f,
r"\begin{array}{l} a = x + x \\ f(x) = 3 \cdot a \end{array}",
r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}",
)
integration_utils.check_function(
f,
r"f(x) = 3 \cdot \mathopen{}\left( x + x \mathclose{}\right)",
r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)",
reduce_assignments=True,
)

Expand All @@ -184,15 +181,15 @@ def f(x):
r"\begin{array}{l}"
r" a = x^{2} \\"
r" b = a + a \\"
r" f(x) = 3 \cdot b"
r" f(x) = 3 b"
r" \end{array}"
)

integration_utils.check_function(f, latex_without_option)
integration_utils.check_function(f, latex_without_option, reduce_assignments=False)
integration_utils.check_function(
f,
r"f(x) = 3 \cdot \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
reduce_assignments=True,
)

Expand Down Expand Up @@ -228,7 +225,7 @@ def solve(a, b):
r"\mathrm{solve}(a, b) ="
r" \frac{a + b - b}{a - b} - \mathopen{}\left("
r" a + b \mathclose{}\right) - \mathopen{}\left("
r" a - b \mathclose{}\right) - a \cdot b"
r" a - b \mathclose{}\right) - a b"
)
integration_utils.check_function(solve, latex)

Expand Down
83 changes: 83 additions & 0 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import ast
import re

from latexify import analyzers, ast_utils, exceptions
from latexify.codegen import codegen_utils, expression_rules, identifier_converter
Expand Down Expand Up @@ -406,12 +407,94 @@ def _wrap_binop_operand(

return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"

_l_bracket_pattern = re.compile(r"^\\mathopen.*")
_r_bracket_pattern = re.compile(r".*\\mathclose[^ ]+$")
_r_word_pattern = re.compile(r"\\mathrm\{[^ ]+\}$")

def _should_remove_multiply_op(
self, l_latex: str, r_latex: str, l_expr: ast.expr, r_expr: ast.expr
):
"""Determine whether the multiply operator should be removed or not.
See also:
https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
This is an ad-hoc implementation.
This function doesn't fully implements the above requirements, but only
essential ones necessary to release v0.3.
"""

# NOTE(odashi): For compatibility with Python 3.7, we compare the generated
# caracter type directly to determine the "numeric" type.

if isinstance(l_expr, ast.Call):
l_type = "f"
elif self._r_bracket_pattern.match(l_latex):
l_type = "b"
elif self._r_word_pattern.match(l_latex):
l_type = "w"
elif l_latex[-1].isnumeric():
l_type = "n"
else:
le = l_expr
while True:
if isinstance(le, ast.UnaryOp):
le = le.operand
elif isinstance(le, ast.BinOp):
le = le.right
elif isinstance(le, ast.Compare):
le = le.comparators[-1]
elif isinstance(le, ast.BoolOp):
le = le.values[-1]
else:
break
l_type = "a" if isinstance(le, ast.Name) and len(le.id) == 1 else "m"

if isinstance(r_expr, ast.Call):
r_type = "f"
elif self._l_bracket_pattern.match(r_latex):
r_type = "b"
elif r_latex.startswith("\\mathrm"):
r_type = "w"
elif r_latex[0].isnumeric():
r_type = "n"
else:
re = r_expr
while True:
if isinstance(re, ast.UnaryOp):
if isinstance(re.op, ast.USub):
# NOTE(odashi): Unary "-" always require \cdot.
return False
re = re.operand
elif isinstance(re, ast.BinOp):
re = re.left
elif isinstance(re, ast.Compare):
re = re.left
elif isinstance(re, ast.BoolOp):
re = re.values[0]
else:
break
r_type = "a" if isinstance(re, ast.Name) and len(re.id) == 1 else "m"

if r_type == "n":
return False
if l_type in "bn":
return True
if l_type in "am" and r_type in "am":
return True
return False

def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = expression_rules.get_precedence(node)
rule = self._bin_op_rules[type(node.op)]
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)

if type(node.op) in [ast.Mult, ast.MatMult]:
if self._should_remove_multiply_op(lhs, rhs, node.left, node.right):
return f"{rule.latex_left}{lhs} {rhs}{rule.latex_right}"

return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"

def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
Expand Down
Loading

0 comments on commit e0ddde4

Please sign in to comment.