diff --git a/engine/apps/grafana_plugin/tasks/sync_v2.py b/engine/apps/grafana_plugin/tasks/sync_v2.py index 433eb4b2a8..1a479aa776 100644 --- a/engine/apps/grafana_plugin/tasks/sync_v2.py +++ b/engine/apps/grafana_plugin/tasks/sync_v2.py @@ -1,7 +1,7 @@ import logging from celery.utils.log import get_task_logger -from django.utils import timezone +from django.conf import settings from apps.grafana_plugin.helpers.client import GrafanaAPIClient from apps.grafana_plugin.helpers.gcom import get_active_instance_ids @@ -12,10 +12,6 @@ logger.setLevel(logging.DEBUG) -SYNC_PERIOD = timezone.timedelta(minutes=4) -SYNC_BATCH_SIZE = 500 - - @shared_dedicated_queue_retry_task(autoretry_for=(Exception,), retry_backoff=True, max_retries=0) def start_sync_organizations_v2(): organization_qs = Organization.objects.all() @@ -30,20 +26,22 @@ def start_sync_organizations_v2(): logger.info(f"Found {len(organization_qs)} active organizations") batch = [] + batch_index = 0 + task_countdown_seconds = 0 for org in organization_qs: if GrafanaAPIClient.validate_grafana_token_format(org.api_token): batch.append(org.pk) - if len(batch) == SYNC_BATCH_SIZE: - sync_organizations_v2.apply_async( - (batch,), - ) + if len(batch) == settings.SYNC_V2_BATCH_SIZE: + sync_organizations_v2.apply_async((batch,), countdown=task_countdown_seconds) batch = [] + batch_index += 1 + if batch_index == settings.SYNC_V2_MAX_TASKS: + batch_index = 0 + task_countdown_seconds += settings.SYNC_V2_PERIOD_SECONDS else: logger.info(f"Skipping stack_slug={org.stack_slug}, api_token format is invalid or not set") if batch: - sync_organizations_v2.apply_async( - (batch,), - ) + sync_organizations_v2.apply_async((batch,), countdown=task_countdown_seconds) @shared_dedicated_queue_retry_task(autoretry_for=(Exception,), retry_backoff=True, max_retries=0) diff --git a/engine/apps/grafana_plugin/tests/test_sync_v2.py b/engine/apps/grafana_plugin/tests/test_sync_v2.py index 1915285b18..704ff9a3dd 100644 --- a/engine/apps/grafana_plugin/tests/test_sync_v2.py +++ b/engine/apps/grafana_plugin/tests/test_sync_v2.py @@ -1,7 +1,7 @@ import gzip import json from dataclasses import asdict -from unittest.mock import patch +from unittest.mock import call, patch import pytest from django.urls import reverse @@ -159,3 +159,34 @@ def test_sync_team_serialization(test_team, validation_pass): except ValidationError as e: validation_error = e assert (validation_error is None) == validation_pass + + +@pytest.mark.django_db +def test_sync_batch_tasks(make_organization, settings): + settings.SYNC_V2_MAX_TASKS = 2 + settings.SYNC_V2_PERIOD_SECONDS = 10 + settings.SYNC_V2_BATCH_SIZE = 2 + + for _ in range(9): + make_organization(api_token="glsa_abcdefghijklmnopqrstuvwxyz") + + expected_calls = [ + call(size=2, countdown=0), + call(size=2, countdown=0), + call(size=2, countdown=10), + call(size=2, countdown=10), + call(size=1, countdown=20), + ] + with patch("apps.grafana_plugin.tasks.sync_v2.sync_organizations_v2.apply_async", return_value=None) as mock_sync: + start_sync_organizations_v2() + + def check_call(actual, expected): + return ( + len(actual.args[0][0]) == expected.kwargs["size"] + and actual.kwargs["countdown"] == expected.kwargs["countdown"] + ) + + for actual_call, expected_call in zip(mock_sync.call_args_list, expected_calls): + assert check_call(actual_call, expected_call) + + assert mock_sync.call_count == len(expected_calls) diff --git a/engine/settings/base.py b/engine/settings/base.py index d2bac3ff72..788f958fb8 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -963,3 +963,7 @@ class BrokerTypes: DETACHED_INTEGRATIONS_SERVER = getenv_boolean("DETACHED_INTEGRATIONS_SERVER", default=False) ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS = os.environ.get("ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS", default=14) + +SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 10) +SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 300) +SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500)