Skip to content

Commit

Permalink
chore: move outgoing message check from status-go to go-waku (#1180)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaichaosun authored Aug 6, 2024
1 parent 5aa1131 commit 240051b
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 0 deletions.
248 changes: 248 additions & 0 deletions waku/v2/api/publish/message_check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
package publish

import (
"bytes"
"context"
"sync"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
"github.com/waku-org/go-waku/waku/v2/protocol/store"
"github.com/waku-org/go-waku/waku/v2/timesource"
"go.uber.org/zap"
)

const DefaultMaxHashQueryLength = 100
const DefaultHashQueryInterval = 3 * time.Second
const DefaultMessageSentPeriod = 3 // in seconds
const DefaultMessageExpiredPerid = 10 // in seconds

type MessageSentCheckOption func(*MessageSentCheck) error

// MessageSentCheck tracks the outgoing messages and check against store node
// if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query
// if the message keeps missing after `messageExpiredPerid`, the message id will be expired
type MessageSentCheck struct {
messageIDs map[string]map[common.Hash]uint32
messageIDsMu sync.RWMutex
storePeerID peer.ID
MessageStoredChan chan common.Hash
MessageExpiredChan chan common.Hash
ctx context.Context
store *store.WakuStore
timesource timesource.Timesource
logger *zap.Logger
maxHashQueryLength uint64
hashQueryInterval time.Duration
messageSentPeriod uint32
messageExpiredPerid uint32
}

// NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters
func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck {
return &MessageSentCheck{
messageIDs: make(map[string]map[common.Hash]uint32),
messageIDsMu: sync.RWMutex{},
MessageStoredChan: make(chan common.Hash, 1000),
MessageExpiredChan: make(chan common.Hash, 1000),
ctx: ctx,
store: store,
timesource: timesource,
logger: logger,
maxHashQueryLength: DefaultMaxHashQueryLength,
hashQueryInterval: DefaultHashQueryInterval,
messageSentPeriod: DefaultMessageSentPeriod,
messageExpiredPerid: DefaultMessageExpiredPerid,
}
}

// WithMaxHashQueryLength sets the maximum number of message hashes to query in one request
func WithMaxHashQueryLength(count uint64) MessageSentCheckOption {
return func(params *MessageSentCheck) error {
params.maxHashQueryLength = count
return nil
}
}

// WithHashQueryInterval sets the interval to query the store node
func WithHashQueryInterval(interval time.Duration) MessageSentCheckOption {
return func(params *MessageSentCheck) error {
params.hashQueryInterval = interval
return nil
}
}

// WithMessageSentPeriod sets the delay period to query the store node after message is published
func WithMessageSentPeriod(period uint32) MessageSentCheckOption {
return func(params *MessageSentCheck) error {
params.messageSentPeriod = period
return nil
}
}

// WithMessageExpiredPerid sets the period that a message is considered expired
func WithMessageExpiredPerid(period uint32) MessageSentCheckOption {
return func(params *MessageSentCheck) error {
params.messageExpiredPerid = period
return nil
}
}

// Add adds a message for message sent check
func (m *MessageSentCheck) Add(topic string, messageID common.Hash, sentTime uint32) {
m.messageIDsMu.Lock()
defer m.messageIDsMu.Unlock()

if _, ok := m.messageIDs[topic]; !ok {
m.messageIDs[topic] = make(map[common.Hash]uint32)
}
m.messageIDs[topic][messageID] = sentTime
}

// DeleteByMessageIDs deletes the message ids from the message sent check, used by scenarios like message acked with MVDS
func (m *MessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) {
m.messageIDsMu.Lock()
defer m.messageIDsMu.Unlock()

for pubsubTopic, subMsgs := range m.messageIDs {
for _, hash := range messageIDs {
delete(subMsgs, hash)
if len(subMsgs) == 0 {
delete(m.messageIDs, pubsubTopic)
} else {
m.messageIDs[pubsubTopic] = subMsgs
}
}
}
}

// SetStorePeerID sets the peer id of store node
func (m *MessageSentCheck) SetStorePeerID(peerID peer.ID) {
m.storePeerID = peerID
}

// CheckIfMessagesStored checks if the tracked outgoing messages are stored periodically
func (m *MessageSentCheck) CheckIfMessagesStored() {
ticker := time.NewTicker(m.hashQueryInterval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
m.logger.Debug("stop the look for message stored check")
return
case <-ticker.C:
m.messageIDsMu.Lock()
m.logger.Debug("running loop for messages stored check", zap.Any("messageIds", m.messageIDs))
pubsubTopics := make([]string, 0, len(m.messageIDs))
pubsubMessageIds := make([][]common.Hash, 0, len(m.messageIDs))
pubsubMessageTime := make([][]uint32, 0, len(m.messageIDs))
for pubsubTopic, subMsgs := range m.messageIDs {
var queryMsgIds []common.Hash
var queryMsgTime []uint32
for msgID, sendTime := range subMsgs {
if uint64(len(queryMsgIds)) >= m.maxHashQueryLength {
break
}
// message is sent 5 seconds ago, check if it's stored
if uint32(m.timesource.Now().Unix()) > sendTime+m.messageSentPeriod {
queryMsgIds = append(queryMsgIds, msgID)
queryMsgTime = append(queryMsgTime, sendTime)
}
}
m.logger.Debug("store query for message hashes", zap.Any("queryMsgIds", queryMsgIds), zap.String("pubsubTopic", pubsubTopic))
if len(queryMsgIds) > 0 {
pubsubTopics = append(pubsubTopics, pubsubTopic)
pubsubMessageIds = append(pubsubMessageIds, queryMsgIds)
pubsubMessageTime = append(pubsubMessageTime, queryMsgTime)
}
}
m.messageIDsMu.Unlock()

pubsubProcessedMessages := make([][]common.Hash, len(pubsubTopics))
for i, pubsubTopic := range pubsubTopics {
processedMessages := m.messageHashBasedQuery(m.ctx, pubsubMessageIds[i], pubsubMessageTime[i], pubsubTopic)
pubsubProcessedMessages[i] = processedMessages
}

m.messageIDsMu.Lock()
for i, pubsubTopic := range pubsubTopics {
subMsgs, ok := m.messageIDs[pubsubTopic]
if !ok {
continue
}
for _, hash := range pubsubProcessedMessages[i] {
delete(subMsgs, hash)
if len(subMsgs) == 0 {
delete(m.messageIDs, pubsubTopic)
} else {
m.messageIDs[pubsubTopic] = subMsgs
}
}
}
m.logger.Debug("messages for next store hash query", zap.Any("messageIds", m.messageIDs))
m.messageIDsMu.Unlock()

}
}
}

func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []common.Hash, relayTime []uint32, pubsubTopic string) []common.Hash {
selectedPeer := m.storePeerID
if selectedPeer == "" {
m.logger.Error("no store peer id available", zap.String("pubsubTopic", pubsubTopic))
return []common.Hash{}
}

var opts []store.RequestOption
requestID := protocol.GenerateRequestID()
opts = append(opts, store.WithRequestID(requestID))
opts = append(opts, store.WithPeer(selectedPeer))
opts = append(opts, store.WithPaging(false, m.maxHashQueryLength))
opts = append(opts, store.IncludeData(false))

messageHashes := make([]pb.MessageHash, len(hashes))
for i, hash := range hashes {
messageHashes[i] = pb.ToMessageHash(hash.Bytes())
}

m.logger.Debug("store.queryByHash request", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Any("messageHashes", messageHashes))

result, err := m.store.QueryByHash(ctx, messageHashes, opts...)
if err != nil {
m.logger.Error("store.queryByHash failed", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Error(err))
return []common.Hash{}
}

m.logger.Debug("store.queryByHash result", zap.String("requestID", hexutil.Encode(requestID)), zap.Int("messages", len(result.Messages())))

var ackHashes []common.Hash
var missedHashes []common.Hash
for i, hash := range hashes {
found := false
for _, msg := range result.Messages() {
if bytes.Equal(msg.GetMessageHash(), hash.Bytes()) {
found = true
break
}
}

if found {
ackHashes = append(ackHashes, hash)
m.MessageStoredChan <- hash
}

if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid {
missedHashes = append(missedHashes, hash)
m.MessageExpiredChan <- hash
}
}

m.logger.Debug("ack message hashes", zap.Any("ackHashes", ackHashes))
m.logger.Debug("missed message hashes", zap.Any("missedHashes", missedHashes))

return append(ackHashes, missedHashes...)
}
33 changes: 33 additions & 0 deletions waku/v2/api/publish/message_check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package publish

import (
"context"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)

func TestAddAndDelete(t *testing.T) {
ctx := context.TODO()
messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil)

messageSentCheck.Add("topic", [32]byte{1}, 1)
messageSentCheck.Add("topic", [32]byte{2}, 2)
messageSentCheck.Add("topic", [32]byte{3}, 3)
messageSentCheck.Add("another-topic", [32]byte{4}, 4)

require.Equal(t, uint32(1), messageSentCheck.messageIDs["topic"][[32]byte{1}])
require.Equal(t, uint32(2), messageSentCheck.messageIDs["topic"][[32]byte{2}])
require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}])
require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}])

messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{1}, [32]byte{2}})
require.NotNil(t, messageSentCheck.messageIDs["topic"])
require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}])

messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{3}})
require.Nil(t, messageSentCheck.messageIDs["topic"])

require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}])
}

0 comments on commit 240051b

Please sign in to comment.