Skip to content

Commit

Permalink
feat(typing): Generate annotations based on known datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 6, 2024
1 parent b89e6dc commit 7b0fe29
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tools/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import polars as pl

from tools.codemod import ruff
from tools.datasets.github import GitHub
from tools.datasets.models import QueryTree
from tools.datasets.npm import Npm
from tools.schemapi import utils

if TYPE_CHECKING:
import sys
Expand All @@ -37,10 +39,17 @@
else:
from typing_extensions import TypeAlias

_PathAlias: TypeAlias = Literal["npm_tags", "gh_tags", "gh_trees"]

WorkInProgress: TypeAlias = Any

__all__ = ["app", "data"]

HEADER_COMMENT = """\
# The contents of this file are automatically written by
# tools/datasets.__init__.py. Do not modify directly.
"""


class Application:
"""
Expand Down Expand Up @@ -78,6 +87,14 @@ def github(self) -> GitHub:
def npm(self) -> Npm:
return self._npm

@property
def _aliases(self) -> dict[_PathAlias, Path]:
return {
"npm_tags": self.npm._paths["tags"],
"gh_tags": self.github._paths["tags"],
"gh_trees": self.github._paths["trees"],
}

def refresh(self) -> pl.DataFrame:
npm_tags = self.npm.tags()
self.write_parquet(npm_tags, self.npm._paths["tags"])
Expand All @@ -89,6 +106,21 @@ def refresh(self) -> pl.DataFrame:
self.write_parquet(gh_trees, self.github._paths["trees"])
return gh_trees

def read(self, name: _PathAlias, /) -> pl.DataFrame:
"""Read existing metadata from file."""
return pl.read_parquet(self._from_alias(name))

def scan(self, name: _PathAlias, /) -> pl.LazyFrame:
"""Scan existing metadata from file."""
return pl.scan_parquet(self._from_alias(name))

def _from_alias(self, name: _PathAlias, /) -> Path:
if name not in {"npm_tags", "gh_tags", "gh_trees"}:
msg = f'Expected one of {["npm_tags", "gh_tags", "gh_trees"]!r}, but got: {name!r}'
raise TypeError(msg)
else:
return self._aliases[name]

def write_parquet(self, frame: pl.DataFrame | pl.LazyFrame, fp: Path, /) -> None:
"""Write ``frame`` to ``fp``, with some extra safety."""
if not fp.exists():
Expand Down Expand Up @@ -118,6 +150,36 @@ def write_parquet(self, frame: pl.DataFrame | pl.LazyFrame, fp: Path, /) -> None
"""


def generate_datasets_typing(application: Application, output: Path, /) -> None:
app = application
tags = app.scan("gh_tags").select("tag").collect().to_series()
names = (
app.scan("gh_trees")
.filter("ext_supported")
.unique("name_js")
.select("name_js")
.sort("name_js")
.collect()
.to_series()
)
NAME = "DatasetName"
TAG = "VersionTag"
EXT = "Extension"
contents = (
f"{HEADER_COMMENT}",
"from __future__ import annotations\n",
"import sys",
"from typing import Literal, TYPE_CHECKING",
utils.import_typing_extensions((3, 10), "TypeAlias"),
"\n",
f"__all__ = {[NAME, TAG, EXT]}\n\n"
f"{NAME}: TypeAlias = {utils.spell_literal(names)}",
f"{TAG}: TypeAlias = {utils.spell_literal(tags)}",
f'{EXT}: TypeAlias = {utils.spell_literal([".csv", ".json", ".tsv", ".arrow"])}',
)
ruff.write_lint_format(output, contents)


def is_ext_supported(suffix: str) -> TypeIs[ExtSupported]:
return suffix in {".csv", ".json", ".tsv", ".arrow"}

Expand Down
137 changes: 137 additions & 0 deletions tools/datasets/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# The contents of this file are automatically written by
# tools/datasets.__init__.py. Do not modify directly.

from __future__ import annotations

import sys
from typing import Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


__all__ = ["DatasetName", "Extension", "VersionTag"]

DatasetName: TypeAlias = Literal[
"airports",
"annual-precip",
"anscombe",
"barley",
"birdstrikes",
"budget",
"budgets",
"burtin",
"cars",
"climate",
"co2-concentration",
"countries",
"crimea",
"disasters",
"driving",
"earthquakes",
"flare",
"flare-dependencies",
"flights-10k",
"flights-200k",
"flights-20k",
"flights-2k",
"flights-3m",
"flights-5k",
"flights-airport",
"football",
"gapminder",
"gapminder-health-income",
"github",
"global-temp",
"graticule",
"income",
"iowa-electricity",
"iris",
"jobs",
"la-riots",
"londonBoroughs",
"londonCentroids",
"londonTubeLines",
"lookup_groups",
"lookup_people",
"miserables",
"monarchs",
"movies",
"normal-2d",
"obesity",
"ohlc",
"penguins",
"platformer-terrain",
"points",
"political-contributions",
"population",
"population_engineers_hurricanes",
"seattle-temps",
"seattle-weather",
"seattle-weather-hourly-normals",
"sf-temps",
"sp500",
"sp500-2000",
"stocks",
"udistrict",
"unemployment",
"unemployment-across-industries",
"uniform-2d",
"us-10m",
"us-employment",
"us-state-capitals",
"volcano",
"weather",
"weball26",
"wheat",
"windvectors",
"world-110m",
"zipcodes",
]
VersionTag: TypeAlias = Literal[
"v2.9.0",
"v2.8.1",
"v2.8.0",
"v2.7.0",
"v2.5.4",
"v2.5.3",
"v2.5.3-next.0",
"v2.5.2",
"v2.5.2-next.0",
"v2.5.1",
"v2.5.1-next.0",
"v2.5.0",
"v2.5.0-next.0",
"v2.4.0",
"v2.3.1",
"v2.3.0",
"v2.1.0",
"v2.0.0",
"v1.31.1",
"v1.31.0",
"v1.30.4",
"v1.30.3",
"v1.30.2",
"v1.30.1",
"v1.29.0",
"v1.24.0",
"v1.22.0",
"v1.21.1",
"v1.21.0",
"v1.20.0",
"v1.19.0",
"v1.18.0",
"v1.17.0",
"v1.16.0",
"v1.15.0",
"v1.14.0",
"v1.12.0",
"v1.11.0",
"v1.10.0",
"v1.8.0",
"v1.7.0",
"v1.5.0",
]
Extension: TypeAlias = Literal[".csv", ".json", ".tsv", ".arrow"]

0 comments on commit 7b0fe29

Please sign in to comment.