Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config sources #20

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions envier/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import defaultdict
import os
import sys
from typing import Any
from typing import Callable
from typing import DefaultDict
from typing import Dict
from typing import Generic
from typing import Iterator
Expand All @@ -14,6 +17,12 @@
import warnings


if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal


class NoDefaultType(object):
def __str__(self):
return ""
Expand Down Expand Up @@ -80,7 +89,7 @@ def __init__(
self.help_default = help_default

def _retrieve(self, env, prefix):
# type: (Env, str) -> T
# type: (Env, str) -> Tuple[ConfigSourceType, T]
source = env.source

full_name = prefix + _normalized(self.name)
Expand Down Expand Up @@ -114,7 +123,7 @@ def _retrieve(self, env, prefix):

if raw is None:
if not isinstance(self.default, NoDefaultType):
return self.default
return "default", self.default

raise KeyError(
"Mandatory environment variable {} is not set".format(full_name)
Expand All @@ -128,37 +137,37 @@ def _retrieve(self, env, prefix):
type(parsed), self.type
)
)
return parsed
return "environment", parsed # type: ignore[return-value]

if self.type is bool:
return cast(T, raw.lower() in env.__truthy__)
return "environment", cast(T, raw.lower() in env.__truthy__)
elif self.type in (list, tuple, set):
collection = raw.split(env.__item_separator__)
return cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator]
return "environment", cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator]
elif self.type is dict:
d = dict(
_.split(env.__value_separator__, 1)
for _ in raw.split(env.__item_separator__)
)
if self.map is not None:
d = dict(self.map(*_) for _ in d.items())
return cast(T, d)
return "environment", cast(T, d)

if _check_type(raw, self.type):
return cast(T, raw)
return "environment", cast(T, raw)

if hasattr(self.type, "__origin__") and self.type.__origin__ is Union: # type: ignore[attr-defined,union-attr]
for t in self.type.__args__: # type: ignore[attr-defined,union-attr]
try:
return cast(T, t(raw))
return "environment", cast(T, t(raw))
except TypeError:
pass

return self.type(raw) # type: ignore[call-arg,operator]
return "environment", self.type(raw) # type: ignore[call-arg,operator,return-type]

def __call__(self, env, prefix):
# type: (Env, str) -> T
value = self._retrieve(env, prefix)
# type: (Env, str) -> Tuple[ConfigSourceType, T]
source, value = self._retrieve(env, prefix)

if self.validator is not None:
try:
Expand All @@ -169,7 +178,7 @@ def __call__(self, env, prefix):
"Invalid value for environment variable %s: %s" % (full_name, e)
)

return value
return source, value


class DerivedVariable(Generic[T]):
Expand All @@ -190,6 +199,37 @@ def __call__(self, env):
return value


ConfigSourceType = Literal["default", "environment", "code"]


class _EnvSource(object):
_sentinel = object()

def __init__(self, source_precedence):
# type: (List[ConfigSourceType]) -> None
self._source_precedence = source_precedence
self._sources = {} # type: Dict[ConfigSourceType, Any]
for s in self._source_precedence:
self._sources[s] = self._sentinel

def set_source(self, source, value):
# type: (ConfigSourceType, Any) -> None
self._sources[source] = value

def value(self):
for s in self._source_precedence:
if self._sources[s] is not self._sentinel:
return self._sources[s]
return None

def value_source_type(self):
# type: () -> ConfigSourceType
for s in self._source_precedence:
if self._sources[s] is not self._sentinel:
return s
raise ValueError("No source set for setting")


class Env(object):
"""Env base class.

Expand Down Expand Up @@ -228,7 +268,18 @@ class Env(object):

def __init__(self, source=None, parent=None):
# type: (Optional[Dict[str, str]], Optional[Env]) -> None
self.source = source or os.environ
# Has to come first to avoid issues with __setattr__
self._items = defaultdict(
lambda: _EnvSource(
[
"environment",
"default",
"code",
]
)
) # type: DefaultDict[str, _EnvSource]

self.source = source or dict(os.environ)
self.parent = parent

self._full_prefix = (
Expand All @@ -243,20 +294,44 @@ def __init__(self, source=None, parent=None):
derived = []
for name, e in list(self.__class__.__dict__.items()):
if isinstance(e, EnvVariable):
setattr(self, name, e(self, self._full_prefix))
s, v = e(self, self._full_prefix)
self.set_attr_source_value(name, s, v)
elif isinstance(e, type) and issubclass(e, Env):
if e.__item__ is not None and e.__item__ != name:
# Move the subclass to the __item__ attribute
setattr(self.spec, e.__item__, e)
delattr(self.spec, name)
name = e.__item__
setattr(self, name, e(source, self))
setattr(self, name, e(source=self.source, parent=self))
elif isinstance(e, DerivedVariable):
derived.append((name, e))

for n, d in derived:
setattr(self, n, d(self))

def __setattr__(self, name, value):
if name != "_items":
self._items[name].set_source("code", value)
super(Env, self).__setattr__(name, value)

def set_attr_source_value(self, name, source, value):
# type: (str, ConfigSourceType, Any) -> None
self._items[name].set_source(source, value)
super(Env, self).__setattr__(name, value)

def source_type(self, name):
# type: (str) -> ConfigSourceType
return self._items[name].value_source_type()

def __getitem__(self, item):
# type: (str) -> Any
if item in self._items:
return self._items[item].value()
raise AttributeError(item)

def __contains__(self, item):
return item in self._items

@classmethod
def var(
cls,
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
license="MIT",
packages=find_packages(exclude=["tests*"]),
python_requires=">=2.7",
install_requires=["typing; python_version<'3.5'"],
install_requires=[
"typing; python_version<'3.5'",
"typing_extensions; python_version<'3.8'",
],
extras_require={"mypy": ["mypy"]},
setup_requires=["setuptools_scm"],
use_scm_version=True,
Expand Down
36 changes: 36 additions & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,39 @@ def validate(value):
Config()
else:
assert Config().foo == value


def test_getitem():
class Config(Env):
foo = Env.var(int, "FOO", default=42)

assert Config()["foo"] == 42


def test_env_multi_source(monkeypatch):
"""Ensure multiple sources are supported."""

class Config(Env):
foo = Env.var(int, "FOO", default=0)

# Test the default source.
cfg = Config()
assert cfg.foo == 0
assert cfg.source_type("foo") == "default"
cfg.set_attr_source_value("foo", "default", 5)
assert cfg.foo == 5

# Test the environment source.
monkeypatch.setenv("FOO", "1")
cfg = Config()
assert cfg.foo == 1
assert cfg.source_type("foo") == "environment"
cfg.set_attr_source_value("foo", "environment", 6)
assert cfg.foo == 6

# Test the programmatic source.
cfg.foo = 2
assert cfg.foo == 2
assert cfg.source_type("foo") == "code"
cfg.set_attr_source_value("foo", "code", 10)
assert cfg.foo == 10
Loading