From 04a7c1343c4ef82c0fe8a228b16cc6817c74a573 Mon Sep 17 00:00:00 2001 From: Kyle Verhoog Date: Tue, 29 Aug 2023 16:30:37 -0400 Subject: [PATCH 1/3] Add config sources Add support for setting and storing configuration from additional sources. This enables a setting to have values coming from a default, environment and programmatic sources. The default priority of the sources is (lowest to highest): 1. Default 2. Environment 3. Programmatic The active source can be retrieved using a new method: `get_source()`: ```python class Config: foo = Env.var(int, "FOO", default=0) cfg = Config() assert cfg.foo == 0 assert cfg.source_type("foo") == "default" os.environ["FOO"] = 123 cfg = Config() assert cfg.foo == 123 assert cfg.source_type("foo") == "environment" ``` --- envier/env.py | 109 ++++++++++++++++++++++++++++++++++++++-------- tests/test_env.py | 36 +++++++++++++++ 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/envier/env.py b/envier/env.py index bb42f1a..2d39431 100644 --- a/envier/env.py +++ b/envier/env.py @@ -1,10 +1,13 @@ +from collections import defaultdict import os from typing import Any from typing import Callable +from typing import DefaultDict from typing import Dict from typing import Generic from typing import Iterator from typing import List +from typing import Literal from typing import Optional from typing import Tuple from typing import Type @@ -80,15 +83,15 @@ def __init__( self.help_default = help_default def _retrieve(self, env, prefix): - # type: (Env, str) -> T - source = env.source + # type: (Env, str) -> Tuple[ConfigSourceType, T] + env_source = env.env_source full_name = prefix + _normalized(self.name) - raw = source.get(full_name) + raw = env_source.get(full_name) if raw is None and self.deprecations: for name, deprecated_when, removed_when in self.deprecations: full_deprecated_name = prefix + _normalized(name) - raw = source.get(full_deprecated_name) + raw = env_source.get(full_deprecated_name) if raw is not None: deprecated_when_message = ( " in version %s" % deprecated_when @@ -114,12 +117,13 @@ def _retrieve(self, env, prefix): if raw is None: if not isinstance(self.default, NoDefaultType): - return self.default + return "default", self.default raise KeyError( "Mandatory environment variable {} is not set".format(full_name) ) + value = raw # type: Union[T, str] if self.parser is not None: parsed = self.parser(raw) if not _check_type(parsed, self.type): @@ -128,13 +132,13 @@ def _retrieve(self, env, prefix): type(parsed), self.type ) ) - return parsed + return "environment", parsed # type: ignore[return-value] if self.type is bool: - return cast(T, raw.lower() in env.__truthy__) + value = cast(T, raw.lower() in env.__truthy__) elif self.type in (list, tuple, set): collection = raw.split(env.__item_separator__) - return cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator] + value = cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator] elif self.type is dict: d = dict( _.split(env.__value_separator__, 1) @@ -142,23 +146,24 @@ def _retrieve(self, env, prefix): ) if self.map is not None: d = dict(self.map(*_) for _ in d.items()) - return cast(T, d) + value = cast(T, d) if _check_type(raw, self.type): - return cast(T, raw) + value = cast(T, raw) if hasattr(self.type, "__origin__") and self.type.__origin__ is Union: # type: ignore[attr-defined,union-attr] for t in self.type.__args__: # type: ignore[attr-defined,union-attr] try: - return cast(T, t(raw)) + value = cast(T, t(raw)) except TypeError: pass - return self.type(raw) # type: ignore[call-arg,operator] + return "environment", value # type: ignore[call-arg,operator] + # return "environment", self.type(value) # type: ignore[call-arg,operator] def __call__(self, env, prefix): - # type: (Env, str) -> T - value = self._retrieve(env, prefix) + # type: (Env, str) -> Tuple[ConfigSourceType, T] + source, value = self._retrieve(env, prefix) if self.validator is not None: try: @@ -169,7 +174,7 @@ def __call__(self, env, prefix): "Invalid value for environment variable %s: %s" % (full_name, e) ) - return value + return source, value class DerivedVariable(Generic[T]): @@ -190,6 +195,37 @@ def __call__(self, env): return value +ConfigSourceType = Literal["default", "environment", "programmatic"] + + +class _EnvSource(object): + _sentinel = object() + + def __init__(self, source_precedence): + # type: (List[ConfigSourceType]) -> None + self._source_precedence = source_precedence + self._sources = {} # type: Dict[ConfigSourceType, Any] + for s in self._source_precedence: + self._sources[s] = self._sentinel + + def set_source(self, source, value): + # type: (ConfigSourceType, Any) -> None + self._sources[source] = value + + def value(self): + for s in self._source_precedence: + if self._sources[s] is not self._sentinel: + return self._sources[s] + return None + + def value_source_type(self): + # type: () -> ConfigSourceType + for s in self._source_precedence: + if self._sources[s] is not self._sentinel: + return s + raise ValueError("No source set for setting") + + class Env(object): """Env base class. @@ -226,9 +262,20 @@ class Env(object): __item_separator__ = "," __value_separator__ = ":" - def __init__(self, source=None, parent=None): + def __init__(self, env_source=None, parent=None): # type: (Optional[Dict[str, str]], Optional[Env]) -> None - self.source = source or os.environ + # Has to come first to avoid issues with __setattr__ + self._items = defaultdict( + lambda: _EnvSource( + [ + "programmatic", + "environment", + "default", + ] + ) + ) # type: DefaultDict[str, _EnvSource] + + self.env_source = env_source or os.environ self.parent = parent self._full_prefix = ( @@ -243,20 +290,44 @@ def __init__(self, source=None, parent=None): derived = [] for name, e in list(self.__class__.__dict__.items()): if isinstance(e, EnvVariable): - setattr(self, name, e(self, self._full_prefix)) + source, v = e(self, self._full_prefix) + self.set_attr_source_value(name, source, v) elif isinstance(e, type) and issubclass(e, Env): if e.__item__ is not None and e.__item__ != name: # Move the subclass to the __item__ attribute setattr(self.spec, e.__item__, e) delattr(self.spec, name) name = e.__item__ - setattr(self, name, e(source, self)) + setattr(self, name, e(env_source, self)) elif isinstance(e, DerivedVariable): derived.append((name, e)) for n, d in derived: setattr(self, n, d(self)) + def __setattr__(self, name, value): + if name != "_items": + self._items[name].set_source("programmatic", value) + super(Env, self).__setattr__(name, value) + + def set_attr_source_value(self, name, source, value): + # type: (str, ConfigSourceType, Any) -> None + self._items[name].set_source(source, value) + super(Env, self).__setattr__(name, value) + + def source_type(self, name): + # type: (str) -> ConfigSourceType + return self._items[name].value_source_type() + + def __getitem__(self, item): + # type: (str) -> Any + if item in self._items: + return self._items[item].value() + raise AttributeError(item) + + def __contains__(self, item): + return item in self._items + @classmethod def var( cls, diff --git a/tests/test_env.py b/tests/test_env.py index 1e77a73..435ad25 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -344,3 +344,39 @@ def validate(value): Config() else: assert Config().foo == value + + +def test_getitem(): + class Config(Env): + foo = Env.var(int, "FOO", default=42) + + assert Config()["foo"] == 42 + + +def test_env_multi_source(monkeypatch): + """Ensure multiple sources are supported.""" + + class Config(Env): + foo = Env.var(int, "FOO", default=0) + + # Test the default source. + cfg = Config() + assert cfg.foo == 0 + assert cfg.source_type("foo") == "default" + cfg.set_attr_source_value("foo", "default", 5) + assert cfg.foo == 5 + + # Test the environment source. + monkeypatch.setenv("FOO", "1") + cfg = Config() + assert cfg.foo == 1 + assert cfg.source_type("foo") == "environment" + cfg.set_attr_source_value("foo", "environment", 6) + assert cfg.foo == 6 + + # Test the programmatic source. + cfg.foo = 2 + assert cfg.foo == 2 + assert cfg.source_type("foo") == "programmatic" + cfg.set_attr_source_value("foo", "programmatic", 10) + assert cfg.foo == 10 From a7a94c4d442cd525d2973eccb4529370c7c23fd4 Mon Sep 17 00:00:00 2001 From: Kyle Verhoog Date: Wed, 25 Oct 2023 13:54:53 -0400 Subject: [PATCH 2/3] simplify --- envier/env.py | 38 +++++++++++++++++++++----------------- setup.py | 5 ++++- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/envier/env.py b/envier/env.py index 2d39431..dd73cea 100644 --- a/envier/env.py +++ b/envier/env.py @@ -1,5 +1,6 @@ from collections import defaultdict import os +import sys from typing import Any from typing import Callable from typing import DefaultDict @@ -7,7 +8,6 @@ from typing import Generic from typing import Iterator from typing import List -from typing import Literal from typing import Optional from typing import Tuple from typing import Type @@ -17,6 +17,12 @@ import warnings +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + + class NoDefaultType(object): def __str__(self): return "" @@ -84,14 +90,14 @@ def __init__( def _retrieve(self, env, prefix): # type: (Env, str) -> Tuple[ConfigSourceType, T] - env_source = env.env_source + source = env.source full_name = prefix + _normalized(self.name) - raw = env_source.get(full_name) + raw = source.get(full_name) if raw is None and self.deprecations: for name, deprecated_when, removed_when in self.deprecations: full_deprecated_name = prefix + _normalized(name) - raw = env_source.get(full_deprecated_name) + raw = source.get(full_deprecated_name) if raw is not None: deprecated_when_message = ( " in version %s" % deprecated_when @@ -123,7 +129,6 @@ def _retrieve(self, env, prefix): "Mandatory environment variable {} is not set".format(full_name) ) - value = raw # type: Union[T, str] if self.parser is not None: parsed = self.parser(raw) if not _check_type(parsed, self.type): @@ -135,10 +140,10 @@ def _retrieve(self, env, prefix): return "environment", parsed # type: ignore[return-value] if self.type is bool: - value = cast(T, raw.lower() in env.__truthy__) + return "environment", cast(T, raw.lower() in env.__truthy__) elif self.type in (list, tuple, set): collection = raw.split(env.__item_separator__) - value = cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator] + return "environment", cast(T, self.type(collection if self.map is None else map(self.map, collection))) # type: ignore[call-arg,arg-type,operator] elif self.type is dict: d = dict( _.split(env.__value_separator__, 1) @@ -146,20 +151,19 @@ def _retrieve(self, env, prefix): ) if self.map is not None: d = dict(self.map(*_) for _ in d.items()) - value = cast(T, d) + return "environment", cast(T, d) if _check_type(raw, self.type): - value = cast(T, raw) + return "environment", cast(T, raw) if hasattr(self.type, "__origin__") and self.type.__origin__ is Union: # type: ignore[attr-defined,union-attr] for t in self.type.__args__: # type: ignore[attr-defined,union-attr] try: - value = cast(T, t(raw)) + return "environment", cast(T, t(raw)) except TypeError: pass - return "environment", value # type: ignore[call-arg,operator] - # return "environment", self.type(value) # type: ignore[call-arg,operator] + return "environment", self.type(raw) # type: ignore[call-arg,operator,return-type] def __call__(self, env, prefix): # type: (Env, str) -> Tuple[ConfigSourceType, T] @@ -262,7 +266,7 @@ class Env(object): __item_separator__ = "," __value_separator__ = ":" - def __init__(self, env_source=None, parent=None): + def __init__(self, source=None, parent=None): # type: (Optional[Dict[str, str]], Optional[Env]) -> None # Has to come first to avoid issues with __setattr__ self._items = defaultdict( @@ -275,7 +279,7 @@ def __init__(self, env_source=None, parent=None): ) ) # type: DefaultDict[str, _EnvSource] - self.env_source = env_source or os.environ + self.source = source or dict(os.environ) self.parent = parent self._full_prefix = ( @@ -290,15 +294,15 @@ def __init__(self, env_source=None, parent=None): derived = [] for name, e in list(self.__class__.__dict__.items()): if isinstance(e, EnvVariable): - source, v = e(self, self._full_prefix) - self.set_attr_source_value(name, source, v) + s, v = e(self, self._full_prefix) + self.set_attr_source_value(name, s, v) elif isinstance(e, type) and issubclass(e, Env): if e.__item__ is not None and e.__item__ != name: # Move the subclass to the __item__ attribute setattr(self.spec, e.__item__, e) delattr(self.spec, name) name = e.__item__ - setattr(self, name, e(env_source, self)) + setattr(self, name, e(source=self.source, parent=self)) elif isinstance(e, DerivedVariable): derived.append((name, e)) diff --git a/setup.py b/setup.py index 696c5ef..95e3b22 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,10 @@ license="MIT", packages=find_packages(exclude=["tests*"]), python_requires=">=2.7", - install_requires=["typing; python_version<'3.5'"], + install_requires=[ + "typing; python_version<'3.5'", + "typing_extensions; python_version<'3.8'", + ], extras_require={"mypy": ["mypy"]}, setup_requires=["setuptools_scm"], use_scm_version=True, From 563021236737cbad9b3f0865cc6a87452191f7a9 Mon Sep 17 00:00:00 2001 From: Kyle Verhoog Date: Fri, 27 Oct 2023 10:26:56 -0400 Subject: [PATCH 3/3] use "code" instead of "programmatic" --- envier/env.py | 6 +++--- tests/test_env.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/envier/env.py b/envier/env.py index dd73cea..9a189e0 100644 --- a/envier/env.py +++ b/envier/env.py @@ -199,7 +199,7 @@ def __call__(self, env): return value -ConfigSourceType = Literal["default", "environment", "programmatic"] +ConfigSourceType = Literal["default", "environment", "code"] class _EnvSource(object): @@ -272,9 +272,9 @@ def __init__(self, source=None, parent=None): self._items = defaultdict( lambda: _EnvSource( [ - "programmatic", "environment", "default", + "code", ] ) ) # type: DefaultDict[str, _EnvSource] @@ -311,7 +311,7 @@ def __init__(self, source=None, parent=None): def __setattr__(self, name, value): if name != "_items": - self._items[name].set_source("programmatic", value) + self._items[name].set_source("code", value) super(Env, self).__setattr__(name, value) def set_attr_source_value(self, name, source, value): diff --git a/tests/test_env.py b/tests/test_env.py index 435ad25..44e484e 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -377,6 +377,6 @@ class Config(Env): # Test the programmatic source. cfg.foo = 2 assert cfg.foo == 2 - assert cfg.source_type("foo") == "programmatic" - cfg.set_attr_source_value("foo", "programmatic", 10) + assert cfg.source_type("foo") == "code" + cfg.set_attr_source_value("foo", "code", 10) assert cfg.foo == 10