Skip to content

Commit

Permalink
Add config sources
Browse files Browse the repository at this point in the history
Add support for setting and storing configuration from additional
sources. This enables a setting to have values coming from a default,
environment and programmatic sources.

The default priority of the sources is (lowest to highest):

1. Default
2. Environment
3. Programmatic

The active source can be retrieved using a new method: `get_source()`:

```python
class Config:
        foo = Env.var(int, "FOO", default=0)

cfg = Config()
assert cfg.foo == 0
assert cfg.source_type("foo") == "default"

os.environ["FOO"] = 123
cfg = Config()
assert cfg.foo == 123
assert cfg.source_type("foo") == "environment"
```
  • Loading branch information
Kyle-Verhoog committed Aug 29, 2023
1 parent 91a8446 commit 2a35499
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 19 deletions.
99 changes: 80 additions & 19 deletions envier/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
import os
from typing import Any
from typing import Callable
from typing import DefaultDict
from typing import Dict
from typing import Generic
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Type
Expand Down Expand Up @@ -80,15 +83,15 @@ def __init__(
self.help_default = help_default

def _retrieve(self, env, prefix):
# type: (Env, str) -> T
source = env.source
# type: (Env, str) -> Tuple[ConfigSourceType, T]
env_source = env.env_source

full_name = prefix + _normalized(self.name)
raw = source.get(full_name)
raw = env_source.get(full_name)
if raw is None and self.deprecations:
for name, deprecated_when, removed_when in self.deprecations:
full_deprecated_name = prefix + _normalized(name)
raw = source.get(full_deprecated_name)
raw = env_source.get(full_deprecated_name)
if raw is not None:
deprecated_when_message = (
" in version %s" % deprecated_when
Expand All @@ -114,12 +117,13 @@ def _retrieve(self, env, prefix):

if raw is None:
if not isinstance(self.default, NoDefaultType):
return self.default
return ("default", self.default)

raise KeyError(
"Mandatory environment variable {} is not set".format(full_name)
)

value = raw # type: Union[T, str]
if self.parser is not None:
parsed = self.parser(raw)
if not _check_type(parsed, self.type):
Expand All @@ -128,37 +132,37 @@ def _retrieve(self, env, prefix):
type(parsed), self.type
)
)
return parsed
value = parsed

if self.type is bool:
return cast(T, raw.lower() in env.__truthy__)
value = cast(T, raw.lower() in env.__truthy__)
elif self.type in (list, tuple, set):
collection = raw.split(env.__item_separator__)
return cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator]
value = cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator]
elif self.type is dict:
d = dict(
_.split(env.__value_separator__, 1)
for _ in raw.split(env.__item_separator__)
)
if self.map is not None:
d = dict(self.map(*_) for _ in d.items())
return cast(T, d)
value = cast(T, d)

if _check_type(raw, self.type):
return cast(T, raw)
value = cast(T, raw)

if hasattr(self.type, "__origin__") and self.type.__origin__ is Union: # type: ignore[attr-defined,union-attr]
for t in self.type.__args__: # type: ignore[attr-defined,union-attr]
try:
return cast(T, t(raw))
value = cast(T, t(raw))
except TypeError:
pass

return self.type(raw) # type: ignore[call-arg,operator]
return "environment", self.type(value) # type: ignore[call-arg,operator]

def __call__(self, env, prefix):
# type: (Env, str) -> T
value = self._retrieve(env, prefix)
# type: (Env, str) -> Tuple[ConfigSourceType, T]
source, value = self._retrieve(env, prefix)

if self.validator is not None:
try:
Expand All @@ -169,7 +173,7 @@ def __call__(self, env, prefix):
"Invalid value for environment variable %s: %s" % (full_name, e)
)

return value
return source, value


class DerivedVariable(Generic[T]):
Expand All @@ -190,6 +194,37 @@ def __call__(self, env):
return value


ConfigSourceType = Literal["default", "environment", "programmatic"]


class _EnvSource(object):
_sentinel = object()

def __init__(self, source_precedence):
# type: (List[ConfigSourceType]) -> None
self._source_precedence = source_precedence
self._sources = {} # type: Dict[ConfigSourceType, Any]
for s in self._source_precedence:
self._sources[s] = self._sentinel

def set_source(self, source, value):
# type: (ConfigSourceType, Any) -> None
self._sources[source] = value

def value(self):
for s in self._source_precedence:
if self._sources[s] is not self._sentinel:
return s
return None

def value_source_type(self):
# type: () -> ConfigSourceType
for s in self._source_precedence:
if self._sources[s] is not self._sentinel:
return s
raise ValueError("No source set for setting")


class Env(object):
"""Env base class.
Expand Down Expand Up @@ -226,9 +261,20 @@ class Env(object):
__item_separator__ = ","
__value_separator__ = ":"

def __init__(self, source=None, parent=None):
def __init__(self, env_source=None, parent=None):
# type: (Optional[Dict[str, str]], Optional[Env]) -> None
self.source = source or os.environ
# Has to come first to avoid issues with __setattr__
self._items = defaultdict(
lambda: _EnvSource(
[
"programmatic",
"environment",
"default",
]
)
) # type: DefaultDict[str, _EnvSource]

self.env_source = env_source or os.environ
self.parent = parent

self._full_prefix = (
Expand All @@ -243,20 +289,35 @@ def __init__(self, source=None, parent=None):
derived = []
for name, e in list(self.__class__.__dict__.items()):
if isinstance(e, EnvVariable):
setattr(self, name, e(self, self._full_prefix))
source, v = e(self, self._full_prefix)
self.set_attr_source_value(name, source, v)
elif isinstance(e, type) and issubclass(e, Env):
if e.__item__ is not None and e.__item__ != name:
# Move the subclass to the __item__ attribute
setattr(self.spec, e.__item__, e)
delattr(self.spec, name)
name = e.__item__
setattr(self, name, e(source, self))
setattr(self, name, e(env_source, self))
elif isinstance(e, DerivedVariable):
derived.append((name, e))

for n, d in derived:
setattr(self, n, d(self))

def __setattr__(self, name, value):
if name != "_items":
self._items[name].set_source("programmatic", value)
super(Env, self).__setattr__(name, value)

def set_attr_source_value(self, name, source, value):
# type: (str, ConfigSourceType, Any) -> None
self._items[name].set_source(source, value)
super(Env, self).__setattr__(name, value)

def source_type(self, name):
# type: (str) -> ConfigSourceType
return self._items[name].value_source_type()

@classmethod
def var(
cls,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,32 @@ def validate(value):
Config()
else:
assert Config().foo == value


def test_env_multi_source(monkeypatch):
"""Ensure multiple sources are supported."""

class Config(Env):
foo = Env.var(int, "FOO", default=0)

# Test the default source.
cfg = Config()
assert cfg.foo == 0
assert cfg.source_type("foo") == "default"
cfg.set_attr_source_value("foo", "default", 5)
assert cfg.foo == 5

# Test the environment source.
monkeypatch.setenv("FOO", "1")
cfg = Config()
assert cfg.foo == 1
assert cfg.source_type("foo") == "environment"
cfg.set_attr_source_value("foo", "environment", 6)
assert cfg.foo == 6

# Test the programmatic source.
cfg.foo = 2
assert cfg.foo == 2
assert cfg.source_type("foo") == "programmatic"
cfg.set_attr_source_value("foo", "programmatic", 10)
assert cfg.foo == 10

0 comments on commit 2a35499

Please sign in to comment.