Skip to content

Commit

Permalink
feat: add items iterator
Browse files Browse the repository at this point in the history
We extend the public API with an iterator that allows retrieving
all the configuration items. This returns a mapping between
attribute paths and environment variable instances. We also attach
the full variable name to each EnvVariable instance, which is
accessible via the ``full_name`` property.
  • Loading branch information
P403n1x87 committed Oct 21, 2024
1 parent ba018d7 commit fd66b90
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 38 deletions.
120 changes: 84 additions & 36 deletions envier/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import deque
from collections import namedtuple
import os
import typing as t
Expand Down Expand Up @@ -68,6 +69,12 @@ def __init__(
self.help_type = help_type
self.help_default = help_default

self._full_name = _normalized(name) # Will be set by the EnvMeta metaclass

@property
def full_name(self) -> str:
return f"_{self._full_name}" if self.private else self._full_name

def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any:
if _type is bool:
return t.cast(T, raw.lower() in env.__truthy__)
Expand Down Expand Up @@ -100,9 +107,7 @@ def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any:
def _retrieve(self, env: "Env", prefix: str) -> T:
source = env.source

full_name = prefix + _normalized(self.name)
if self.private:
full_name = f"_{full_name}"
full_name = self.full_name
raw = source.get(full_name.format(**env.dynamic))
if raw is None and self.deprecations:
for name, deprecated_when, removed_when in self.deprecations:
Expand Down Expand Up @@ -167,10 +172,8 @@ def __call__(self, env: "Env", prefix: str) -> T:
try:
self.validator(value)
except ValueError as e:
full_name = prefix + _normalized(self.name)
raise ValueError(
"Invalid value for environment variable %s: %s" % (full_name, e)
)
msg = f"Invalid value for environment variable {self.full_name}: {e}"
raise ValueError(msg)

return value

Expand All @@ -191,7 +194,22 @@ def __call__(self, env: "Env") -> T:
return value


class Env(object):
class EnvMeta(type):
def __new__(
cls, name: str, bases: t.Tuple[t.Type], ns: t.Dict[str, t.Any]
) -> t.Any:
env = t.cast("Env", super().__new__(cls, name, bases, ns))

prefix = ns.get("__prefix__")
if prefix:
for v in env.values(recursive=True):
if isinstance(v, EnvVariable):
v._full_name = f"{_normalized(prefix)}_{v._full_name}".upper()

return env


class Env(metaclass=EnvMeta):
"""Env base class.
This class is meant to be subclassed. The configuration is declared by using
Expand Down Expand Up @@ -336,26 +354,42 @@ def d(
return DerivedVariable(type, derivation)

@classmethod
def keys(cls) -> t.Iterator[str]:
"""Return the names of all the items."""
return (
k
for k, v in cls.__dict__.items()
if isinstance(v, (EnvVariable, DerivedVariable))
or isinstance(v, type)
and issubclass(v, Env)
)
def items(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[t.Tuple[str, t.Union[EnvVariable, DerivedVariable]]]:
classes = (EnvVariable, DerivedVariable) if include_derived else (EnvVariable,)
q: t.Deque[t.Tuple[t.Tuple[str], t.Type["Env"]]] = deque()
path: t.Tuple[str] = tuple() # type: ignore[assignment]
q.append((path, cls))
while q:
path, env = q.popleft()
for k, v in env.__dict__.items():
if isinstance(v, classes):
yield (
".".join((*path, k)),
t.cast(t.Union[EnvVariable, DerivedVariable], v),
)
elif isinstance(v, type) and issubclass(v, Env) and recursive:
item_name = getattr(v, "__item__", k)
if item_name is None:
item_name = k
q.append(((*path, item_name), v)) # type: ignore[arg-type]

@classmethod
def values(cls) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]:
"""Return the names of all the items."""
return (
v
for v in cls.__dict__.values()
if isinstance(v, (EnvVariable, DerivedVariable))
or isinstance(v, type)
and issubclass(v, Env)
)
def keys(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[str]:
"""Return the name of all the configuration items."""
for k, _ in cls.items(recursive, include_derived):
yield k

@classmethod
def values(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]:
"""Return the value of all the configuration items."""
for _, v in cls.items(recursive, include_derived):
yield v

@classmethod
def include(
Expand All @@ -371,14 +405,6 @@ def include(
operation would result in some variables being overwritten. This can
be disabled by setting the ``overwrite`` argument to ``True``.
"""
if namespace is not None:
if not overwrite and hasattr(cls, namespace):
raise ValueError("Namespace already in use: {}".format(namespace))

setattr(cls, namespace, env_spec)

return None

# Pick only the attributes that define variables.
to_include = {
k: v
Expand All @@ -387,14 +413,36 @@ def include(
or isinstance(v, type)
and issubclass(v, Env)
}

if not overwrite:
overlap = set(cls.__dict__.keys()) & set(to_include.keys())
if overlap:
raise ValueError("Configuration clashes detected: {}".format(overlap))

own_prefix = _normalized(getattr(cls, "__prefix__", ""))

if namespace is not None:
if not overwrite and hasattr(cls, namespace):
raise ValueError("Namespace already in use: {}".format(namespace))

if getattr(cls, namespace, None) is not env_spec:
setattr(cls, namespace, env_spec)

if own_prefix:
for _, v in to_include.items():
if isinstance(v, EnvVariable):
v._full_name = f"{own_prefix}_{v._full_name}"

return None

other_prefix = getattr(env_spec, "__prefix__", "")
for k, v in to_include.items():
setattr(cls, k, v)
if getattr(cls, k, None) is not v:
setattr(cls, k, v)
if isinstance(v, EnvVariable):
if other_prefix:
v._full_name = v._full_name[len(other_prefix) + 1 :] # noqa
if own_prefix:
v._full_name = f"{own_prefix}_{v._full_name}"

@classmethod
def help_info(
Expand Down
50 changes: 48 additions & 2 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ class GlobalConfig(Env):
service = ServiceConfig

config = GlobalConfig()
assert set(config.keys()) == {"debug_mode", "service"}
assert set(config.keys()) == {"debug_mode"}
assert set(config.keys(recursive=True)) == {
"debug_mode",
"service.host",
"service.port",
}
assert config.service.port == 8080


Expand All @@ -178,11 +183,23 @@ class ServiceConfig(Env):

host = Env.var(str, "host", default="localhost")
port = Env.var(int, "port", default=3000)
_private = Env.var(int, "private", default=42, private=True)

config = GlobalConfig()
assert set(config.keys()) == {"debug_mode", "service"}
assert set(config.keys()) == {"debug_mode"}
assert set(config.keys(recursive=True)) == {
"debug_mode",
"service.host",
"service.port",
"service._private",
}
assert config.service.port == 8080

assert GlobalConfig.debug_mode.full_name == "MYAPP_DEBUG"
assert GlobalConfig.service.host.full_name == "MYAPP_SERVICE_HOST"
assert GlobalConfig.service.port.full_name == "MYAPP_SERVICE_PORT"
assert GlobalConfig.service._private.full_name == "_MYAPP_SERVICE_PRIVATE"


def test_env_include():
class GlobalConfig(Env):
Expand Down Expand Up @@ -383,3 +400,32 @@ class Config(Env):
("_PRIVATE_FOO", "int", "42", ""),
("PUBLIC_FOO", "int", "42", ""),
}

assert Config.private.full_name == "_PRIVATE_FOO"


def test_env_items(monkeypatch):
monkeypatch.setenv("MYAPP_SERVICE_PORT", "8080")

class GlobalConfig(Env):
__prefix__ = "myapp"

debug_mode = Env.var(bool, "debug", default=False)

class ServiceConfig(Env):
__item__ = __prefix__ = "service"

host = Env.var(str, "host", default="localhost")
port = Env.var(int, "port", default=3000)
_private = Env.var(int, "private", default=42, private=True)

items = list(GlobalConfig.items())
assert items == [("debug_mode", GlobalConfig.debug_mode)]

items = list(GlobalConfig.items(recursive=True))
assert items == [
("debug_mode", GlobalConfig.debug_mode),
("service.host", GlobalConfig.ServiceConfig.host),
("service.port", GlobalConfig.ServiceConfig.port),
("service._private", GlobalConfig.ServiceConfig._private),
]

0 comments on commit fd66b90

Please sign in to comment.