diff --git a/envier/env.py b/envier/env.py index e9ee301..9e9100a 100644 --- a/envier/env.py +++ b/envier/env.py @@ -67,6 +67,35 @@ def __init__( self.help_type = help_type self.help_default = help_default + def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any: + if _type is bool: + return t.cast(T, raw.lower() in env.__truthy__) + elif _type in (list, tuple, set): + collection = raw.split(env.__item_separator__) + return t.cast( + T, + _type( # type: ignore[operator] + collection if self.map is None else map(self.map, collection) # type: ignore[arg-type] + ), + ) + elif _type is dict: + d = dict( + _.split(env.__value_separator__, 1) + for _ in raw.split(env.__item_separator__) + ) + if self.map is not None: + d = dict(self.map(*_) for _ in d.items()) + return t.cast(T, d) + + if _check_type(raw, _type): + return t.cast(T, raw) + + try: + return _type(raw) + except Exception as e: + msg = f"cannot cast {raw} to {self.type}" + raise TypeError(msg) from e + def _retrieve(self, env: "Env", prefix: str) -> T: source = env.source @@ -121,36 +150,14 @@ def _retrieve(self, env: "Env", prefix: str) -> T: ) return parsed - if self.type is bool: - return t.cast(T, raw.lower() in env.__truthy__) - elif self.type in (list, tuple, set): - collection = raw.split(env.__item_separator__) - return t.cast( - T, - self.type( # type: ignore[operator] - collection if self.map is None else map(self.map, collection) # type: ignore[arg-type] - ), - ) - elif self.type is dict: - d = dict( - _.split(env.__value_separator__, 1) - for _ in raw.split(env.__item_separator__) - ) - if self.map is not None: - d = dict(self.map(*_) for _ in d.items()) - return t.cast(T, d) - - if _check_type(raw, self.type): - return t.cast(T, raw) - if hasattr(self.type, "__origin__") and self.type.__origin__ is t.Union: # type: ignore[attr-defined,union-attr] for ot in self.type.__args__: # type: ignore[attr-defined,union-attr] try: - return t.cast(T, ot(raw)) + return t.cast(T, self._cast(ot, raw, env)) except TypeError: pass - return self.type(raw) # type: ignore[call-arg,operator] + return self._cast(self.type, raw, env) def __call__(self, env: "Env", prefix: str) -> T: value = self._retrieve(env, prefix) @@ -436,9 +443,11 @@ def add_entries(full_prefix: str, config: t.Type[Env]) -> None: ( f"``{private_prefix}{full_prefix}{_normalized(v.name)}``", help_type, # type: ignore[attr-defined] - v.help_default - if v.help_default is not None - else str(v.default), + ( + v.help_default + if v.help_default is not None + else str(v.default) + ), help_message, ) ) diff --git a/tests/test_env.py b/tests/test_env.py index 4374a00..15ff70d 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -284,11 +284,16 @@ class DictConfig(Env): assert DictConfig().foo == expected -def test_env_optional_default(): +def test_env_optional_default(monkeypatch): class DictConfig(Env): foo = Env.var(Optional[str], "foo", default=None) + bar = Env.var(Optional[bool], "bar", default=None) assert DictConfig().foo is None + assert DictConfig().bar is None + + monkeypatch.setenv("BAR", "0") + assert not DictConfig().bar @pytest.mark.parametrize("value,_type", [(1, int), ("1", str)])