diff --git a/engine/apps/api/tests/test_user.py b/engine/apps/api/tests/test_user.py index 900a48eb23..b89a0ba340 100644 --- a/engine/apps/api/tests/test_user.py +++ b/engine/apps/api/tests/test_user.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest from django.core.cache import cache @@ -1775,17 +1775,10 @@ def test_invalid_working_hours( @patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock()) @patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True) -@patch( - "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.get_throttle_limits", - return_value=(1, 10 * 60), -) -@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.get_throttle_limits", return_value=(1, 10 * 60)) @pytest.mark.django_db def test_phone_number_verification_flow_ratelimit_per_user( mock_verification_start, mocked_verification_check, - mocked_get_phone_verification_code_get_throttle_limits, - mocked_get_phone_verify_phone_number_limits, make_organization_and_user_with_plugin_token, make_user_auth_headers, ): @@ -1794,40 +1787,44 @@ def test_phone_number_verification_flow_ratelimit_per_user( client = APIClient() url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key}) - # first get_verification_code request is succesfull - response = client.get(url, format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_200_OK + with patch( + "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.rate", + new_callable=PropertyMock, + ) as mocked_rate: + mocked_rate.return_value = "1/10m" + # first get_verification_code request is succesfull + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK - # second get_verification_code request is ratelimited - response = client.get(url, format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + # second get_verification_code request is ratelimited + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) - # first verify_number request is succesfull, because it uses different ratelimit scope - response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_200_OK + with patch( + "apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.rate", + new_callable=PropertyMock, + ) as mocked_rate: + mocked_rate.return_value = "1/10m" - url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + # first verify_number request is succesfull, because it uses different ratelimit scope + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK - # second verify_number request is succesfull, because it ratelimited - response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + + # second verify_number request is succesfull, because it ratelimited + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS @patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock()) @patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True) -@patch( - "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.get_throttle_limits", - return_value=(1, 10 * 60), -) -@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.get_throttle_limits", return_value=(1, 10 * 60)) @pytest.mark.django_db def test_phone_number_verification_flow_ratelimit_per_org( mock_verification_start, mocked_verification_check, - mocked_get_phone_verification_code_get_throttle_limits, - mocked_get_phone_verify_phone_number_limits, make_organization_and_user_with_plugin_token, make_user_auth_headers, make_user_for_organization, @@ -1841,21 +1838,33 @@ def test_phone_number_verification_flow_ratelimit_per_org( client = APIClient() - url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key}) - response = client.get(url, format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_200_OK + with patch( + "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.rate", + new_callable=PropertyMock, + ) as mocked_rate: + mocked_rate.return_value = "1/10m" - url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key}) - response = client.get(url, format="json", **make_user_auth_headers(second_user, token)) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key}) + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK - url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) - response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) - assert response.status_code == status.HTTP_200_OK + url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key}) + response = client.get(url, format="json", **make_user_auth_headers(second_user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS - url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key}) - response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token)) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + with patch( + "apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.rate", + new_callable=PropertyMock, + ) as mocked_rate: + mocked_rate.return_value = "1/10m" + + url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key}) + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS @patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock()) diff --git a/engine/apps/api/throttlers/demo_alert_throttler.py b/engine/apps/api/throttlers/demo_alert_throttler.py index e28cd8f8a7..eb622ec192 100644 --- a/engine/apps/api/throttlers/demo_alert_throttler.py +++ b/engine/apps/api/throttlers/demo_alert_throttler.py @@ -1,6 +1,6 @@ -from rest_framework.throttling import UserRateThrottle +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler -class DemoAlertThrottler(UserRateThrottle): +class DemoAlertThrottler(CustomRateUserThrottler): scope = "send_demo_alert" rate = "30/m" diff --git a/engine/apps/api/throttlers/phone_verification_throttler.py b/engine/apps/api/throttlers/phone_verification_throttler.py index 1bf6fe02c7..7e6df29e8f 100644 --- a/engine/apps/api/throttlers/phone_verification_throttler.py +++ b/engine/apps/api/throttlers/phone_verification_throttler.py @@ -1,49 +1,21 @@ -from common.api_helpers.custom_rate_scoped_throttler import CustomRateScopedThrottler +from common.api_helpers.custom_rate_scoped_throttler import CustomRateOrganizationThrottler, CustomRateUserThrottler -class GetPhoneVerificationCodeThrottlerPerUser(CustomRateScopedThrottler): - def get_scope(self): - return "get_phone_verification_code_per_user" +class GetPhoneVerificationCodeThrottlerPerUser(CustomRateUserThrottler): + rate = "5/10m" + scope = "get_phone_verification_code_per_user" - def get_throttle_limits(self): - return 5, 10 * 60 +class VerifyPhoneNumberThrottlerPerUser(CustomRateUserThrottler): + rate = "50/10m" + scope = "verify_phone_number_per_user" -class VerifyPhoneNumberThrottlerPerUser(CustomRateScopedThrottler): - def get_scope(self): - return "verify_phone_number_per_user" - def get_throttle_limits(self): - return 50, 10 * 60 +class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateOrganizationThrottler): + rate = "50/10m" + scope = "get_phone_verification_code_per_org" -class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateScopedThrottler): - def get_scope(self): - return "get_phone_verification_code_per_org" - - def get_throttle_limits(self): - return 50, 10 * 60 - - def get_cache_key(self, request, view): - if request.user.is_authenticated: - ident = request.user.organization.pk - else: - ident = self.get_ident(request) - - return self.cache_format % {"scope": self.scope, "ident": ident} - - -class VerifyPhoneNumberThrottlerPerOrg(CustomRateScopedThrottler): - def get_scope(self): - return "verify_phone_number_per_org" - - def get_throttle_limits(self): - return 50, 10 * 60 - - def get_cache_key(self, request, view): - if request.user.is_authenticated: - ident = request.user.organization.pk - else: - ident = self.get_ident(request) - - return self.cache_format % {"scope": self.scope, "ident": ident} +class VerifyPhoneNumberThrottlerPerOrg(CustomRateOrganizationThrottler): + rate = "50/10m" + scope = "verify_phone_number_per_org" diff --git a/engine/apps/api/throttlers/test_call_throttler.py b/engine/apps/api/throttlers/test_call_throttler.py index 5bcb9d356c..45d556073b 100644 --- a/engine/apps/api/throttlers/test_call_throttler.py +++ b/engine/apps/api/throttlers/test_call_throttler.py @@ -1,7 +1,7 @@ -from rest_framework.throttling import UserRateThrottle +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler -class TestCallThrottler(UserRateThrottle): +class TestCallThrottler(CustomRateUserThrottler): """ set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following: PytestCollectionWarning: cannot collect test class 'TestCallThrottler' because it has a __init__ constructor @@ -13,7 +13,7 @@ class TestCallThrottler(UserRateThrottle): rate = "5/m" -class TestPushThrottler(UserRateThrottle): +class TestPushThrottler(CustomRateUserThrottler): """ set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following: PytestCollectionWarning: cannot collect test class 'TestPushThrottler' because it has a __init__ constructor diff --git a/engine/apps/integrations/mixins/__init__.py b/engine/apps/integrations/mixins/__init__.py index f34c1d41d5..67adc120f5 100644 --- a/engine/apps/integrations/mixins/__init__.py +++ b/engine/apps/integrations/mixins/__init__.py @@ -3,5 +3,6 @@ from .ratelimit_mixin import ( # noqa: F401 IntegrationHeartBeatRateLimitMixin, IntegrationRateLimitMixin, + RateLimitMixin, is_ratelimit_ignored, ) diff --git a/engine/apps/integrations/mixins/ratelimit_mixin.py b/engine/apps/integrations/mixins/ratelimit_mixin.py index ea90086816..e3e2616efc 100644 --- a/engine/apps/integrations/mixins/ratelimit_mixin.py +++ b/engine/apps/integrations/mixins/ratelimit_mixin.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from functools import wraps +from django.conf import settings from django.core.cache import cache from django.http import HttpRequest, HttpResponse from django.views import View @@ -16,6 +17,8 @@ RATELIMIT_INTEGRATION = "300/5m" RATELIMIT_TEAM = "900/5m" +RATELIMIT_INTEGRATION_GROUP_NAME = "integration" +RATELIMIT_TEAM_GROUP_NAME = "team" RATELIMIT_REASON_INTEGRATION = "channel" RATELIMIT_REASON_TEAM = "team" INTEGRATION_TOKEN_TO_IGNORE_KEY = "integration_tokens_to_ignore_ratelimit" @@ -30,13 +33,30 @@ def get_rate_limit_per_channel_key(_, request): return str(request.alert_receive_channel.pk) -def get_rate_limit_per_team_key(_, request): +def get_rate_limit_per_organization_key(_, request): """ Rate limiting based on AlertReceiveChannel's team PK """ return str(request.alert_receive_channel.organization_id) +def get_rate_limit(group, request): + custom_ratelimits = settings.CUSTOM_RATELIMITS + + organization_id = str(request.alert_receive_channel.organization_id) + + if group == RATELIMIT_INTEGRATION_GROUP_NAME: + if organization_id in custom_ratelimits: + return custom_ratelimits[organization_id]["integration"] + return RATELIMIT_INTEGRATION + elif group == RATELIMIT_TEAM_GROUP_NAME: + if organization_id in custom_ratelimits: + return custom_ratelimits[organization_id]["organization"] + return RATELIMIT_TEAM + else: + raise Exception("Unknown group") + + def ratelimit(group=None, key=None, rate=None, method=ALL, block=False, reason=None): """ This decorator is an updated version of: @@ -171,7 +191,11 @@ def notify(self): block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit ) @ratelimit( - key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True + key=get_rate_limit_per_organization_key, + rate=RATELIMIT_TEAM, + group="team", + reason=RATELIMIT_REASON_TEAM, + block=True, ) def execute_rate_limit(self, *args, **kwargs): pass @@ -201,13 +225,17 @@ class IntegrationRateLimitMixin(RateLimitMixin, View): @ratelimit( key=get_rate_limit_per_channel_key, - rate=RATELIMIT_INTEGRATION, - group="integration", + rate=get_rate_limit, + group=RATELIMIT_INTEGRATION_GROUP_NAME, reason=RATELIMIT_REASON_INTEGRATION, block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit ) @ratelimit( - key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True + key=get_rate_limit_per_organization_key, + rate=get_rate_limit, + group=RATELIMIT_TEAM_GROUP_NAME, + reason=RATELIMIT_REASON_TEAM, + block=True, ) def execute_rate_limit(self, *args, **kwargs): pass diff --git a/engine/apps/integrations/tests/test_ratelimit.py b/engine/apps/integrations/tests/test_ratelimit.py index c38edec2a8..78d4be09d4 100644 --- a/engine/apps/integrations/tests/test_ratelimit.py +++ b/engine/apps/integrations/tests/test_ratelimit.py @@ -1,8 +1,9 @@ +import json from unittest import mock import pytest from django.core.cache import cache -from django.test import Client +from django.test import Client, override_settings from django.urls import reverse from rest_framework import status @@ -35,9 +36,9 @@ def test_ratelimit_alerts_per_integration( c = Client() - response = c.post(url, data={"message": "This is the test alert from amixr"}) + response = c.post(url, data={"message": "This is the test alert"}) assert response.status_code == 200 - response = c.post(url, data={"message": "This is the test alert from amixr"}) + response = c.post(url, data={"message": "This is the test alert"}) assert response.status_code == 429 assert mocked_task.call_count == 1 @@ -150,3 +151,82 @@ def test_ratelimit_integration_and_organization( response = client.post(urls[3]) assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS assert response.content.decode() == IntegrationRateLimitMixin.TEXT_WORKSPACE + + +@pytest.mark.django_db +def test_custom_throttling(make_organization, make_alert_receive_channel): + organization_with_custom_ratelimit = make_organization() + integration_with_custom_ratelimit = make_alert_receive_channel( + organization_with_custom_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK + ) + url_with_custom_ratelimit = reverse( + "integrations:universal", + kwargs={ + "integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, + "alert_channel_key": integration_with_custom_ratelimit.token, + }, + ) + + integration_with_custom_ratelimit_2 = make_alert_receive_channel( + organization_with_custom_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK + ) + url_with_custom_ratelimit_2 = reverse( + "integrations:universal", + kwargs={ + "integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, + "alert_channel_key": integration_with_custom_ratelimit_2.token, + }, + ) + + organization_with_default_ratelimit = make_organization() + integration_with_default_ratelimit = make_alert_receive_channel( + organization_with_default_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK + ) + url_with_default_ratelimit = reverse( + "integrations:universal", + kwargs={ + "integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, + "alert_channel_key": integration_with_default_ratelimit.token, + }, + ) + cache.clear() + + CUSTOM_RATELIMITS_STR = ( + '{"' + + str(organization_with_custom_ratelimit.pk) + + '": {"integration": "2/m","organization": "3/m","public_api": "1/m"}}' + ) + + with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)): + client = Client() + + # Organization without custom ratelimit should use default ratelimit + for _ in range(5): + response = client.post(url_with_default_ratelimit) + assert response.status_code == status.HTTP_200_OK + + # Organization with custom ratelimit will be ratelimited after 2 requests because of integration rate limit + response = client.post(url_with_custom_ratelimit) + + assert response.status_code == status.HTTP_200_OK + + response = client.post(url_with_custom_ratelimit) + + assert response.status_code == status.HTTP_200_OK + + response = client.post(url_with_custom_ratelimit) + + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert response.content.decode() == IntegrationRateLimitMixin.TEXT_INTEGRATION.format( + integration=integration_with_custom_ratelimit.verbal_name + ) + + # Organization with custom ratelimit will be ratelimited after 3 requests because of organization rate limit + response = client.post(url_with_custom_ratelimit_2) + + assert response.status_code == status.HTTP_200_OK + + response = client.post(url_with_custom_ratelimit_2) + + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert response.content.decode() == IntegrationRateLimitMixin.TEXT_WORKSPACE diff --git a/engine/apps/mobile_app/fcm_relay.py b/engine/apps/mobile_app/fcm_relay.py index 925d9fa601..e9fc13d274 100644 --- a/engine/apps/mobile_app/fcm_relay.py +++ b/engine/apps/mobile_app/fcm_relay.py @@ -6,19 +6,19 @@ from rest_framework import status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from rest_framework.throttling import UserRateThrottle from rest_framework.views import APIView from apps.auth_token.auth import ApiTokenAuthentication from apps.mobile_app.models import FCMDevice from apps.mobile_app.utils import send_message_to_fcm_device +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler from common.custom_celery_tasks import shared_dedicated_queue_retry_task task_logger = get_task_logger(__name__) task_logger.setLevel(logging.DEBUG) -class FCMRelayThrottler(UserRateThrottle): +class FCMRelayThrottler(CustomRateUserThrottler): scope = "fcm_relay" rate = "300/m" diff --git a/engine/apps/public_api/tests/test_ratelimit.py b/engine/apps/public_api/tests/test_ratelimit.py index d6a74587a6..8e2b7691c1 100644 --- a/engine/apps/public_api/tests/test_ratelimit.py +++ b/engine/apps/public_api/tests/test_ratelimit.py @@ -1,7 +1,9 @@ +import json from unittest.mock import PropertyMock, patch import pytest from django.core.cache import cache +from django.test import override_settings from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient @@ -29,3 +31,38 @@ def test_throttling(make_organization_and_user_with_token): # make sure RateLimitHeadersMixin used assert response.has_header("RateLimit-Reset") + + +@pytest.mark.django_db +def test_custom_throttling(make_organization_and_user_with_token): + organization_with_custom_ratelimit, _, token_with_custom_ratelimit = make_organization_and_user_with_token() + _, _, token_with_default_ratelimit = make_organization_and_user_with_token() + cache.clear() + + CUSTOM_RATELIMITS_STR = ( + '{"' + + str(organization_with_custom_ratelimit.pk) + + '": {"integration": "10/5m","organization": "15/5m","public_api": "1/m"}}' + ) + + with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)): + client = APIClient() + + url = reverse("api-public:alert_groups-list") + + # Organization without custom ratelimit should use default ratelimit + for _ in range(5): + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_default_ratelimit}") + assert response.status_code == status.HTTP_200_OK + + # Organization with custom ratelimit will be ratelimited after 1 request + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_custom_ratelimit}") + + assert response.status_code == status.HTTP_200_OK + + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_custom_ratelimit}") + + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + # make sure RateLimitHeadersMixin used + assert response.has_header("RateLimit-Reset") diff --git a/engine/apps/public_api/throttlers/info_throttler.py b/engine/apps/public_api/throttlers/info_throttler.py index a48bce22f5..f2c26041c8 100644 --- a/engine/apps/public_api/throttlers/info_throttler.py +++ b/engine/apps/public_api/throttlers/info_throttler.py @@ -1,6 +1,6 @@ -from rest_framework.throttling import UserRateThrottle +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler -class InfoThrottler(UserRateThrottle): +class InfoThrottler(CustomRateUserThrottler): scope = "info" rate = "100/m" diff --git a/engine/apps/public_api/throttlers/phone_notification_throttler.py b/engine/apps/public_api/throttlers/phone_notification_throttler.py index a66e19a1f1..cb389189a5 100644 --- a/engine/apps/public_api/throttlers/phone_notification_throttler.py +++ b/engine/apps/public_api/throttlers/phone_notification_throttler.py @@ -1,6 +1,6 @@ -from rest_framework.throttling import UserRateThrottle +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler -class PhoneNotificationThrottler(UserRateThrottle): +class PhoneNotificationThrottler(CustomRateUserThrottler): scope = "phone_notification" rate = "60/m" diff --git a/engine/apps/public_api/throttlers/user_throttle.py b/engine/apps/public_api/throttlers/user_throttle.py index 7c176f2e24..8b156e46dc 100644 --- a/engine/apps/public_api/throttlers/user_throttle.py +++ b/engine/apps/public_api/throttlers/user_throttle.py @@ -1,6 +1,6 @@ -from rest_framework.throttling import UserRateThrottle +from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler -class UserThrottle(UserRateThrottle): +class UserThrottle(CustomRateUserThrottler): scope = "public_api" rate = "300/m" diff --git a/engine/common/api_helpers/custom_rate_scoped_throttler.py b/engine/common/api_helpers/custom_rate_scoped_throttler.py index c8965d20ff..888d1385a8 100644 --- a/engine/common/api_helpers/custom_rate_scoped_throttler.py +++ b/engine/common/api_helpers/custom_rate_scoped_throttler.py @@ -1,55 +1,32 @@ -from rest_framework.throttling import SimpleRateThrottle +from django.conf import settings +from ratelimit.utils import _split_rate +from rest_framework.throttling import UserRateThrottle -class CustomRateScopedThrottler(SimpleRateThrottle): - """ - Abstract class to create throttlers with custom amount of seconds and custom scope. - The unique cache key will be generated by concatenating the - user id of the request, and the scope from get_scope() method. +class CustomRateUserThrottler(UserRateThrottle): + """ """ - Should not be used directly. - """ + def parse_rate(self, rate): + "Use django ratelimit format to parse rate, i.e. '30/1m', instead of '30/m'" + return _split_rate(rate) - def __init__(self): - self.scope = self.get_scope() - self.num_requests, self.duration = self.get_throttle_limits() + def allow_request(self, request, view): + # Override default rate limit, if organization id is specified in CUSTOM_RATELIMITS + custom_ratelimits = settings.CUSTOM_RATELIMITS + organization_id = str(request.user.organization_id) + if organization_id in custom_ratelimits: + self.rate = custom_ratelimits[organization_id]["public_api"] + self.num_requests, self.duration = self.parse_rate(self.rate) - def get_throttle_limits(self): - """ - :return tuple requests/seconds - """ - raise NotImplementedError + return super().allow_request(request, view) - def get_scope(self): - """ - :return ratelimit scope - """ - raise NotImplementedError - def allow_request(self, request, view): - """ - Overriden allow_request method. - The difference is that overriden method doesn't check rate property. - """ - - self.key = self.get_cache_key(request, view) - if self.key is None: - return True - - self.history = self.cache.get(self.key, []) - self.now = self.timer() - - # Drop any requests from the history which have now passed the - # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: - self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success() +class CustomRateOrganizationThrottler(CustomRateUserThrottler): + scope = "organization" def get_cache_key(self, request, view): if request.user.is_authenticated: - ident = request.user.pk + ident = request.user.organization.pk else: ident = self.get_ident(request) diff --git a/engine/common/api_helpers/custom_ratelimit.py b/engine/common/api_helpers/custom_ratelimit.py new file mode 100644 index 0000000000..e9aced3014 --- /dev/null +++ b/engine/common/api_helpers/custom_ratelimit.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + + +@dataclass +class CustomRateLimit: + integration: str + organization: str + public_api: str diff --git a/engine/settings/base.py b/engine/settings/base.py index 6450b62249..47ccb77db5 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -7,6 +7,7 @@ from celery.schedules import crontab from firebase_admin import credentials, initialize_app +from common.api_helpers.custom_ratelimit import CustomRateLimit from common.utils import getenv_boolean, getenv_integer, getenv_list VERSION = "dev-oss" @@ -964,6 +965,17 @@ class BrokerTypes: ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS = os.environ.get("ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS", default=14) +# The CUSTOM_RATELIMITS environment variable is expected to be a JSON string that defines rate limits +# for different levels (e.g., integration, organization, public API). +# Example of CUSTOM_RATELIMITS in environment variable: +# CUSTOM_RATELIMITS={"1": {"integration": "10/5m", "organization": "15/5m", "public_api": "10/5m"}} +# Where, "1" is the pk of the organization + +# Load the environment variable and parse it into a dictionary, falling back to an empty dictionary if not set. +CUSTOM_RATELIMITS: typing.Dict[str, CustomRateLimit] = json.loads(os.getenv("CUSTOM_RATELIMITS", "{}")) +# Convert the parsed JSON into a dictionary of RateLimit dataclasses +CUSTOM_RATELIMITS = {key: CustomRateLimit(**value) for key, value in CUSTOM_RATELIMITS.items()} + SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6) SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240) SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500)