Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
refactor dependency injection to be more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
thought-tobi committed Apr 2, 2024
1 parent bafa309 commit e293049
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 80 deletions.
32 changes: 13 additions & 19 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
import os

from telegram import Update
from telegram.ext import (
ApplicationBuilder,
CommandHandler,
CallbackQueryHandler
)
from telegram.ext import ApplicationBuilder, CommandHandler, CallbackQueryHandler

from src.config import ConfigurationProvider, Configuration
import src.repository.user_repository as user_repository
Expand Down Expand Up @@ -55,21 +51,20 @@ def initialize_handlers(self):
def run(self):
self.application.run_polling(allowed_updates=Update.ALL_TYPES)


@autowire
def initialize_notifications(notifier: Notifier) -> None:
"""
Adds reminders to the job queue for all users that have configured reminders.
"""
for user in user_repository.find_all_users():
user_id = user.user_id
notifications = user.notifications
logging.info(f"Setting up notifications for for user {user_id}")
for notification in notifications:
notifier.set_notification(user_id, notification)
@autowire("notifier")
def initialize_notifications(self, notifier: Notifier) -> None:
"""
Adds reminders to the job queue for all users that have configured reminders.
"""
for user in user_repository.find_all_users():
user_id = user.user_id
notifications = user.notifications
logging.info(f"Setting up notifications for for user {user_id}")
for notification in notifications:
notifier.set_notification(user_id, notification)


@autowire
@autowire("configuration")
def refresh_user_configs(configuration: Configuration):
"""
Overwrites the user configurations in the database with the configurations in the config file.
Expand All @@ -88,7 +83,6 @@ def main():
# Create application
refresh_user_configs()
application = MoodTrackerApplication(TOKEN)
initialize_notifications()
application.run()


Expand Down
49 changes: 35 additions & 14 deletions src/autowiring/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,49 @@
from src.autowiring.injectable import Injectable


def autowire(func):
class ParameterNotInSignatureError(Exception):
def __init__(self, parameter):
self.parameter = parameter
self.message = f"Parameter '{self.parameter}' specified in @autowire is not in the function signature"
super().__init__(self.message)


class ParameterNotInCacheError(Exception):
def __init__(self, parameter):
self.parameter = parameter
self.message = f"Parameter '{self.parameter}' does not exist in the 'di' cache"
super().__init__(self.message)


def autowire(*autowire_params):
"""
Decorator that autowires the parameters of a function.
Decorator that autowires the specified parameters of a function.
Parameters are autowired if they
- Are specified in the decorator arguments
- Are of a class that is a subclass of Injectable
- They exist within kink's dependency injection container
- They do not possess a default value as per the method signature
:param func: decorated function
:param autowire_params: names of parameters to autowire
:return: fully autowired function
"""
sig = inspect.signature(func)

def wrapper(*args, **kwargs):
for name, param in sig.parameters.items():
if (
param.default == param.empty and name not in kwargs
): # Check if the parameter has a default value
def decorator(func):
sig = inspect.signature(func)

def wrapper(*args, **kwargs):
for (
name
) in autowire_params: # Check if the parameter is in the autowire list
if name not in sig.parameters:
raise ParameterNotInSignatureError(name)
param = sig.parameters[name]
param_type = param.annotation
if inspect.isclass(param_type) and issubclass(param_type, Injectable):
fully_qualified_name = param_type.get_fully_qualified_name()
if fully_qualified_name in di:
kwargs[name] = di[fully_qualified_name]
return func(*args, **kwargs)
if fully_qualified_name not in di:
raise ParameterNotInCacheError(fully_qualified_name)
kwargs[name] = di[fully_qualified_name]
return func(*args, **kwargs)

return wrapper

return wrapper
return decorator
7 changes: 2 additions & 5 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def get_notifications(self) -> list[Notification]:
class ConfigurationProvider:
_configuration: Configuration

def __init__(self):
self.configuration = ConfigurationProvider.load("config.yaml")
def __init__(self, config_file: str = "config.yaml"):
self.configuration = ConfigurationProvider.load(config_file)

@staticmethod
def load(config_file: str) -> Configuration:
Expand All @@ -36,6 +36,3 @@ def load(config_file: str) -> Configuration:

def get_configuration(self) -> Configuration:
return self.configuration


_configuration = ConfigurationProvider().configuration
39 changes: 21 additions & 18 deletions src/handlers/user_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from telegram import Update

from src.autowiring.inject import autowire
from src.config import _configuration
from src.config import Configuration
from src.handlers.util import send
from src.model.notification import Notification
from src.notifier import Notifier
Expand All @@ -18,25 +18,13 @@ async def create_user(update: Update, _) -> None:
:param _: CallbackContext: is irrelevant
:return:
"""
# Declare introduction text.
bullet_point_list = "\n".join(
[f"- {metric.name.capitalize()}" for metric in _configuration.get_metrics()]
)
introduction_text = (
"Hi! You can track your mood with me. "
"Simply type /record to get started. By default, "
f"I will track the following metrics:\n {bullet_point_list}"
)

# Handle registration
user_id = update.effective_user.id
if not user_repository.find_user(user_id):
logging.info(f"Creating user {user_id}")
# todo all of this could use some decoupling
user_repository.create_user(user_id)
for notification in _configuration.get_notifications():
create_notification(user_id=user_id, notification=notification)
await send(update, text=introduction_text)
setup_notifications(user_id)
await send(update, text=introduction_text())
# User already exists
else:
logging.info(f"Received /start, but user {user_id} already exists")
Expand All @@ -47,6 +35,21 @@ async def create_user(update: Update, _) -> None:
)


@autowire
def create_notification(user_id: int, notification: Notification, notifier: Notifier):
notifier.set_notification(user_id, notification)
@autowire("configuration", "notifier")
def setup_notifications(
user_id: int, configuration: Configuration, notifier: Notifier
) -> None:
for notification in configuration.get_notifications():
notifier.set_notification(user_id, notification)


@autowire("configuration")
def introduction_text(configuration: Configuration) -> str:
bullet_point_list = "\n".join(
[f"- {metric.name.capitalize()}" for metric in configuration.get_metrics()]
)
return (
"Hi! You can track your mood with me. "
"Simply type /record to get started. By default, "
f"I will track the following metrics:\n {bullet_point_list}"
)
10 changes: 6 additions & 4 deletions src/repository/user_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pymongo

from src.autowiring.inject import autowire
from src.config import Configuration
from src.model.metric import Metric
from src.model.notification import Notification
from src.model.user import User
from src.config import _configuration

mongo_url = os.getenv("MONGODB_HOST", "localhost:27017")
mongo_client = pymongo.MongoClient(mongo_url)
Expand All @@ -29,14 +30,15 @@ def parse_user(result: dict) -> User:
return User(**result)


def create_user(user_id: int) -> None:
@autowire("configuration")
def create_user(user_id: int, configuration: Configuration) -> None:
user.insert_one(
{
"user_id": user_id,
"metrics": [metric.model_dump() for metric in _configuration.get_metrics()],
"metrics": [metric.model_dump() for metric in configuration.get_metrics()],
"notifications": [
notification.model_dump()
for notification in _configuration.get_notifications()
for notification in configuration.get_notifications()
],
}
)
Expand Down
22 changes: 22 additions & 0 deletions test/config.test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
metrics:
- name: mood
user_prompt: "How do you feel right now?"
emoji: true # I'm not brave enough to put emojis in a YAML file
values:
":zany_face:": 3
":grinning_face_with_smiling_eyes:": 2
":slightly_smiling_face:": 1
":face_without_mouth:": 0
":slightly_frowning_face:": -1
":frowning_face:": -2
":skull:": -3
- name: sleep
user_prompt: "How much sleep did you get today?"
type: numeric
values:
lower_bound: 4
upper_bound: 12

notifications:
- text: "It's time to record your mood!"
time: "18:00"
12 changes: 2 additions & 10 deletions test/test_autowiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@ def setup():


def test_trivial_autowiring():
@autowire
@autowire("test_class")
def test_func(test_class: SomeClass):
return test_class.field

assert test_func() == "some-value"


def test_autowiring_with_default_value():
@autowire
def test_func(test_class: SomeClass = SomeClass("another-value")):
return test_class.field

assert test_func() == "another-value"


def test_autowiring_with_args():
@autowire
@autowire("test_class")
def test_func(string: str, test_class: SomeClass):
return test_class.field + string

Expand Down
23 changes: 13 additions & 10 deletions test/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import pytest
from expiringdict import ExpiringDict

from src.config import ConfigurationProvider
import src.handlers.record_handlers as command_handlers
import src.repository.record_repository as record_repository
from src.config import _configuration
from src.handlers.record_handlers import create_temporary_record, button
from src.handlers.user_handlers import create_user
from src.model.metric import Metric
from src.model.user import User

expiry_time = 1
Expand Down Expand Up @@ -51,8 +52,13 @@ def patch_command_handler_methods():


@pytest.fixture
def user() -> User:
return User(user_id=1, metrics=test_metrics, notifications=[])
def metrics() -> list[Metric]:
return ConfigurationProvider('test/config.test.yaml').get_configuration().get_metrics()


@pytest.fixture
def user(metrics) -> User:
return User(user_id=1, metrics=metrics, notifications=[])


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -103,13 +109,13 @@ async def test_record_registration(button_update, update):


@pytest.mark.asyncio
async def test_finish_record_creation(update, button_update, mocker, user):
async def test_finish_record_creation(update, button_update, mocker, user, metrics):
"""
Tests state transition from recording Metric N to Finished.
"""
# given only one metric is defined
mocker.patch("src.repository.user_repository.find_user", return_value=user)
user.metrics = [test_metrics[0]]
user.metrics = [metrics[0]]

# when user calls /record
await command_handlers.record_handler(update, None)
Expand Down Expand Up @@ -167,9 +173,6 @@ async def test_record_with_offset(update):

# then the temp record's timestamp should be offset by 1 day
assert (
command_handlers.get_temp_record(1).timestamp.day
== (datetime.datetime.now() - datetime.timedelta(days=1)).day
command_handlers.get_temp_record(1).timestamp.day
== (datetime.datetime.now() - datetime.timedelta(days=1)).day
)


test_metrics = _configuration.get_metrics()
6 changes: 6 additions & 0 deletions test/test_user_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from src.app import MoodTrackerApplication
from kink import di

from src.config import ConfigurationProvider
from src.handlers.user_handlers import create_user
from src.notifier import Notifier

Expand All @@ -25,6 +26,11 @@ def mock_notifier():
assert di[Notifier.get_fully_qualified_name()] is not None


@pytest.fixture(autouse=True)
def configuration():
ConfigurationProvider().load("test/config.test.yaml").register()


def test_querying_nonexistent_user_returns_none():
assert user_repository.find_user(1) is None

Expand Down

0 comments on commit e293049

Please sign in to comment.