diff --git a/acceptor.go b/acceptor.go index 9700fe0..dce5683 100644 --- a/acceptor.go +++ b/acceptor.go @@ -20,7 +20,10 @@ const ( var RouterRoles = map[string]any{ //nolint:gochecknoglobals "dealer": map[string]any{ - "features": map[string]any{}, + "features": map[string]any{ + FeatureProgressiveCallInvocations: true, + FeatureCallCancelling: true, + }, }, "broker": map[string]any{ "features": map[string]any{}, diff --git a/dealer.go b/dealer.go index 9a6b07a..00d7577 100644 --- a/dealer.go +++ b/dealer.go @@ -12,6 +12,12 @@ const ( OptionProgress = "progress" ) +const ( + FeatureProgressiveCallInvocations = "progressive_call_invocations" + FeatureProgressiveCallResults = "progressive_call_results" + FeatureCallCancelling = "call_canceling" +) + type PendingInvocation struct { RequestID int64 CallerID int64 @@ -27,11 +33,17 @@ type Registration struct { InvocationPolicy string } +type CallMap struct { + CallerID int64 + CallID int64 +} + type Dealer struct { sessions map[int64]*SessionDetails registrationsByProcedure map[string]*Registration registrationsBySession map[int64]map[int64]*Registration pendingCalls map[int64]*PendingInvocation + invocationIDbyCall map[CallMap]int64 idGen *SessionScopeIDGenerator sync.Mutex @@ -43,6 +55,7 @@ func NewDealer() *Dealer { registrationsByProcedure: make(map[string]*Registration), registrationsBySession: make(map[int64]map[int64]*Registration), pendingCalls: make(map[int64]*PendingInvocation), + invocationIDbyCall: make(map[CallMap]int64), idGen: &SessionScopeIDGenerator{}, } } @@ -113,13 +126,19 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message break } receiveProgress, _ := call.Options()[OptionReceiveProgress].(bool) - - invocationID := d.idGen.NextID() - d.pendingCalls[invocationID] = &PendingInvocation{ - RequestID: call.RequestID(), - CallerID: sessionID, - CalleeID: callee, - ReceiveProgress: receiveProgress, + progress, _ := call.Options()[OptionProgress].(bool) + + invocationID, ok := d.invocationIDbyCall[CallMap{CallerID: sessionID, CallID: call.RequestID()}] + if !ok || !progress { + invocationID = d.idGen.NextID() + d.pendingCalls[invocationID] = &PendingInvocation{ + RequestID: call.RequestID(), + CallerID: sessionID, + CalleeID: callee, + ReceiveProgress: receiveProgress, + Progress: progress, + } + d.invocationIDbyCall[CallMap{CallerID: sessionID, CallID: call.RequestID()}] = invocationID } var invocation *messages.Invocation @@ -127,7 +146,13 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message invocation = messages.NewInvocationBinary(invocationID, regs.ID, nil, call.Payload(), call.PayloadSerializer()) } else { - details := map[string]any{OptionReceiveProgress: receiveProgress} + details := map[string]any{} + if receiveProgress { + details[OptionReceiveProgress] = receiveProgress + } + if progress { + details[OptionProgress] = progress + } invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs()) } diff --git a/dealer_test.go b/dealer_test.go index 0dd7122..2c47237 100644 --- a/dealer_test.go +++ b/dealer_test.go @@ -165,3 +165,47 @@ func TestProgressiveCallResults(t *testing.T) { progress, _ := result.Details()[wampproto.OptionReceiveProgress].(bool) require.False(t, progress) } + +func TestProgressiveCallInvocations(t *testing.T) { + dealer := wampproto.NewDealer() + + callee := wampproto.NewSessionDetails(1, "realm", "authid", "anonymous", false) + caller := wampproto.NewSessionDetails(2, "realm", "authid", "anonymous", false) + + err := dealer.AddSession(callee) + require.NoError(t, err) + err = dealer.AddSession(caller) + require.NoError(t, err) + + register := messages.NewRegister(3, nil, "foo.bar") + _, err = dealer.ReceiveMessage(callee.ID(), register) + require.NoError(t, err) + + call := messages.NewCall(4, map[string]any{wampproto.OptionProgress: true}, "foo.bar", []any{}, nil) + messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call) + require.NoError(t, err) + require.Equal(t, callee.ID(), messageWithRecipient.Recipient) + + invMessage := messageWithRecipient.Message.(*messages.Invocation) + require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool)) + + invRequestID := invMessage.RequestID() + for i := 0; i < 10; i++ { + call = messages.NewCall(4, map[string]any{wampproto.OptionProgress: true}, "foo.bar", []any{}, nil) + messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), call) + require.NoError(t, err) + + invMessage = messageWithRecipient.Message.(*messages.Invocation) + require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool)) + require.Equal(t, invRequestID, invMessage.RequestID()) + } + + finalCall := messages.NewCall(4, map[string]any{}, "foo.bar", []any{}, nil) + messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), finalCall) + require.NoError(t, err) + require.Equal(t, callee.ID(), messageWithRecipient.Recipient) + + invocation := messageWithRecipient.Message.(*messages.Invocation) + inProgress, _ := invocation.Details()[wampproto.OptionProgress].(bool) + require.False(t, inProgress) +} diff --git a/joiner.go b/joiner.go index 7c6e1df..7cc2dff 100644 --- a/joiner.go +++ b/joiner.go @@ -11,12 +11,15 @@ import ( var ClientRoles = map[string]any{ //nolint:gochecknoglobals "caller": map[string]any{ - "features": map[string]any{}, + "features": map[string]any{ + FeatureProgressiveCallInvocations: true, + }, }, "callee": map[string]any{ "features": map[string]any{ - "progressive_call_results": true, - "call_canceling": true, + FeatureProgressiveCallInvocations: true, + FeatureProgressiveCallResults: true, + FeatureCallCancelling: true, }, }, "publisher": map[string]any{