diff --git a/pubnub/dtos.py b/pubnub/dtos.py index ae0220b0..5a1ececc 100644 --- a/pubnub/dtos.py +++ b/pubnub/dtos.py @@ -10,6 +10,12 @@ def __init__(self, channels=None, channel_groups=None, presence_enabled=None, ti self.presence_enabled = presence_enabled self.timetoken = timetoken + @property + def channels_with_pressence(self): + if not self.presence_enabled: + return self.channels + return [*self.channels] + [ch + '-pnpres' for ch in self.channels] + class UnsubscribeOperation(object): def __init__(self, channels=None, channel_groups=None): diff --git a/pubnub/event_engine/manage_effects.py b/pubnub/event_engine/manage_effects.py index 3522d209..45f4fe01 100644 --- a/pubnub/event_engine/manage_effects.py +++ b/pubnub/event_engine/manage_effects.py @@ -4,6 +4,7 @@ from typing import Optional, Union from pubnub.endpoints.presence.heartbeat import Heartbeat +from pubnub.endpoints.presence.leave import Leave from pubnub.endpoints.pubsub.subscribe import Subscribe from pubnub.enums import PNReconnectionPolicy from pubnub.exceptions import PubNubException @@ -20,6 +21,7 @@ class ManagedEffect: effect: Union[effects.PNManageableEffect, effects.PNCancelEffect] stop_event = None logger: logging.Logger + task: asyncio.Task def set_pn(self, pubnub: PubNub): self.pubnub = pubnub @@ -42,6 +44,8 @@ def stop(self): if self.stop_event: self.logger.debug(f'stop_event({id(self.stop_event)}).set() called on {self.__class__.__name__}') self.stop_event.set() + if hasattr(self, 'task') and isinstance(self.task, asyncio.Task) and not self.task.cancelled(): + self.task.cancel() def get_new_stop_event(self): event = asyncio.Event() @@ -60,9 +64,9 @@ def run(self): loop: asyncio.AbstractEventLoop = self.pubnub.event_loop coro = self.handshake_async(channels=channels, groups=groups, timetoken=tt, stop_event=self.stop_event) if loop.is_running(): - loop.create_task(coro) + self.task = loop.create_task(coro) else: - loop.run_until_complete(coro) + self.task = loop.run_until_complete(coro) else: # TODO: the synchronous way pass @@ -98,9 +102,9 @@ def run(self): loop: asyncio.AbstractEventLoop = self.pubnub.event_loop coro = self.receive_messages_async(channels, groups, timetoken, region) if loop.is_running(): - loop.create_task(coro) + self.task = loop.create_task(coro) else: - loop.run_until_complete(coro) + self.task = loop.run_until_complete(coro) else: # TODO: the synchronous way pass @@ -181,9 +185,9 @@ def run(self): loop: asyncio.AbstractEventLoop = self.pubnub.event_loop coro = self.delayed_reconnect_async(delay, attempts) if loop.is_running(): - self.delayed_reconnect_coro = loop.create_task(coro) + self.task = loop.create_task(coro) else: - self.delayed_reconnect_coro = loop.run_until_complete(coro) + self.task = loop.run_until_complete(coro) else: # TODO: the synchronous way pass @@ -218,9 +222,9 @@ def stop(self): if self.stop_event: self.logger.debug(f'stop_event({id(self.stop_event)}).set() called on {self.__class__.__name__}') self.stop_event.set() - if self.delayed_reconnect_coro: + if self.task: try: - self.delayed_reconnect_coro.cancel() + self.task.cancel() except asyncio.exceptions.CancelledError: pass @@ -276,9 +280,9 @@ def run(self): loop: asyncio.AbstractEventLoop = self.pubnub.event_loop coro = self.heartbeat(channels=channels, groups=groups, stop_event=self.stop_event) if loop.is_running(): - loop.create_task(coro) + self.task = loop.create_task(coro) else: - loop.run_until_complete(coro) + self.task = loop.run_until_complete(coro) async def heartbeat(self, channels, groups, stop_event): request = Heartbeat(self.pubnub).channels(channels).channel_groups(groups).cancellation_event(stop_event) @@ -286,14 +290,14 @@ async def heartbeat(self, channels, groups, stop_event): if heartbeat.status.error: self.logger.warning(f'Heartbeat failed: {heartbeat.status.error_data.__dict__}') - self.event_engine.trigger(events.HeartbeatFailureEvent(heartbeat.status.error_data, 1)) + self.event_engine.trigger(events.HeartbeatFailureEvent(channels=channels, groups=groups, + reason=heartbeat.status.error_data, attempt=1)) else: self.event_engine.trigger(events.HeartbeatSuccessEvent(channels=channels, groups=groups)) class ManagedHeartbeatWaitEffect(ManagedEffect): - def __init__(self, pubnub_instance, event_engine_instance, - effect: Union[effects.PNManageableEffect, effects.PNCancelEffect]) -> None: + def __init__(self, pubnub_instance, event_engine_instance, effect: effects.HeartbeatWaitEffect) -> None: super().__init__(pubnub_instance, event_engine_instance, effect) self.heartbeat_interval = pubnub_instance.config.heartbeat_interval @@ -301,13 +305,13 @@ def run(self): if hasattr(self.pubnub, 'event_loop'): self.stop_event = self.get_new_stop_event() loop: asyncio.AbstractEventLoop = self.pubnub.event_loop - coro = self.heartbeat_wait(self.heartbeat_interval, stop_event=self.stop_event) + coroutine = self.heartbeat_wait(self.heartbeat_interval, stop_event=self.stop_event) if loop.is_running(): - loop.create_task(coro) + self.task = loop.create_task(coroutine) else: - loop.run_until_complete(coro) + self.task = loop.run_until_complete(coroutine) - async def heartbeat(self, wait_time: int, stop_event): + async def heartbeat_wait(self, wait_time: int, stop_event): try: await asyncio.sleep(wait_time) self.event_engine.trigger(events.HeartbeatTimesUpEvent(channels=self.effect.channels, @@ -317,7 +321,24 @@ async def heartbeat(self, wait_time: int, stop_event): class ManagedHeartbeatLeaveEffect(ManagedEffect): - pass + def run(self): + channels = self.effect.channels + groups = self.effect.groups + if hasattr(self.pubnub, 'event_loop'): + self.stop_event = self.get_new_stop_event() + loop: asyncio.AbstractEventLoop = self.pubnub.event_loop + coro = self.leave(channels=channels, groups=groups, stop_event=self.stop_event) + if loop.is_running(): + self.task = loop.create_task(coro) + else: + self.task = loop.run_until_complete(coro) + + async def leave(self, channels, groups, stop_event): + leave_request = Leave(self.pubnub).channels(channels).channel_groups(groups).cancellation_event(stop_event) + leave = await leave_request.future() + + if leave.status.error: + self.logger.warning(f'Heartbeat failed: {leave.status.error_data.__dict__}') class ManagedHeartbeatDelayedHeartbeatEffect(ManagedEffect): @@ -331,6 +352,9 @@ class ManagedEffectFactory: effects.HandshakeReconnectEffect.__name__: ManagedHandshakeReconnectEffect, effects.ReceiveReconnectEffect.__name__: ManagedReceiveReconnectEffect, effects.HeartbeatEffect.__name__: ManagedHeartbeatEffect, + effects.HeartbeatWaitEffect.__name__: ManagedHeartbeatWaitEffect, + effects.HeartbeatDelayedEffect.__name__: ManagedHeartbeatDelayedHeartbeatEffect, + effects.HeartbeatLeaveEffect.__name__: ManagedHeartbeatLeaveEffect, } def __init__(self, pubnub_instance, event_engine_instance) -> None: @@ -339,7 +363,7 @@ def __init__(self, pubnub_instance, event_engine_instance) -> None: def create(self, effect: ManagedEffect): if effect.__class__.__name__ not in self._managed_effects: - raise PubNubException(errormsg="Unhandled manage effect") + raise PubNubException(errormsg=f"Unhandled managed effect: {effect.__class__.__name__}") return self._managed_effects[effect.__class__.__name__](self._pubnub, self._event_engine, effect) diff --git a/pubnub/event_engine/models/effects.py b/pubnub/event_engine/models/effects.py index b9607776..4e503399 100644 --- a/pubnub/event_engine/models/effects.py +++ b/pubnub/event_engine/models/effects.py @@ -109,6 +109,7 @@ def __init__(self, channels: Union[None, List[str]] = None, groups: Union[None, class HeartbeatWaitEffect(PNManageableEffect): def __init__(self, time) -> None: + self.wait_time = time super().__init__() @@ -117,13 +118,14 @@ class HeartbeatCancelWaitEffect(PNCancelEffect): class HeartbeatLeaveEffect(PNManageableEffect): - def __init__(self) -> None: + def __init__(self, channels: Union[None, List[str]] = None, groups: Union[None, List[str]] = None) -> None: super().__init__() + self.channels = channels + self.groups = groups class HeartbeatDelayedEffect(PNManageableEffect): - def __init__(self) -> None: - super().__init__() + pass class HeartbeatCancelDelayedEffect(PNCancelEffect): diff --git a/pubnub/event_engine/models/states.py b/pubnub/event_engine/models/states.py index d0940d13..c420367c 100644 --- a/pubnub/event_engine/models/states.py +++ b/pubnub/event_engine/models/states.py @@ -566,6 +566,8 @@ def __init__(self, context: PNContext) -> None: } def joined(self, event: events.HeartbeatJoinedEvent, context: PNContext) -> PNTransition: + self._context.channels = event.channels + self._context.groups = event.groups self._context.update(context) return PNTransition( @@ -644,7 +646,7 @@ def left(self, event: events.HeartbeatLeftEvent, context: PNContext) -> PNTransi return PNTransition( state=HeartbeatingState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def reconnect(self, event: events.HeartbeatReconnectEvent, context: PNContext) -> PNTransition: @@ -661,7 +663,7 @@ def disconnect(self, event: events.HeartbeatDisconnectEvent, context: PNContext) return PNTransition( state=HeartbeatStoppedState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> PNTransition: @@ -670,7 +672,7 @@ def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> P return PNTransition( state=HeartbeatInactiveState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) @@ -705,7 +707,7 @@ def disconnect(self, event: events.HeartbeatDisconnectEvent, context: PNContext) return PNTransition( state=HeartbeatStoppedState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> PNTransition: @@ -714,7 +716,7 @@ def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> P return PNTransition( state=HeartbeatInactiveState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def joined(self, event: events.HeartbeatJoinedEvent, context: PNContext) -> PNTransition: @@ -731,7 +733,7 @@ def left(self, event: events.HeartbeatLeftEvent, context: PNContext) -> PNTransi return PNTransition( state=HeartbeatingState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def success(self, event: events.HeartbeatSuccessEvent, context: PNContext) -> PNTransition: @@ -760,10 +762,9 @@ def on_enter(self, context: PNContext): super().on_enter(self._context) return effects.HeartbeatWaitEffect(self._context) - def on_exit(self, context: PNContext): - self._context.update(context) - super().on_exit(self._context) - return effects.HeartbeatCancelWaitEffect(self._context) + def on_exit(self): + super().on_exit() + return effects.HeartbeatCancelWaitEffect() def disconnect(self, event: events.HeartbeatDisconnectEvent, context: PNContext) -> PNTransition: self._context.update(context) @@ -771,7 +772,7 @@ def disconnect(self, event: events.HeartbeatDisconnectEvent, context: PNContext) return PNTransition( state=HeartbeatStoppedState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> PNTransition: @@ -780,7 +781,7 @@ def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> P return PNTransition( state=HeartbeatInactiveState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def joined(self, event: events.HeartbeatJoinedEvent, context: PNContext) -> PNTransition: @@ -797,7 +798,7 @@ def left(self, event: events.HeartbeatLeftEvent, context: PNContext) -> PNTransi return PNTransition( state=HeartbeatingState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def times_up(self, event: events.HeartbeatTimesUpEvent, context: PNContext) -> PNTransition: @@ -827,10 +828,9 @@ def on_enter(self, context: PNContext): super().on_enter(self._context) return effects.HeartbeatDelayedEffect(self._context) - def on_exit(self, context: PNContext): - self._context.update(context) - super().on_exit(self._context) - return effects.HeartbeatCancelDelayedEffect(self._context) + def on_exit(self): + super().on_exit() + return effects.HeartbeatCancelDelayedEffect() def failure(self, event: events.HeartbeatFailureEvent, context: PNContext) -> PNTransition: self._context.update(context) @@ -854,7 +854,7 @@ def left(self, event: events.HeartbeatLeftEvent, context: PNContext) -> PNTransi return PNTransition( state=HeartbeatingState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def success(self, event: events.HeartbeatSuccessEvent, context: PNContext) -> PNTransition: @@ -879,7 +879,7 @@ def disconnect(self, event: events.HeartbeatDisconnectEvent, context: PNContext) return PNTransition( state=HeartbeatStoppedState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> PNTransition: @@ -888,5 +888,5 @@ def left_all(self, event: events.HeartbeatLeftAllEvent, context: PNContext) -> P return PNTransition( state=HeartbeatInactiveState, context=self._context, - effect=effects.HeartbeatLeaveEffect() + effect=effects.HeartbeatLeaveEffect(channels=self._context.channels, groups=self._context.groups) ) diff --git a/pubnub/event_engine/statemachine.py b/pubnub/event_engine/statemachine.py index 3b5bbe35..81e0a85a 100644 --- a/pubnub/event_engine/statemachine.py +++ b/pubnub/event_engine/statemachine.py @@ -22,6 +22,7 @@ def __init__(self, initial_state: states.PNState, dispatcher_class: Optional[Dis dispatcher_class = Dispatcher self._dispatcher = dispatcher_class(self) self._enabled = True + self._name = name self.logger = logging.getLogger("pubnub" if not name else f"pubnub.{name}") def __del__(self): @@ -88,3 +89,7 @@ def dispatch_effects(self): def stop(self): self._enabled = False + + @property + def name(self): + return self._name diff --git a/pubnub/pubnub_asyncio.py b/pubnub/pubnub_asyncio.py index b145452a..0080e4c3 100644 --- a/pubnub/pubnub_asyncio.py +++ b/pubnub/pubnub_asyncio.py @@ -56,10 +56,6 @@ def __init__(self, config, custom_event_loop=None, subscription_manager=None): self._telemetry_manager = AsyncioTelemetryManager() - def __del__(self): - if self.event_loop.is_running(): - self.event_loop.create_task(self.close_session()) - async def close_pending_tasks(self, tasks): await asyncio.gather(*tasks) await asyncio.sleep(0.1) @@ -90,6 +86,7 @@ async def stop(self): await self.close_session() if self._subscription_manager: self._subscription_manager.stop() + await self.close_session() def sdk_platform(self): return "-Asyncio" @@ -559,8 +556,9 @@ class EventEngineSubscriptionManager(SubscriptionManager): def __init__(self, pubnub_instance): self.event_engine = StateMachine(states.UnsubscribedState, name="subscribe") - # self.presence_engine = StateMachine(states.HeartbeatInactiveState, name="presence") + self.presence_engine = StateMachine(states.HeartbeatInactiveState, name="presence") self.event_engine.get_dispatcher().set_pn(pubnub_instance) + self.presence_engine.get_dispatcher().set_pn(pubnub_instance) self.loop = asyncio.new_event_loop() super().__init__(pubnub_instance) @@ -574,27 +572,27 @@ def adapt_subscribe_builder(self, subscribe_operation: SubscribeOperation): if subscribe_operation.timetoken: subscription_event = events.SubscriptionRestoredEvent( - channels=subscribe_operation.channels, + channels=subscribe_operation.channels_with_pressence, groups=subscribe_operation.channel_groups, timetoken=subscribe_operation.timetoken ) else: subscription_event = events.SubscriptionChangedEvent( - channels=subscribe_operation.channels, + channels=subscribe_operation.channels_with_pressence, groups=subscribe_operation.channel_groups ) self.event_engine.trigger(subscription_event) - # self.presence_engine.trigger(events.HeartbeatJoinedEvent( - # channels=subscribe_operation.channels, - # groups=subscribe_operation.channel_groups - # )) + self.presence_engine.trigger(events.HeartbeatJoinedEvent( + channels=subscribe_operation.channels, + groups=subscribe_operation.channel_groups + )) def adapt_unsubscribe_builder(self, unsubscribe_operation): if not isinstance(unsubscribe_operation, UnsubscribeOperation): raise PubNubException('Invalid Unsubscribe Operation') event = events.SubscriptionChangedEvent(None, None) self.event_engine.trigger(event) - # self.presence_engine.trigger(events.HeartbeatLeftAllEvent()) + self.presence_engine.trigger(events.HeartbeatLeftAllEvent()) class AsyncioSubscribeMessageWorker(SubscribeMessageWorker): diff --git a/tests/acceptance/subscribe/steps/then_steps.py b/tests/acceptance/subscribe/steps/then_steps.py index 6efbe585..943a4ae8 100644 --- a/tests/acceptance/subscribe/steps/then_steps.py +++ b/tests/acceptance/subscribe/steps/then_steps.py @@ -43,16 +43,11 @@ def parse_log_line(line: str): lambda line: line.startswith('Triggered event') or line.startswith('Invoke effect'), context.log_stream.getvalue().splitlines() ))] - try: - for index, expected in enumerate(context.table): - logged_type, logged_name = normalized_log[index] - expected_type, expected_name = expected - assert expected_type == logged_type, f'on line {index + 1} => {expected_type} != {logged_type}' - assert expected_name == logged_name, f'on line {index + 1} => {expected_name} != {logged_name}' - except Exception as e: - import ipdb - ipdb.set_trace() - raise e + for index, expected in enumerate(context.table): + logged_type, logged_name = normalized_log[index] + expected_type, expected_name = expected + assert expected_type == logged_type, f'on line {index + 1} => {expected_type} != {logged_type}' + assert expected_name == logged_name, f'on line {index + 1} => {expected_name} != {logged_name}' @then("I receive an error in my subscribe response") @@ -75,25 +70,52 @@ async def step_impl(context: PNContext): """ -@then(u'I wait {wait_time} seconds') +@then("I wait '{wait_time}' seconds") @async_run_until_complete async def step_impl(context: PNContext, wait_time: str): await busypie.wait() \ .at_most(int(wait_time)) \ .poll_delay(1) \ - .poll_interval(1) + .poll_interval(1) \ + .until_async(lambda: True) @then(u'I observe the following Events and Invocations of the Presence EE') @async_run_until_complete async def step_impl(context): - pass + def parse_log_line(line: str): + line_type = 'event' if line.startswith('Triggered event') else 'invocation' + m = re.search('([A-Za-z])+(Event|Effect)', line) + name = m.group(0).replace('Effect', '').replace('Event', '') + name = name.replace('Effect', '').replace('Event', '') + name = re.sub(r'([A-Z])', r'_\1', name).upper().lstrip('_') + name = name.replace('HEARTBEAT_JOIN', 'JOIN').replace('HEARTBEAT_WAIT', 'WAIT') + return (line_type, name) + + normalized_log = [parse_log_line(log_line) for log_line in list(filter( + lambda line: line.startswith('Triggered event') or line.startswith('Invoke effect'), + context.log_stream.getvalue().splitlines() + ))] + + try: + for index, expected in enumerate(context.table): + logged_type, logged_name = normalized_log[index] + expected_type, expected_name = expected + assert expected_type == logged_type, f'on line {index + 1} => {expected_type} != {logged_type}' + assert expected_name == logged_name, f'on line {index + 1} => {expected_name} != {logged_name}' + except Exception: + import ipdb + ipdb.set_trace() @then(u'I wait for getting Presence joined events') @async_run_until_complete async def step_impl(context: PNContext): - pass + await busypie.wait() \ + .at_most(15) \ + .poll_delay(3) \ + .poll_interval(1) \ + .until_async(lambda: True) @then(u'I receive an error in my heartbeat response') diff --git a/tests/acceptance/subscribe/steps/when_steps.py b/tests/acceptance/subscribe/steps/when_steps.py index ef9cbdd9..63f4ffab 100644 --- a/tests/acceptance/subscribe/steps/when_steps.py +++ b/tests/acceptance/subscribe/steps/when_steps.py @@ -1,16 +1,15 @@ from behave import when +from behave.api.async_step import async_run_until_complete from tests.acceptance.subscribe.environment import PNContext, AcceptanceCallback @when('I subscribe') def step_impl(context: PNContext): - print(f'WHEN I subscribe {id(context.pubnub)}') context.pubnub.subscribe().channels('foo').execute() @when('I subscribe with timetoken {timetoken}') def step_impl(context: PNContext, timetoken: str): # noqa F811 - print(f'WHEN I subscribe with TT {id(context.pubnub)}') callback = AcceptanceCallback() context.pubnub.add_listener(callback) context.pubnub.subscribe().channels('foo').with_timetoken(int(timetoken)).execute() @@ -21,11 +20,13 @@ def step_impl(context: PNContext, timetoken: str): # noqa F811 """ -@when(u'I join {channel1}, {channel2}, {channel3} channels') -def step_impl(context, channel1, channel2, channel3): +@when("I join '{channel1}', '{channel2}', '{channel3}' channels") +@async_run_until_complete +async def step_impl(context, channel1, channel2, channel3): context.pubnub.subscribe().channels([channel1, channel2, channel3]).execute() -@when(u'I join {channel1}, {channel2}, {channel3} channels with presence') -def step_impl(context, channel1, channel2, channel3): +@when("I join '{channel1}', '{channel2}', '{channel3}' channels with presence") +@async_run_until_complete +async def step_impl(context, channel1, channel2, channel3): context.pubnub.subscribe().channels([channel1, channel2, channel3]).with_presence().execute() diff --git a/tests/functional/event_engine/test_managed_effect.py b/tests/functional/event_engine/test_managed_effect.py index 26c46530..04c55e8e 100644 --- a/tests/functional/event_engine/test_managed_effect.py +++ b/tests/functional/event_engine/test_managed_effect.py @@ -1,10 +1,16 @@ +import pytest +import asyncio + from unittest.mock import patch from pubnub.enums import PNReconnectionPolicy from pubnub.event_engine import manage_effects from pubnub.event_engine.models import effects from pubnub.event_engine.dispatcher import Dispatcher +from pubnub.event_engine.models import states from pubnub.event_engine.models.states import UnsubscribedState from pubnub.event_engine.statemachine import StateMachine +from pubnub.pubnub_asyncio import PubNubAsyncio +from tests.helper import pnconf_env_copy class FakeConfig: @@ -82,3 +88,15 @@ def test_dispatch_stop_receive_reconnect_effect(): dispatcher.dispatch_effect(effects.ReceiveReconnectEffect(['chan'])) dispatcher.dispatch_effect(effects.CancelReceiveReconnectEffect()) mocked_stop.assert_called() + + +@pytest.mark.asyncio +async def test_cancel_effect(): + pubnub = PubNubAsyncio(pnconf_env_copy()) + event_engine = StateMachine(states.HeartbeatInactiveState, name="presence") + managed_effects_factory = manage_effects.ManagedEffectFactory(pubnub, event_engine) + managed_wait_effect = managed_effects_factory.create(effect=effects.HeartbeatWaitEffect(10)) + managed_wait_effect.run() + await asyncio.sleep(1) + managed_wait_effect.stop() + await pubnub.stop() diff --git a/tests/functional/event_engine/test_subscribe.py b/tests/functional/event_engine/test_subscribe.py index 37fbaf50..588c60e8 100644 --- a/tests/functional/event_engine/test_subscribe.py +++ b/tests/functional/event_engine/test_subscribe.py @@ -62,6 +62,7 @@ async def test_subscribe(): message_callback.assert_called() pubnub.unsubscribe_all() pubnub._subscription_manager.stop() + await pubnub.stop() async def delayed_publish(channel, message, delay): @@ -84,6 +85,7 @@ async def test_handshaking(): assert pubnub._subscription_manager.event_engine.get_state_name() == states.ReceivingState.__name__ status_callback.assert_called() pubnub._subscription_manager.stop() + await pubnub.stop() @pytest.mark.asyncio @@ -112,7 +114,7 @@ def is_state(state): assert pubnub._subscription_manager.event_engine.get_state_name() == states.HandshakeFailedState.__name__ pubnub._subscription_manager.stop() - await pubnub.close_session() + await pubnub.stop() @pytest.mark.asyncio @@ -141,3 +143,4 @@ def is_state(state): .until_async(lambda: is_state(states.HandshakeReconnectingState.__name__)) assert pubnub._subscription_manager.event_engine.get_state_name() == states.HandshakeReconnectingState.__name__ pubnub._subscription_manager.stop() + await pubnub.stop()