Skip to content

Commit

Permalink
Merge pull request #6 from ChorusOne/improve-app-stopping
Browse files Browse the repository at this point in the history
Improve application cleanup on stopping
  • Loading branch information
mksh authored Sep 25, 2024
2 parents cf941a5 + 2a0dff5 commit 88678f3
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 55 deletions.
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
.PHONY: lint deps test
.PHONY: lint deps test test-verbose

all: lint test

lint: deps
pipenv run flake8 --ignore=E501 .
pipenv run flake8 --ignore=E501,W503 .
pipenv run mypy --strict .
pipenv run black --check .

deps:
pipenv sync --dev

test: deps
pipenv run pytest tests.py
pipenv run pytest tests.py -sv

test-verbose: deps
pipenv run pytest tests.py -v -o log_cli=true --capture=fd --show-capture=stderr --log-level=DEBUG
101 changes: 78 additions & 23 deletions ssv_cluster_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import copy
import enum
import functools
import json
import logging
import pathlib
Expand All @@ -12,7 +13,7 @@
import furl # type: ignore[import-untyped]
from prometheus_async import aio
from prometheus_client import Gauge
from pydantic import AfterValidator, BaseModel, computed_field
from pydantic import AfterValidator, BaseModel, ConfigDict, computed_field
from pydantic_settings import BaseSettings
from web3 import Web3
from web3.contract import AsyncContract
Expand Down Expand Up @@ -42,11 +43,29 @@ def ssv_network_views_contract(self) -> str:
raise RuntimeError("Can not derive SSV network views address for network")


# ###################
# ####################
# Aiohttp & web3 apps
def get_application() -> web.Application:
async def start_exporter_app(app: web.Application) -> None:
exporter: SSVClusterExporter = app[exporter_app_key]
# Reuse client session for web3 and ssv api
await exporter.ethereum_rpc.provider.cache_async_session(exporter.session) # type: ignore[attr-defined]
# Acquire data once to verify its working
await exporter.tick()
# Spawn long-running process
exporter.start()


async def stop_exporter_app(app: web.Application) -> None:
exporter: SSVClusterExporter = app[exporter_app_key]
await exporter.stop()


def get_application(exporter: "SSVClusterExporter") -> web.Application:
app = web.Application()
app[exporter_app_key] = exporter
app.router.add_get("/metrics", aio.web.server_stats)
app.on_startup.append(start_exporter_app)
app.on_shutdown.append(stop_exporter_app)
return app


Expand Down Expand Up @@ -209,10 +228,9 @@ class SSVNetworkProperties(BaseModel):
class SSVNetworkContract(BaseModel):
"""A facade for web3 contract data retrieval for network wide values."""

network_views: AsyncContract
model_config = ConfigDict(arbitrary_types_allowed=True)

class Config:
arbitrary_types_allowed = True
network_views: AsyncContract

async def fetch_network_fee(self) -> int:
return int(await self.network_views.functions.getNetworkFee().call())
Expand Down Expand Up @@ -243,11 +261,11 @@ async def fetch_all(self) -> SSVNetworkProperties:
class SSVClusterContract(BaseModel):
"""A facade for web3 contract data retrieval for clusters."""

model_config = ConfigDict(arbitrary_types_allowed=True)

network_views: AsyncContract
clusters: set[SSVCluster]

class Config:
arbitrary_types_allowed = True
loop: asyncio.AbstractEventLoop

def contract_call_args(self, cluster: SSVCluster) -> SSVNetworkViewsCallArgs:
return (
Expand Down Expand Up @@ -277,13 +295,13 @@ async def get_cluster_burn_rate(self, cluster: SSVCluster) -> None:
async def fetch_balances(self) -> None:
futs = []
for cluster in self.clusters:
futs.append(asyncio.create_task(self.get_cluster_balance(cluster)))
futs.append(self.loop.create_task(self.get_cluster_balance(cluster)))
await asyncio.gather(*futs)

async def fetch_burn_rates(self) -> None:
futs = []
for cluster in self.clusters:
futs.append(asyncio.create_task(self.get_cluster_burn_rate(cluster)))
futs.append(self.loop.create_task(self.get_cluster_burn_rate(cluster)))
await asyncio.gather(*futs)

async def fetch_all(self) -> None:
Expand All @@ -304,13 +322,39 @@ class SSVClusterExporter(BaseSettings):
ethereum_rpc: Web3RpcClient
base_ssv_url: furl.furl = furl.furl("https://api.ssv.network/api/v4/")

session: client.ClientSession
loop: asyncio.AbstractEventLoop

# Stopping
stopping: bool = False
stopped: asyncio.Event = asyncio.Event()

@computed_field # type: ignore
@property
@functools.cached_property
def network_views(self) -> AsyncContract:
return get_ssv_network_views_contract(self.ethereum_rpc, self.network) # type: ignore[arg-type]

@computed_field # type: ignore
@functools.cached_property
def session(self) -> client.ClientSession:
return client.ClientSession(loop=self.loop)

def on_runner_task_done(self, *args: typing.Any) -> None:
self.stopped.set()

def start(self) -> None:
self._runner_task = self.loop.create_task(self.run())
# Raise event when task is stopped
self._runner_task.add_done_callback(self.on_runner_task_done)

async def stop(self) -> None:
logger.info("Gracefully shutting down application")
self.stopping = True
self._runner_task.cancel()
if not self.stopped.is_set():
await self.stopped.wait()
await self.session.close()
logger.info("Stopped components, will exit")

async def sleep(self) -> None:
await asyncio.sleep(self.interval_ms / 1000)

Expand Down Expand Up @@ -390,11 +434,11 @@ async def fetch_clusters_info(self) -> list[SSVCluster]:

for owner_config in self.owners:
futs.append(
asyncio.create_task(self.get_owner_clusters(owner_config.address))
self.loop.create_task(self.get_owner_clusters(owner_config.address))
)
for cluster_config in self.clusters:
futs.append(
asyncio.create_task(self.get_cluster_by_id(cluster_config.cluster_id))
self.loop.create_task(self.get_cluster_by_id(cluster_config.cluster_id))
)

responses = await asyncio.gather(*futs)
Expand Down Expand Up @@ -434,7 +478,9 @@ async def clusters_updates(self) -> None:
"""Run cluster-specific metrics update."""
clusters = set(await self.fetch_clusters_info())
latest_metric_fetcher = SSVClusterContract(
network_views=self.network_views, clusters=clusters
network_views=self.network_views,
clusters=clusters,
loop=self.loop,
)
await latest_metric_fetcher.fetch_all()
self.update_clusters_metrics(*clusters)
Expand All @@ -454,15 +500,23 @@ async def tick(self) -> None:
)
except Exception:
logger.exception("Failed to update cluster details")
if self.stopping:
await self.session.close()

async def loop(self) -> None:
async def run(self) -> None:
"""Infinite loop that spawns checker tasks."""
while True:
asyncio.ensure_future(self.tick())
while not self.stopping:
self.loop.create_task(self.tick())
await self.sleep()
self.stopped.set()


# Aiohttp app key for exporter component
exporter_app_key: web.AppKey[SSVClusterExporter] = web.AppKey(
"exporter", SSVClusterExporter
)


# #############
# Entry point
def main() -> None:
Expand All @@ -476,7 +530,7 @@ def main() -> None:
asyncio.set_event_loop(loop)
try:
config_data = yaml.safe_load(config_text)
config_data["session"] = client.ClientSession(loop=loop)
config_data["loop"] = loop
exporter = SSVClusterExporter(**config_data)
except yaml.error.YAMLError:
logger.exception("Invalid config YAML")
Expand All @@ -485,9 +539,10 @@ def main() -> None:
logger.exception("Invalid config data")
exit(2)
else:
app = get_application()
loop.create_task(exporter.loop())
web.run_app(app, host=args.host, port=args.port, loop=loop)
app = get_application(exporter)
web.run_app(
app, host=args.host, port=args.port, loop=loop, handler_cancellation=True
)


if __name__ == "__main__":
Expand Down
59 changes: 30 additions & 29 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import AsyncGenerator
import socket
import typing
Expand All @@ -18,17 +19,16 @@ def find_free_port() -> int:

@pytest_asyncio.fixture
async def metrics_server(exporter_data: typing.Any) -> AsyncGenerator[str, None]:
exporter_data["session"] = client.ClientSession()
exporter_data["loop"] = asyncio.get_event_loop()
exporter = ssv_cluster_exporter.SSVClusterExporter(**exporter_data)
port = find_free_port()
# Acquire data once
await exporter.tick()
app = ssv_cluster_exporter.get_application()
app = ssv_cluster_exporter.get_application(exporter)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", port)
await site.start()
yield f"http://localhost:{port}"
await runner.shutdown()
await site.stop()


Expand Down Expand Up @@ -59,32 +59,33 @@ async def metrics_server(exporter_data: typing.Any) -> AsyncGenerator[str, None]
],
)
async def test_metrics(metrics_server: str) -> None:
session = client.ClientSession()
response = await session.get(f"{metrics_server}/metrics")
assert response.status == 200
matched_metrics = set()
for metric in text_string_to_metric_families(await response.text()):
if metric.name.startswith("ssv_cluster"):
sample = metric.samples[0]
assert (
sample.labels["cluster_id"]
== "0xde12c5ce1bc895c3ed8b81afcbbb55b3efff7ae9ebac5dbd2ebac3bd29474c09" # noqa: W503
)
assert sample.labels["id"] == "1278541"
assert sample.labels["network"] == "holesky"
assert sample.labels["operators"] == "1092,1093,1094,1095"
assert (
sample.labels["owner"] == "0xD4BB555d3B0D7fF17c606161B44E372689C14F4B"
)
matched_metrics.add(metric.name)
elif metric.name in (
"ssv_network_fee",
"ssv_minimum_liquidation_collateral",
"ssv_liquidation_threshold_period",
):
sample = metric.samples[0]
assert sample.labels["network"] == "holesky"
matched_metrics.add(metric.name)
async with client.ClientSession() as session:
response = await session.get(f"{metrics_server}/metrics")
assert response.status == 200
for metric in text_string_to_metric_families(await response.text()):
if metric.name.startswith("ssv_cluster"):
sample = metric.samples[0]
assert (
sample.labels["cluster_id"]
== "0xde12c5ce1bc895c3ed8b81afcbbb55b3efff7ae9ebac5dbd2ebac3bd29474c09" # noqa: W503
)
assert sample.labels["id"] == "1278541"
assert sample.labels["network"] == "holesky"
assert sample.labels["operators"] == "1092,1093,1094,1095"
assert (
sample.labels["owner"]
== "0xD4BB555d3B0D7fF17c606161B44E372689C14F4B"
)
matched_metrics.add(metric.name)
elif metric.name in (
"ssv_network_fee",
"ssv_minimum_liquidation_collateral",
"ssv_liquidation_threshold_period",
):
sample = metric.samples[0]
assert sample.labels["network"] == "holesky"
matched_metrics.add(metric.name)
assert matched_metrics == {
"ssv_cluster_validators_count",
"ssv_cluster_balance",
Expand Down

0 comments on commit 88678f3

Please sign in to comment.