Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update state when connection is hangup #1841

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/core/src/modules/connections/ConnectionEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { BaseEvent } from '../../agent/Events'

export enum ConnectionEventTypes {
ConnectionStateChanged = 'ConnectionStateChanged',
ConnectionDidRotated = 'ConnectionDidRotated',
}

export interface ConnectionStateChangedEvent extends BaseEvent {
Expand All @@ -13,3 +14,19 @@ export interface ConnectionStateChangedEvent extends BaseEvent {
previousState: DidExchangeState | null
}
}

export interface ConnectionDidRotatedEvent extends BaseEvent {
type: typeof ConnectionEventTypes.ConnectionDidRotated
payload: {
connectionRecord: ConnectionRecord

ourDid?: {
from: string
to: string
}
theirDid?: {
from: string
to: string
}
}
}
50 changes: 47 additions & 3 deletions packages/core/src/modules/connections/__tests__/did-rotate.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */

import type { ConnectionRecord } from '../repository'

import { ReplaySubject, first, firstValueFrom, timeout } from 'rxjs'

import { MessageSender } from '../../..//agent/MessageSender'
Expand All @@ -11,6 +9,8 @@ import {
makeConnection,
waitForAgentMessageProcessedEvent,
waitForBasicMessage,
waitForConnectionRecord,
waitForDidRotate,
} from '../../../../tests/helpers'
import { Agent } from '../../../agent/Agent'
import { getOutboundMessageContext } from '../../../agent/getOutboundMessageContext'
Expand All @@ -20,6 +20,8 @@ import { BasicMessage } from '../../basic-messages'
import { createPeerDidDocumentFromServices } from '../../dids'
import { ConnectionsModule } from '../ConnectionsModule'
import { DidRotateProblemReportMessage, HangupMessage, DidRotateAckMessage } from '../messages'
import { DidExchangeState } from '../models'
import { ConnectionRecord } from '../repository'

import { InMemoryDidRegistry } from './InMemoryDidRegistry'

Expand Down Expand Up @@ -233,11 +235,33 @@ describe('Rotation E2E tests', () => {
didDocument,
})

const waitForAllDidRotate = Promise.all([waitForDidRotate(aliceAgent, {}), waitForDidRotate(bobAgent, {})])

// Do did rotate
await aliceAgent.connections.rotate({ connectionId: aliceBobConnection!.id, toDid: did })

// Wait for acknowledge
await waitForAgentMessageProcessedEvent(aliceAgent, { messageType: DidRotateAckMessage.type.messageTypeUri })
const [firstRotate, secondRotate] = await waitForAllDidRotate

const preRotateDid = aliceBobConnection!.did
expect(firstRotate).toEqual({
connectionRecord: expect.any(ConnectionRecord),
ourDid: {
from: preRotateDid,
to: did,
},
theirDid: undefined,
})

expect(secondRotate).toEqual({
connectionRecord: expect.any(ConnectionRecord),
ourDid: undefined,
theirDid: {
from: preRotateDid,
to: did,
},
})

// Send message to previous did
await bobAgent.dependencyManager.resolve(MessageSender).sendMessage(messageToPreviousDid)
Expand Down Expand Up @@ -323,13 +347,33 @@ describe('Rotation E2E tests', () => {
connectionRecord: bobAliceConnection!.clone(),
})

const connectionsAbandoned = Promise.all([
waitForConnectionRecord(aliceAgent, {
state: DidExchangeState.Abandoned,
threadId: aliceBobConnection?.threadId,
}),
waitForConnectionRecord(bobAgent, {
state: DidExchangeState.Abandoned,
threadId: aliceBobConnection?.threadId,
}),
])
await aliceAgent.connections.hangup({ connectionId: aliceBobConnection!.id })

// Wait for hangup
await waitForAgentMessageProcessedEvent(bobAgent, {
messageType: HangupMessage.type.messageTypeUri,
})

const [aliceAbandoned, bobAbandoned] = await connectionsAbandoned
expect(aliceAbandoned).toMatchObject({
state: DidExchangeState.Abandoned,
errorMessage: 'Connection hangup by us',
})
expect(bobAbandoned).toMatchObject({
state: DidExchangeState.Abandoned,
errorMessage: 'Connection hangup by other party',
})

// If Bob attempts to send a message to Alice after they received the hangup, framework should reject it
expect(bobAgent.basicMessages.sendMessage(bobAliceConnection!.id, 'Message after hangup')).rejects.toThrowError()

Expand Down Expand Up @@ -358,7 +402,7 @@ describe('Rotation E2E tests', () => {
await aliceAgent.connections.hangup({ connectionId: aliceBobConnection!.id, deleteAfterHangup: true })

// Verify that alice connection has been effectively deleted
expect(aliceAgent.connections.getById(aliceBobConnection!.id)).rejects.toThrowError(RecordNotFoundError)
expect(aliceAgent.connections.getById(aliceBobConnection!.id)).rejects.toThrow(RecordNotFoundError)

// Wait for hangup
await waitForAgentMessageProcessedEvent(bobAgent, {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import type { Routing } from './ConnectionService'
import type { AgentContext } from '../../../agent'
import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext'
import type { ConnectionDidRotatedEvent, ConnectionStateChangedEvent } from '../ConnectionEvents'
import type { ConnectionRecord } from '../repository/ConnectionRecord'

import { EventEmitter } from '../../../agent/EventEmitter'
import { OutboundMessageContext } from '../../../agent/models'
import { InjectionSymbols } from '../../../constants'
import { CredoError } from '../../../error'
Expand All @@ -18,8 +20,10 @@ import {
isValidPeerDid,
} from '../../dids'
import { getMediationRecordForDidDocument } from '../../routing/services/helpers'
import { ConnectionEventTypes } from '../ConnectionEvents'
import { ConnectionsModuleConfig } from '../ConnectionsModuleConfig'
import { DidRotateMessage, DidRotateAckMessage, DidRotateProblemReportMessage, HangupMessage } from '../messages'
import { DidExchangeState } from '../models'
import { ConnectionMetadataKeys } from '../repository/ConnectionMetadataTypes'

import { ConnectionService } from './ConnectionService'
Expand All @@ -29,10 +33,16 @@ import { createPeerDidFromServices, getDidDocumentForCreatedDid, routingToServic
export class DidRotateService {
private didResolverService: DidResolverService
private logger: Logger
private eventEmitter: EventEmitter

public constructor(didResolverService: DidResolverService, @inject(InjectionSymbols.Logger) logger: Logger) {
public constructor(
didResolverService: DidResolverService,
@inject(InjectionSymbols.Logger) logger: Logger,
eventEmitter: EventEmitter
) {
this.didResolverService = didResolverService
this.logger = logger
this.eventEmitter = eventEmitter
}

public async createRotate(
Expand Down Expand Up @@ -95,9 +105,13 @@ export class DidRotateService {
connection.previousDids = [...connection.previousDids, connection.did]
}

const previousState = connection.state
connection.did = undefined
connection.state = DidExchangeState.Abandoned
connection.errorMessage = 'Connection hangup by us'

await agentContext.dependencyManager.resolve(ConnectionService).update(agentContext, connection)
this.emitStateChangedEvent(agentContext, connection, previousState)

return message
}
Expand All @@ -119,9 +133,13 @@ export class DidRotateService {
connection.previousTheirDids = [...connection.previousTheirDids, connection.theirDid]
}

const previousState = connection.state
connection.theirDid = undefined
connection.state = DidExchangeState.Abandoned
connection.errorMessage = 'Connection hangup by other party'

await agentContext.dependencyManager.resolve(ConnectionService).update(agentContext, connection)
this.emitStateChangedEvent(agentContext, connection, previousState)
}

/**
Expand Down Expand Up @@ -197,9 +215,13 @@ export class DidRotateService {
connection.previousTheirDids = [...connection.previousTheirDids, connection.theirDid]
}

const previousTheirDid = connection.theirDid
connection.theirDid = newDid

await agentContext.dependencyManager.resolve(ConnectionService).update(agentContext, connection)
this.emitDidRotatedEvent(agentContext, connection, {
previousTheirDid,
})

return outboundMessageContext
}
Expand All @@ -225,11 +247,15 @@ export class DidRotateService {
// Store previous did in order to still accept out-of-order messages that arrived later using it
if (connection.did) connection.previousDids = [...connection.previousDids, connection.did]

const previousOurDid = connection.did
connection.did = didRotateMetadata.did
connection.mediatorId = didRotateMetadata.mediatorId
connection.metadata.delete(ConnectionMetadataKeys.DidRotate)

await agentContext.dependencyManager.resolve(ConnectionService).update(agentContext, connection)
this.emitDidRotatedEvent(agentContext, connection, {
previousOurDid,
})
}

/**
Expand Down Expand Up @@ -271,4 +297,49 @@ export class DidRotateService {

await agentContext.dependencyManager.resolve(ConnectionService).update(agentContext, connection)
}

private emitDidRotatedEvent(
agentContext: AgentContext,
connectionRecord: ConnectionRecord,
{ previousOurDid, previousTheirDid }: { previousOurDid?: string; previousTheirDid?: string }
) {
this.eventEmitter.emit<ConnectionDidRotatedEvent>(agentContext, {
type: ConnectionEventTypes.ConnectionDidRotated,
payload: {
// Connection record in event should be static
connectionRecord: connectionRecord.clone(),

ourDid:
previousOurDid && connectionRecord.did
? {
from: previousOurDid,
to: connectionRecord.did,
}
: undefined,

theirDid:
previousTheirDid && connectionRecord.theirDid
? {
from: previousTheirDid,
to: connectionRecord.theirDid,
}
: undefined,
},
})
}

private emitStateChangedEvent(
agentContext: AgentContext,
connectionRecord: ConnectionRecord,
previousState: DidExchangeState | null
) {
this.eventEmitter.emit<ConnectionStateChangedEvent>(agentContext, {
type: ConnectionEventTypes.ConnectionStateChanged,
payload: {
// Connection record in event should be static
connectionRecord: connectionRecord.clone(),
previousState,
},
})
}
}
57 changes: 52 additions & 5 deletions packages/core/tests/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AgentMessageProcessedEvent,
RevocationNotificationReceivedEvent,
KeyDidCreateOptions,
ConnectionDidRotatedEvent,
} from '../src'
import type { AgentModulesInput, EmptyModuleMap } from '../src/agent/AgentModules'
import type { TrustPingReceivedEvent, TrustPingResponseReceivedEvent } from '../src/modules/connections/TrustPingEvents'
Expand All @@ -28,7 +29,7 @@
import { readFileSync } from 'fs'
import path from 'path'
import { lastValueFrom, firstValueFrom, ReplaySubject } from 'rxjs'
import { catchError, filter, map, take, timeout } from 'rxjs/operators'
import { catchError, filter, map, take, tap, timeout } from 'rxjs/operators'

Check warning on line 32 in packages/core/tests/helpers.ts

View workflow job for this annotation

GitHub Actions / Validate

'tap' is defined but never used

import { InMemoryWalletModule } from '../../../tests/InMemoryWalletModule'
import { agentDependencies } from '../../node/src'
Expand Down Expand Up @@ -231,6 +232,8 @@
e.type === CredentialEventTypes.CredentialStateChanged
const isConnectionStateChangedEvent = (e: BaseEvent): e is ConnectionStateChangedEvent =>
e.type === ConnectionEventTypes.ConnectionStateChanged
const isConnectionDidRotatedEvent = (e: BaseEvent): e is ConnectionDidRotatedEvent =>
e.type === ConnectionEventTypes.ConnectionDidRotated
const isTrustPingReceivedEvent = (e: BaseEvent): e is TrustPingReceivedEvent =>
e.type === TrustPingEventTypes.TrustPingReceivedEvent
const isTrustPingResponseReceivedEvent = (e: BaseEvent): e is TrustPingResponseReceivedEvent =>
Expand Down Expand Up @@ -455,6 +458,38 @@
return waitForCredentialRecordSubject(observable, options)
}

export function waitForDidRotateSubject(
subject: ReplaySubject<BaseEvent> | Observable<BaseEvent>,
{
threadId,
state,
timeoutMs = 15000, // sign and store credential in W3c credential protocols take several seconds
}: {
threadId?: string
state?: DidExchangeState
previousState?: DidExchangeState | null
timeoutMs?: number
}
) {
const observable = subject instanceof ReplaySubject ? subject.asObservable() : subject

return firstValueFrom(
observable.pipe(
filter(isConnectionDidRotatedEvent),
filter((e) => threadId === undefined || e.payload.connectionRecord.threadId === threadId),
filter((e) => state === undefined || e.payload.connectionRecord.state === state),
timeout(timeoutMs),
catchError(() => {
throw new Error(`ConnectionDidRotated event not emitted within specified timeout: {
threadId: ${threadId},
state: ${state}
}`)
}),
map((e) => e.payload)
)
)
}

export function waitForConnectionRecordSubject(
subject: ReplaySubject<BaseEvent> | Observable<BaseEvent>,
{
Expand All @@ -480,10 +515,10 @@
timeout(timeoutMs),
catchError(() => {
throw new Error(`ConnectionStateChanged event not emitted within specified timeout: {
previousState: ${previousState},
threadId: ${threadId},
state: ${state}
}`)
previousState: ${previousState},
threadId: ${threadId},
state: ${state}
}`)
}),
map((e) => e.payload.connectionRecord)
)
Expand All @@ -503,6 +538,18 @@
return waitForConnectionRecordSubject(observable, options)
}

export async function waitForDidRotate(
agent: Agent,
options: {
threadId?: string
state?: DidExchangeState
timeoutMs?: number
}
) {
const observable = agent.events.observable<ConnectionDidRotatedEvent>(ConnectionEventTypes.ConnectionDidRotated)
return waitForDidRotateSubject(observable, options)
}

export async function waitForBasicMessage(
agent: Agent,
{ content, connectionId }: { content?: string; connectionId?: string }
Expand Down
Loading