Skip to content

Commit

Permalink
feat: ✨ numpy-array-use-type-var flag (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
ringohoffman authored Nov 21, 2023
1 parent af3c6af commit 60728d3
Show file tree
Hide file tree
Showing 178 changed files with 1,158 additions and 24 deletions.
20 changes: 14 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
if: ${{ failure() || success() }}

tests:
name: "Test 🐍 ${{ matrix.python }} • pybind-${{ matrix.pybind11-branch }}"
name: "Test 🐍 ${{ matrix.python }} • pybind-${{ matrix.pybind11-branch }} • ${{ matrix.numpy-format }}"
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand All @@ -52,11 +52,18 @@ jobs:
- "3.9"
- "3.8"
- "3.7"
numpy-format:
- "numpy-array-wrap-with-annotated"
include:
- pybind11-branch: "v2.9"
python: "3.12"
- pybind11-branch: "v2.11"
python: "3.12"
- python: "3.12"
pybind11-branch: "v2.9"
numpy-format: "numpy-array-wrap-with-annotated"
- python: "3.12"
pybind11-branch: "v2.11"
numpy-format: "numpy-array-wrap-with-annotated"
- python: "3.12"
pybind11-branch: "master"
numpy-format: "numpy-array-use-type-var"
steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -84,7 +91,7 @@ jobs:

- name: Check stubs generation
shell: bash
run: ./tests/check-demo-stubs-generation.sh --stubs-sub-dir "stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}"
run: ./tests/check-demo-stubs-generation.sh --stubs-sub-dir "stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}/${{ matrix.numpy-format }}" --${{ matrix.numpy-format }}

- name: Archive patch
uses: actions/upload-artifact@v3
Expand Down Expand Up @@ -137,6 +144,7 @@ jobs:
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-1 --numpy-array-wrap-with-annotated
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-2 --numpy-array-remove-parameters
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-3 --print-invalid-expressions-as-is
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-4 --numpy-array-use-type-var
pybind11-stubgen "${{ matrix.test-package }}" --dry-run
publish:
Expand Down
12 changes: 12 additions & 0 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
FixMissingImports,
FixMissingNoneHashFieldAnnotation,
FixNumpyArrayDimAnnotation,
FixNumpyArrayDimTypeVar,
FixNumpyArrayFlags,
FixNumpyArrayRemoveParameters,
FixNumpyDtype,
FixPEP585CollectionNames,
FixPybind11EnumStrDoc,
FixRedundantBuiltinsAnnotation,
FixRedundantMethodsFromBuiltinObject,
FixScipyTypeArguments,
FixTypingTypeNames,
FixValueReprRandomAddress,
OverridePrintSafeValues,
Expand All @@ -66,6 +68,7 @@ class CLIArgs(Namespace):
ignore_all_errors: bool
enum_class_locations: list[tuple[re.Pattern, str]]
numpy_array_wrap_with_annotated: bool
numpy_array_use_type_var: bool
numpy_array_remove_parameters: bool
print_invalid_expressions_as_is: bool
print_safe_value_reprs: re.Pattern | None
Expand Down Expand Up @@ -156,6 +159,13 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
"'ARRAY_T[TYPE, [*DIMS], *FLAGS]' format with "
"'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'",
)
numpy_array_fix.add_argument(
"--numpy-array-use-type-var",
default=False,
action="store_true",
help="Replace 'numpy.ndarray[numpy.float32[m, 1]]' with "
"'numpy.ndarray[tuple[M, typing.Literal[1]], numpy.dtype[numpy.float32]]'",
)

numpy_array_fix.add_argument(
"--numpy-array-remove-parameters",
Expand Down Expand Up @@ -230,6 +240,7 @@ def stub_parser_from_args(args: CLIArgs) -> IParser:

numpy_fixes: list[type] = [
*([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []),
*([FixNumpyArrayDimTypeVar] if args.numpy_array_use_type_var else []),
*(
[FixNumpyArrayRemoveParameters]
if args.numpy_array_remove_parameters
Expand All @@ -246,6 +257,7 @@ class Parser(
FilterTypingModuleAttributes,
FixPEP585CollectionNames,
FixTypingTypeNames,
FixScipyTypeArguments,
FixMissingFixedSizeImport,
FixMissingEnumMembersAnnotation,
OverridePrintSafeValues,
Expand Down
184 changes: 174 additions & 10 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Property,
QualifiedName,
ResolvedType,
TypeVar_,
Value,
)
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize
Expand Down Expand Up @@ -335,6 +336,7 @@ class FixTypingTypeNames(IParser):
"Iterator",
"KeysView",
"List",
"Literal",
"Optional",
"Sequence",
"Set",
Expand All @@ -360,12 +362,28 @@ def __init__(self):
super().__init__()
if sys.version_info < (3, 9):
self.__typing_extensions_names.add(Identifier("Annotated"))
if sys.version_info < (3, 8):
self.__typing_extensions_names.add(Identifier("Literal"))

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str)
if not isinstance(result, ResolvedType) or len(result.name) != 1:
return self._parse_annotation_str(result)

def _parse_annotation_str(
self, result: ResolvedType | InvalidExpression | Value
) -> ResolvedType | InvalidExpression | Value:
if not isinstance(result, ResolvedType):
return result

result.parameters = (
[self._parse_annotation_str(p) for p in result.parameters]
if result.parameters is not None
else None
)

if len(result.name) != 1:
return result

word = result.name[0]
Expand Down Expand Up @@ -582,6 +600,136 @@ def report_error(self, error: ParserError) -> None:
super().report_error(error)


class FixNumpyArrayDimTypeVar(IParser):
__array_names: set[QualifiedName] = {QualifiedName.from_str("numpy.ndarray")}
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types

__DIM_VARS: set[str] = set()

def handle_module(
self, path: QualifiedName, module: types.ModuleType
) -> Module | None:
result = super().handle_module(path, module)
if result is None:
return None

if self.__DIM_VARS:
# the TypeVar_'s generated code will reference `typing`
result.imports.add(
Import(name=None, origin=QualifiedName.from_str("typing"))
)

for name in self.__DIM_VARS:
result.type_vars.append(
TypeVar_(
name=Identifier(name),
bound=self.parse_annotation_str("int"),
),
)

self.__DIM_VARS.clear()

return result

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
# Affects types of the following pattern:
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
# Replace with:
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]

result = super().parse_annotation_str(annotation_str)

if not isinstance(result, ResolvedType):
return result

# handle unqualified, single-letter annotation as a TypeVar
if len(result.name) == 1 and len(result.name[0]) == 1:
result.name = QualifiedName.from_str(result.name[0].upper())
self.__DIM_VARS.add(result.name[0])

if result.name not in self.__array_names:
return result

# ndarray is generic and should have 2 type arguments
if result.parameters is None or len(result.parameters) == 0:
result.parameters = [
self.parse_annotation_str("Any"),
ResolvedType(
name=QualifiedName.from_str("numpy.dtype"),
parameters=[self.parse_annotation_str("Any")],
),
]
return result

scalar_with_dims = result.parameters[0] # e.g. numpy.float64[32, 32]

if (
not isinstance(scalar_with_dims, ResolvedType)
or scalar_with_dims.name not in self.numpy_primitive_types
):
return result

dtype = ResolvedType(
name=QualifiedName.from_str("numpy.dtype"),
parameters=[ResolvedType(name=scalar_with_dims.name)],
)

shape = self.parse_annotation_str("Any")
if (
scalar_with_dims.parameters is not None
and len(scalar_with_dims.parameters) > 0
):
dims = self.__to_dims(scalar_with_dims.parameters)
if dims is not None:
shape = self.parse_annotation_str("Tuple")
assert isinstance(shape, ResolvedType)
shape.parameters = []
for dim in dims:
if isinstance(dim, int):
# self.parse_annotation_str will qualify Literal with either
# typing or typing_extensions and add the import to the module
literal_dim = self.parse_annotation_str("Literal")
assert isinstance(literal_dim, ResolvedType)
literal_dim.parameters = [Value(repr=str(dim))]
shape.parameters.append(literal_dim)
else:
shape.parameters.append(
ResolvedType(name=QualifiedName.from_str(dim))
)

result.parameters = [shape, dtype]
return result

def __to_dims(
self, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
) -> list[int | str] | None:
result: list[int | str] = []
for dim_param in dimensions:
if isinstance(dim_param, Value):
try:
dim = int(dim_param.repr)
except ValueError:
return None
elif isinstance(dim_param, ResolvedType):
dim = str(dim_param)
else:
return None
result.append(dim)
return result

def report_error(self, error: ParserError) -> None:
if (
isinstance(error, NameResolutionError)
and len(error.name) == 1
and error.name[0] in self.__DIM_VARS
):
# allow type variables, which are manually resolved in `handle_module`
return
super().report_error(error)


class FixNumpyArrayRemoveParameters(IParser):
__ndarray_name = QualifiedName.from_str("numpy.ndarray")

Expand All @@ -594,24 +742,40 @@ def parse_annotation_str(
return result


class FixScipyTypeArguments(IParser):
def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str)

if not isinstance(result, ResolvedType):
return result

# scipy.sparse arrays/matrices are not currently generic and do not accept type
# arguments
if result.name[:2] == ("scipy", "sparse"):
result.parameters = None

return result


class FixNumpyDtype(IParser):
__numpy_dtype = QualifiedName.from_str("numpy.dtype")

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str)
if (
not isinstance(result, ResolvedType)
or len(result.name) != 1
or result.parameters is not None
):
return result

word = result.name[0]
if word != Identifier("dtype"):
if not isinstance(result, ResolvedType) or result.parameters:
return result
return ResolvedType(name=self.__numpy_dtype)

# numpy.dtype is generic and should have a type argument
if result.name[:1] == ("dtype",) or result.name[:2] == ("numpy", "dtype"):
result.name = self.__numpy_dtype
result.parameters = [self.parse_annotation_str("Any")]

return result


class FixNumpyArrayFlags(IParser):
Expand Down
3 changes: 3 additions & 0 deletions pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Property,
QualifiedName,
ResolvedType,
TypeVar_,
Value,
)

Expand Down Expand Up @@ -103,6 +104,8 @@ def handle_module(
result.sub_modules.append(obj)
elif isinstance(obj, Attribute):
result.attributes.append(obj)
elif isinstance(obj, TypeVar_):
result.type_vars.append(obj)
elif obj is None:
pass
else:
Expand Down
7 changes: 7 additions & 0 deletions pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Module,
Property,
ResolvedType,
TypeVar_,
Value,
)

Expand Down Expand Up @@ -81,6 +82,9 @@ def print_class(self, class_: Class) -> list[str]:
*indent_lines(self.print_class_body(class_)),
]

def print_type_var(self, type_var: TypeVar_) -> list[str]:
return [str(type_var)]

def print_class_body(self, class_: Class) -> list[str]:
result = []
if class_.doc is not None:
Expand Down Expand Up @@ -215,6 +219,9 @@ def print_module(self, module: Module) -> list[str]:
result.extend(self.print_attribute(attr))
break

for type_var in sorted(module.type_vars, key=lambda t: t.name):
result.extend(self.print_type_var(type_var))

for class_ in sorted(module.classes, key=lambda c: c.name):
result.extend(self.print_class(class_))

Expand Down
Loading

0 comments on commit 60728d3

Please sign in to comment.