Skip to content

Commit

Permalink
include stub locations in errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Oct 3, 2024
1 parent 2e376c1 commit 17a3426
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 50 deletions.
6 changes: 3 additions & 3 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import libcst as cst
import pytest
from unpy.exceptions import StubError
from unpy.exceptions import StubError, StubSyntaxError
from unpy.visitors import StubVisitor


Expand Down Expand Up @@ -39,7 +39,7 @@ def test_illegal_future_import():
],
)
def test_illegal_stringified_annotations(source: str):
with pytest.raises(StubError):
with pytest.raises(StubSyntaxError):
_visit(source)


Expand All @@ -51,7 +51,7 @@ def test_illegal_stringified_annotations(source: str):
],
)
def test_illegal_special_functions_at_module_lvl(source: str):
with pytest.raises(StubError):
with pytest.raises(StubSyntaxError):
_visit(source)


Expand Down
8 changes: 6 additions & 2 deletions unpy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
__all__ = ("StubError",)
__all__ = ("StubError", "StubSyntaxError")


class StubError(TypeError):
class StubError(Exception):
pass


class StubSyntaxError(SyntaxError):
pass
4 changes: 3 additions & 1 deletion unpy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def build(
) -> None:
assert not version

filename = "<stdin>" if str(source) == "-" else str(source.resolve())

source_str = _read_source(source)
output_str = transform_source(source_str, target=target.version)
output_str = transform_source(source_str, filename=filename, target=target.version)

if diff:
_echo_diff(str(source), source_str, str(output), output_str)
Expand Down
17 changes: 13 additions & 4 deletions unpy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,15 @@ def leave_Module(
return updated_node.with_changes(body=new_body)


def transform_module(original: cst.Module, /, target: PythonVersion) -> cst.Module:
def transform_module(
original: cst.Module,
/,
filename: str = "<stdin>",
target: PythonVersion = PythonVersion.PY310,
) -> cst.Module:
wrapper = cst.MetadataWrapper(original)

visitor = StubVisitor()
visitor = StubVisitor(filename=filename)
_ = wrapper.visit(visitor)

transformer = StubTransformer(visitor, target=target)
Expand All @@ -810,7 +815,11 @@ def transform_module(original: cst.Module, /, target: PythonVersion) -> cst.Modu
def transform_source(
source: str,
/,
*,
filename: str = "<stdin>",
target: PythonVersion = PythonVersion.PY310,
) -> str:
return transform_module(cst.parse_module(source), target=target).code
return transform_module(
cst.parse_module(source),
filename=filename,
target=target,
).code
135 changes: 95 additions & 40 deletions unpy/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import libcst.metadata as cst_meta

import unpy._cst as uncst
from unpy.exceptions import StubError
from unpy.exceptions import StubError, StubSyntaxError

__all__ = ("StubVisitor",)

Expand All @@ -14,26 +14,18 @@
_MODULE_TPX: Final = "typing_extensions"


def _check_annotation_expr(
node: cst.BaseExpression,
/,
name: str | None = None,
) -> None:
if isinstance(node, cst.BaseString):
error = StubError("quoted annotations should not be included in stubs")
if name:
error.add_note(f"in {name!r}")
raise error


class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
"""
Collect all PEP-695 type-parameters & required imports in the module's functions,
classes, and type-aliases.
"""

METADATA_DEPENDENCIES = (cst_meta.ScopeProvider,)
METADATA_DEPENDENCIES = cst_meta.PositionProvider, cst_meta.ScopeProvider

# for error reporting
filename: Final[str]

module: cst.Module
_global_scope: cst_meta.GlobalScope

_stack_scope: Final[collections.deque[str]]
Expand Down Expand Up @@ -62,7 +54,9 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904

nested_classvar_final: bool

def __init__(self, /) -> None:
def __init__(self, /, filename: str = "<stdin>") -> None:
self.filename = filename

self._stack_scope = collections.deque()
self._stack_attr = collections.deque()
self._in_import = False
Expand Down Expand Up @@ -100,6 +94,35 @@ def global_qualnames(self, /) -> frozenset[str]:
def global_names(self, /) -> frozenset[str]:
return frozenset({qn.split(".", 1)[0] for qn in self.global_qualnames})

def meta_position(self, node: cst.CSTNode, /) -> cst_meta.CodeRange:
position = self.get_metadata(cst_meta.PositionProvider, node, default=None)
assert position
return position

def _syntax_details(
self,
node: cst.CSTNode,
/,
) -> tuple[str, int, int, str, int, int]:
"""
Get the `SyntaxError` constructor `details` as
`(filename, lineno, offset, test, end_lineno, end_offset)`.
See Also:
- https://docs.python.org/3.14/library/exceptions.html#SyntaxError
"""
span = self.meta_position(node)
lineno, offset = span.start.line, span.start.column + 1
end_lineno, end_offset = span.end.line, span.end.column + 1

line = self.module.code.splitlines(keepends=True)[lineno - 1]
return self.filename, lineno, offset, line, end_lineno, end_offset

def meta_scope(self, node: cst.CSTNode, /) -> cst_meta.Scope:
scope = self.get_metadata(cst_meta.ScopeProvider, node, default=None)
assert scope
return scope

def imported_as(self, module: str, name: str, /) -> str | None:
"""
Find the alias or attribute path used to access `{module}.{name}`, or return
Expand Down Expand Up @@ -260,7 +283,7 @@ def _build_type_param( # noqa: C901
default = tpar.default

if default:
_check_annotation_expr(default)
self.__check_annotation(default)

name_any = self.imported_from_typing_as("Any")
name_object = self.imported_as(_MODULE_BUILTINS, "object")
Expand Down Expand Up @@ -317,7 +340,7 @@ def _build_type_param( # noqa: C901
constraints = tuple(cons)
bound = None
else:
_check_annotation_expr(bound)
self.__check_annotation(bound)

if _default_any and bound is not None:
# if `default=Any`, replace it the value of `bound` (`Any` is horrible)
Expand Down Expand Up @@ -378,6 +401,21 @@ def __after_import(self, /) -> None:
assert self._in_import
self._in_import = False

def __check_annotation(
self,
node: cst.BaseExpression,
/,
name: str | None = None,
) -> None:
if isinstance(node, cst.BaseString):
error = StubSyntaxError(
"Quoted annotations should not be included in stubs",
self._syntax_details(node),
)
if name:
error.add_note(f"in {name!r}")
raise error

def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
if not isinstance(node.value, cst.Name | cst.Attribute):
return
Expand All @@ -394,40 +432,49 @@ def on_visit(self, /, node: cst.CSTNode) -> bool:
if isinstance(node, cst.BaseSmallStatement):
if isinstance(
node,
cst.Del
| cst.Pass
| cst.Break
| cst.Continue
| cst.Return
| cst.Raise
| cst.Assert
| cst.Global
| cst.Nonlocal,
cst.Del | cst.Pass | cst.Break | cst.Continue | cst.Raise | cst.Assert,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
raise StubSyntaxError(
f"{keyword!r} statements are useless in stubs",
self._syntax_details(node),
)
elif isinstance(node, cst.BaseCompoundStatement):
if isinstance(
node,
cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
raise StubSyntaxError(
f"{keyword!r} statements are useless in stubs",
self._syntax_details(node),
)
elif isinstance(node, cst.BaseExpression):
if isinstance(node, cst.BooleanOperation):
raise StubError("boolean operations are useless in stubs")
raise StubSyntaxError(
"Boolean operations are useless in stubs",
self._syntax_details(node),
)
if isinstance(node, cst.FormattedString):
raise StubError("f-strings are useless in stubs")
raise StubSyntaxError(
"Format-strings are useless in stubs",
self._syntax_details(node),
)
if isinstance(node, cst.Lambda | cst.Await | cst.Yield):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} is an invalid expression")
raise StubSyntaxError(
f"{keyword!r} is an invalid expression",
self._syntax_details(node),
)

return super().on_visit(node)

@override
def visit_Module(self, /, node: cst.Module) -> None:
node.validate_types_deep()

self.module = node

scope = self.get_metadata(cst_meta.ScopeProvider, node)
assert isinstance(scope, cst_meta.Scope)
self._global_scope = scope.globals
Expand Down Expand Up @@ -491,7 +538,7 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None:

@override
def visit_Annotation(self, /, node: cst.Annotation) -> None:
_check_annotation_expr(node.annotation)
self.__check_annotation(node.annotation)

@override
def visit_Assign(self, node: cst.Assign) -> None:
Expand All @@ -518,7 +565,7 @@ def visit_Assign(self, node: cst.Assign) -> None:
# `bound` or `default` kwargs
if arg.keyword is None or arg.keyword.value in {"bound", "default"}:
fname = fname or uncst.get_name_strict(node.value.func)
_check_annotation_expr(arg.value, f"{target.value} = {fname}(...)")
self.__check_annotation(arg.value, f"{target.value} = {fname}(...)")

@override
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
Expand All @@ -535,7 +582,7 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
):
# this is a legacy `typing[_extensions].TypeAlias`
# TODO(jorenham): either warn user & register, or just disallow this
_check_annotation_expr(node.value, f"{node.target.value}")
self.__check_annotation(node.value, f"{node.target.value}")

# check for nested `ClassVar` and `Final`
if (
Expand Down Expand Up @@ -582,7 +629,7 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None:
name = node.name.value
assert name not in self.type_aliases

_check_annotation_expr(node.value, f"type {name}")
self.__check_annotation(node.value, f"type {name}")

if tpars := node.type_parameters:
self.type_aliases[name] = self._register_type_params(name, tpars)
Expand All @@ -598,13 +645,21 @@ def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None:
stack.append(name := node.name.value)

if len(stack) == 1 and name in {"__getattr__", "__dir__"}:
raise StubError(f"module-level {name}() cannot be used in a stub")
raise StubSyntaxError(
f"Module-level {name}() cannot be used in a stub",
self._syntax_details(node),
)

assert isinstance(node.body, cst.SimpleStatementSuite | cst.IndentedBlock)
if len(node.body.body) != 1 or not isinstance(node.body.body[0], cst.Ellipsis):
error = StubError("function body must contain only `...`")
qualname = ".".join(stack)
error.add_note(qualname)
if len(node.body.body) != 1 or not (
isinstance(body_expr := node.body.body[0], cst.Ellipsis)
or isinstance(body_expr, cst.Expr)
and isinstance(body_expr.value, cst.Ellipsis)
):
raise StubSyntaxError(
f"Function body must contain only `...`\n{node.body.body[0]}",
self._syntax_details(node.body.body[0]),
) from None

if tpars := node.type_parameters:
self._register_type_params(stack[0], tpars)
Expand Down

0 comments on commit 17a3426

Please sign in to comment.