-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: move outgoing message check from status-go to go-waku (#1180)
- Loading branch information
1 parent
5aa1131
commit 240051b
Showing
2 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
package publish | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"sync" | ||
"time" | ||
|
||
"github.com/ethereum/go-ethereum/common" | ||
"github.com/ethereum/go-ethereum/common/hexutil" | ||
"github.com/libp2p/go-libp2p/core/peer" | ||
"github.com/waku-org/go-waku/waku/v2/protocol" | ||
"github.com/waku-org/go-waku/waku/v2/protocol/pb" | ||
"github.com/waku-org/go-waku/waku/v2/protocol/store" | ||
"github.com/waku-org/go-waku/waku/v2/timesource" | ||
"go.uber.org/zap" | ||
) | ||
|
||
const DefaultMaxHashQueryLength = 100 | ||
const DefaultHashQueryInterval = 3 * time.Second | ||
const DefaultMessageSentPeriod = 3 // in seconds | ||
const DefaultMessageExpiredPerid = 10 // in seconds | ||
|
||
type MessageSentCheckOption func(*MessageSentCheck) error | ||
|
||
// MessageSentCheck tracks the outgoing messages and check against store node | ||
// if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query | ||
// if the message keeps missing after `messageExpiredPerid`, the message id will be expired | ||
type MessageSentCheck struct { | ||
messageIDs map[string]map[common.Hash]uint32 | ||
messageIDsMu sync.RWMutex | ||
storePeerID peer.ID | ||
MessageStoredChan chan common.Hash | ||
MessageExpiredChan chan common.Hash | ||
ctx context.Context | ||
store *store.WakuStore | ||
timesource timesource.Timesource | ||
logger *zap.Logger | ||
maxHashQueryLength uint64 | ||
hashQueryInterval time.Duration | ||
messageSentPeriod uint32 | ||
messageExpiredPerid uint32 | ||
} | ||
|
||
// NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters | ||
func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck { | ||
return &MessageSentCheck{ | ||
messageIDs: make(map[string]map[common.Hash]uint32), | ||
messageIDsMu: sync.RWMutex{}, | ||
MessageStoredChan: make(chan common.Hash, 1000), | ||
MessageExpiredChan: make(chan common.Hash, 1000), | ||
ctx: ctx, | ||
store: store, | ||
timesource: timesource, | ||
logger: logger, | ||
maxHashQueryLength: DefaultMaxHashQueryLength, | ||
hashQueryInterval: DefaultHashQueryInterval, | ||
messageSentPeriod: DefaultMessageSentPeriod, | ||
messageExpiredPerid: DefaultMessageExpiredPerid, | ||
} | ||
} | ||
|
||
// WithMaxHashQueryLength sets the maximum number of message hashes to query in one request | ||
func WithMaxHashQueryLength(count uint64) MessageSentCheckOption { | ||
return func(params *MessageSentCheck) error { | ||
params.maxHashQueryLength = count | ||
return nil | ||
} | ||
} | ||
|
||
// WithHashQueryInterval sets the interval to query the store node | ||
func WithHashQueryInterval(interval time.Duration) MessageSentCheckOption { | ||
return func(params *MessageSentCheck) error { | ||
params.hashQueryInterval = interval | ||
return nil | ||
} | ||
} | ||
|
||
// WithMessageSentPeriod sets the delay period to query the store node after message is published | ||
func WithMessageSentPeriod(period uint32) MessageSentCheckOption { | ||
return func(params *MessageSentCheck) error { | ||
params.messageSentPeriod = period | ||
return nil | ||
} | ||
} | ||
|
||
// WithMessageExpiredPerid sets the period that a message is considered expired | ||
func WithMessageExpiredPerid(period uint32) MessageSentCheckOption { | ||
return func(params *MessageSentCheck) error { | ||
params.messageExpiredPerid = period | ||
return nil | ||
} | ||
} | ||
|
||
// Add adds a message for message sent check | ||
func (m *MessageSentCheck) Add(topic string, messageID common.Hash, sentTime uint32) { | ||
m.messageIDsMu.Lock() | ||
defer m.messageIDsMu.Unlock() | ||
|
||
if _, ok := m.messageIDs[topic]; !ok { | ||
m.messageIDs[topic] = make(map[common.Hash]uint32) | ||
} | ||
m.messageIDs[topic][messageID] = sentTime | ||
} | ||
|
||
// DeleteByMessageIDs deletes the message ids from the message sent check, used by scenarios like message acked with MVDS | ||
func (m *MessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) { | ||
m.messageIDsMu.Lock() | ||
defer m.messageIDsMu.Unlock() | ||
|
||
for pubsubTopic, subMsgs := range m.messageIDs { | ||
for _, hash := range messageIDs { | ||
delete(subMsgs, hash) | ||
if len(subMsgs) == 0 { | ||
delete(m.messageIDs, pubsubTopic) | ||
} else { | ||
m.messageIDs[pubsubTopic] = subMsgs | ||
} | ||
} | ||
} | ||
} | ||
|
||
// SetStorePeerID sets the peer id of store node | ||
func (m *MessageSentCheck) SetStorePeerID(peerID peer.ID) { | ||
m.storePeerID = peerID | ||
} | ||
|
||
// CheckIfMessagesStored checks if the tracked outgoing messages are stored periodically | ||
func (m *MessageSentCheck) CheckIfMessagesStored() { | ||
ticker := time.NewTicker(m.hashQueryInterval) | ||
defer ticker.Stop() | ||
for { | ||
select { | ||
case <-m.ctx.Done(): | ||
m.logger.Debug("stop the look for message stored check") | ||
return | ||
case <-ticker.C: | ||
m.messageIDsMu.Lock() | ||
m.logger.Debug("running loop for messages stored check", zap.Any("messageIds", m.messageIDs)) | ||
pubsubTopics := make([]string, 0, len(m.messageIDs)) | ||
pubsubMessageIds := make([][]common.Hash, 0, len(m.messageIDs)) | ||
pubsubMessageTime := make([][]uint32, 0, len(m.messageIDs)) | ||
for pubsubTopic, subMsgs := range m.messageIDs { | ||
var queryMsgIds []common.Hash | ||
var queryMsgTime []uint32 | ||
for msgID, sendTime := range subMsgs { | ||
if uint64(len(queryMsgIds)) >= m.maxHashQueryLength { | ||
break | ||
} | ||
// message is sent 5 seconds ago, check if it's stored | ||
if uint32(m.timesource.Now().Unix()) > sendTime+m.messageSentPeriod { | ||
queryMsgIds = append(queryMsgIds, msgID) | ||
queryMsgTime = append(queryMsgTime, sendTime) | ||
} | ||
} | ||
m.logger.Debug("store query for message hashes", zap.Any("queryMsgIds", queryMsgIds), zap.String("pubsubTopic", pubsubTopic)) | ||
if len(queryMsgIds) > 0 { | ||
pubsubTopics = append(pubsubTopics, pubsubTopic) | ||
pubsubMessageIds = append(pubsubMessageIds, queryMsgIds) | ||
pubsubMessageTime = append(pubsubMessageTime, queryMsgTime) | ||
} | ||
} | ||
m.messageIDsMu.Unlock() | ||
|
||
pubsubProcessedMessages := make([][]common.Hash, len(pubsubTopics)) | ||
for i, pubsubTopic := range pubsubTopics { | ||
processedMessages := m.messageHashBasedQuery(m.ctx, pubsubMessageIds[i], pubsubMessageTime[i], pubsubTopic) | ||
pubsubProcessedMessages[i] = processedMessages | ||
} | ||
|
||
m.messageIDsMu.Lock() | ||
for i, pubsubTopic := range pubsubTopics { | ||
subMsgs, ok := m.messageIDs[pubsubTopic] | ||
if !ok { | ||
continue | ||
} | ||
for _, hash := range pubsubProcessedMessages[i] { | ||
delete(subMsgs, hash) | ||
if len(subMsgs) == 0 { | ||
delete(m.messageIDs, pubsubTopic) | ||
} else { | ||
m.messageIDs[pubsubTopic] = subMsgs | ||
} | ||
} | ||
} | ||
m.logger.Debug("messages for next store hash query", zap.Any("messageIds", m.messageIDs)) | ||
m.messageIDsMu.Unlock() | ||
|
||
} | ||
} | ||
} | ||
|
||
func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []common.Hash, relayTime []uint32, pubsubTopic string) []common.Hash { | ||
selectedPeer := m.storePeerID | ||
if selectedPeer == "" { | ||
m.logger.Error("no store peer id available", zap.String("pubsubTopic", pubsubTopic)) | ||
return []common.Hash{} | ||
} | ||
|
||
var opts []store.RequestOption | ||
requestID := protocol.GenerateRequestID() | ||
opts = append(opts, store.WithRequestID(requestID)) | ||
opts = append(opts, store.WithPeer(selectedPeer)) | ||
opts = append(opts, store.WithPaging(false, m.maxHashQueryLength)) | ||
opts = append(opts, store.IncludeData(false)) | ||
|
||
messageHashes := make([]pb.MessageHash, len(hashes)) | ||
for i, hash := range hashes { | ||
messageHashes[i] = pb.ToMessageHash(hash.Bytes()) | ||
} | ||
|
||
m.logger.Debug("store.queryByHash request", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Any("messageHashes", messageHashes)) | ||
|
||
result, err := m.store.QueryByHash(ctx, messageHashes, opts...) | ||
if err != nil { | ||
m.logger.Error("store.queryByHash failed", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Error(err)) | ||
return []common.Hash{} | ||
} | ||
|
||
m.logger.Debug("store.queryByHash result", zap.String("requestID", hexutil.Encode(requestID)), zap.Int("messages", len(result.Messages()))) | ||
|
||
var ackHashes []common.Hash | ||
var missedHashes []common.Hash | ||
for i, hash := range hashes { | ||
found := false | ||
for _, msg := range result.Messages() { | ||
if bytes.Equal(msg.GetMessageHash(), hash.Bytes()) { | ||
found = true | ||
break | ||
} | ||
} | ||
|
||
if found { | ||
ackHashes = append(ackHashes, hash) | ||
m.MessageStoredChan <- hash | ||
} | ||
|
||
if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid { | ||
missedHashes = append(missedHashes, hash) | ||
m.MessageExpiredChan <- hash | ||
} | ||
} | ||
|
||
m.logger.Debug("ack message hashes", zap.Any("ackHashes", ackHashes)) | ||
m.logger.Debug("missed message hashes", zap.Any("missedHashes", missedHashes)) | ||
|
||
return append(ackHashes, missedHashes...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package publish | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/ethereum/go-ethereum/common" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAddAndDelete(t *testing.T) { | ||
ctx := context.TODO() | ||
messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil) | ||
|
||
messageSentCheck.Add("topic", [32]byte{1}, 1) | ||
messageSentCheck.Add("topic", [32]byte{2}, 2) | ||
messageSentCheck.Add("topic", [32]byte{3}, 3) | ||
messageSentCheck.Add("another-topic", [32]byte{4}, 4) | ||
|
||
require.Equal(t, uint32(1), messageSentCheck.messageIDs["topic"][[32]byte{1}]) | ||
require.Equal(t, uint32(2), messageSentCheck.messageIDs["topic"][[32]byte{2}]) | ||
require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}]) | ||
require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}]) | ||
|
||
messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{1}, [32]byte{2}}) | ||
require.NotNil(t, messageSentCheck.messageIDs["topic"]) | ||
require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}]) | ||
|
||
messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{3}}) | ||
require.Nil(t, messageSentCheck.messageIDs["topic"]) | ||
|
||
require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}]) | ||
} |