diff --git a/storage.go b/storage.go index 6f20076..1f36b28 100644 --- a/storage.go +++ b/storage.go @@ -1528,3 +1528,155 @@ func (s *PersistentSlabStorage) getAllChildReferences(slab Slab) ( return references, brokenReferences, nil } + +// BatchPreload decodeds and caches slabs of given ids in parallel. +// This is useful for storage health or data validation in migration programs. +func (s *PersistentSlabStorage) BatchPreload(ids []StorageID, numWorkers int) error { + if len(ids) == 0 { + return nil + } + + // Use 11 for min slab count for parallel decoding because micro benchmarks showed + // performance regression for <= 10 slabs when decoding slabs in parallel. + const minCountForBatchPreload = 11 + if len(ids) < minCountForBatchPreload { + + for _, id := range ids { + // fetch from base storage last + data, ok, err := s.baseStorage.Retrieve(id) + if err != nil { + // Wrap err as external error (if needed) because err is returned by BaseStorage interface. + return wrapErrorfAsExternalErrorIfNeeded(err, fmt.Sprintf("failed to retrieve slab %s", id)) + } + if !ok { + continue + } + + slab, err := DecodeSlab(id, data, s.cborDecMode, s.DecodeStorable, s.DecodeTypeInfo) + if err != nil { + // err is already categorized by DecodeSlab(). + return err + } + + // save decoded slab to cache + s.cache[id] = slab + } + + return nil + } + + type slabToBeDecoded struct { + slabID StorageID + data []byte + } + + type decodedSlab struct { + slabID StorageID + slab Slab + err error + } + + // Define decoder (worker) to decode slabs in parallel + decoder := func(wg *sync.WaitGroup, done <-chan struct{}, jobs <-chan slabToBeDecoded, results chan<- decodedSlab) { + defer wg.Done() + + for slabData := range jobs { + // Check if goroutine is signaled to stop before proceeding. + select { + case <-done: + return + default: + } + + id := slabData.slabID + data := slabData.data + + slab, err := DecodeSlab(id, data, s.cborDecMode, s.DecodeStorable, s.DecodeTypeInfo) + // err is already categorized by DecodeSlab(). + results <- decodedSlab{ + slabID: id, + slab: slab, + err: err, + } + } + } + + if numWorkers > len(ids) { + numWorkers = len(ids) + } + + var wg sync.WaitGroup + + // Construct done signal channel + done := make(chan struct{}) + + // Construct job queue + jobs := make(chan slabToBeDecoded, len(ids)) + + // Construct result queue + results := make(chan decodedSlab, len(ids)) + + defer func() { + // This ensures that all goroutines are stopped before output channel is closed. + + // Wait for all goroutines to finish + wg.Wait() + + // Close output channel + close(results) + }() + + // Preallocate cache map if empty + if len(s.cache) == 0 { + s.cache = make(map[StorageID]Slab, len(ids)) + } + + // Launch workers + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go decoder(&wg, done, jobs, results) + } + + // Send jobs + jobCount := 0 + { + // Need to close input channel (jobs) here because + // if there isn't any job in jobs channel, + // done is never processed inside loop "for slabData := range jobs". + defer close(jobs) + + for _, id := range ids { + // fetch from base storage last + data, ok, err := s.baseStorage.Retrieve(id) + if err != nil { + // Closing done channel signals goroutines to stop. + close(done) + // Wrap err as external error (if needed) because err is returned by BaseStorage interface. + return wrapErrorfAsExternalErrorIfNeeded(err, fmt.Sprintf("failed to retrieve slab %s", id)) + } + if !ok { + continue + } + + jobs <- slabToBeDecoded{id, data} + jobCount++ + } + } + + // Process results + for i := 0; i < jobCount; i++ { + result := <-results + + if result.err != nil { + // Closing done channel signals goroutines to stop. + close(done) + // result.err is already categorized by DecodeSlab(). + return result.err + } + + // save decoded slab to cache + s.cache[result.slabID] = result.slab + } + + return nil +} diff --git a/storage_bench_test.go b/storage_bench_test.go index 910126c..adc3253 100644 --- a/storage_bench_test.go +++ b/storage_bench_test.go @@ -132,3 +132,122 @@ func BenchmarkStorageNondeterministicFastCommit(b *testing.B) { benchmarkNondeterministicFastCommit(b, fixedSeed, 100_000) benchmarkNondeterministicFastCommit(b, fixedSeed, 1_000_000) } + +func benchmarkRetrieve(b *testing.B, seed int64, numberOfSlabs int) { + + r := rand.New(rand.NewSource(seed)) + + encMode, err := cbor.EncOptions{}.EncMode() + require.NoError(b, err) + + decMode, err := cbor.DecOptions{}.DecMode() + require.NoError(b, err) + + encodedSlabs := make(map[StorageID][]byte) + ids := make([]StorageID, 0, numberOfSlabs) + for i := 0; i < numberOfSlabs; i++ { + addr := generateRandomAddress(r) + + var index StorageIndex + binary.BigEndian.PutUint64(index[:], uint64(i)) + + id := StorageID{addr, index} + + slab := generateLargeSlab(id) + + data, err := Encode(slab, encMode) + require.NoError(b, err) + + encodedSlabs[id] = data + ids = append(ids, id) + } + + b.Run(strconv.Itoa(numberOfSlabs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + + baseStorage := NewInMemBaseStorageFromMap(encodedSlabs) + storage := NewPersistentSlabStorage(baseStorage, encMode, decMode, decodeStorable, decodeTypeInfo) + + b.StartTimer() + + for _, id := range ids { + _, found, err := storage.Retrieve(id) + require.True(b, found) + require.NoError(b, err) + } + } + }) +} + +func benchmarkBatchPreload(b *testing.B, seed int64, numberOfSlabs int) { + + r := rand.New(rand.NewSource(seed)) + + encMode, err := cbor.EncOptions{}.EncMode() + require.NoError(b, err) + + decMode, err := cbor.DecOptions{}.DecMode() + require.NoError(b, err) + + encodedSlabs := make(map[StorageID][]byte) + ids := make([]StorageID, 0, numberOfSlabs) + for i := 0; i < numberOfSlabs; i++ { + addr := generateRandomAddress(r) + + var index StorageIndex + binary.BigEndian.PutUint64(index[:], uint64(i)) + + id := StorageID{addr, index} + + slab := generateLargeSlab(id) + + data, err := Encode(slab, encMode) + require.NoError(b, err) + + encodedSlabs[id] = data + ids = append(ids, id) + } + + b.Run(strconv.Itoa(numberOfSlabs), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + + baseStorage := NewInMemBaseStorageFromMap(encodedSlabs) + storage := NewPersistentSlabStorage(baseStorage, encMode, decMode, decodeStorable, decodeTypeInfo) + + b.StartTimer() + + err = storage.BatchPreload(ids, runtime.NumCPU()) + require.NoError(b, err) + + for _, id := range ids { + _, found, err := storage.Retrieve(id) + require.True(b, found) + require.NoError(b, err) + } + } + }) +} + +func BenchmarkStorageRetrieve(b *testing.B) { + fixedSeed := int64(1234567) // intentionally use fixed constant rather than time, etc. + + benchmarkRetrieve(b, fixedSeed, 10) + benchmarkRetrieve(b, fixedSeed, 100) + benchmarkRetrieve(b, fixedSeed, 1_000) + benchmarkRetrieve(b, fixedSeed, 10_000) + benchmarkRetrieve(b, fixedSeed, 100_000) + benchmarkRetrieve(b, fixedSeed, 1_000_000) +} + +func BenchmarkStorageBatchPreload(b *testing.B) { + fixedSeed := int64(1234567) // intentionally use fixed constant rather than time, etc. + + benchmarkBatchPreload(b, fixedSeed, 10) + benchmarkBatchPreload(b, fixedSeed, 100) + benchmarkBatchPreload(b, fixedSeed, 1_000) + benchmarkBatchPreload(b, fixedSeed, 10_000) + benchmarkBatchPreload(b, fixedSeed, 100_000) + benchmarkBatchPreload(b, fixedSeed, 1_000_000) +} diff --git a/storage_test.go b/storage_test.go index bfee615..672413b 100644 --- a/storage_test.go +++ b/storage_test.go @@ -19,6 +19,7 @@ package atree import ( + "encoding/binary" "errors" "math/rand" "runtime" @@ -4932,3 +4933,172 @@ func testStorageNondeterministicFastCommit(t *testing.T, numberOfAccounts int, n require.Nil(t, storedValue) } } + +func TestStorageBatchPreload(t *testing.T) { + t.Run("0 slab", func(t *testing.T) { + numberOfAccounts := 0 + numberOfSlabsPerAccount := 0 + testStorageBatchPreload(t, numberOfAccounts, numberOfSlabsPerAccount) + }) + + t.Run("1 slab", func(t *testing.T) { + numberOfAccounts := 1 + numberOfSlabsPerAccount := 1 + testStorageBatchPreload(t, numberOfAccounts, numberOfSlabsPerAccount) + }) + + t.Run("10 slab", func(t *testing.T) { + numberOfAccounts := 1 + numberOfSlabsPerAccount := 10 + testStorageBatchPreload(t, numberOfAccounts, numberOfSlabsPerAccount) + }) + + t.Run("100 slabs", func(t *testing.T) { + numberOfAccounts := 10 + numberOfSlabsPerAccount := 10 + testStorageBatchPreload(t, numberOfAccounts, numberOfSlabsPerAccount) + }) + + t.Run("10_000 slabs", func(t *testing.T) { + numberOfAccounts := 10 + numberOfSlabsPerAccount := 1_000 + testStorageBatchPreload(t, numberOfAccounts, numberOfSlabsPerAccount) + }) +} + +func testStorageBatchPreload(t *testing.T, numberOfAccounts int, numberOfSlabsPerAccount int) { + + indexesByAddress := make(map[Address]uint64) + + generateSlabID := func(address Address) StorageID { + nextIndex := indexesByAddress[address] + 1 + + var idx StorageIndex + binary.BigEndian.PutUint64(idx[:], nextIndex) + + indexesByAddress[address] = nextIndex + + return NewStorageID(address, idx) + } + + encMode, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + decMode, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + r := newRand(t) + + encodedSlabs := make(map[StorageID][]byte) + + // Generate and encode slabs + for i := 0; i < numberOfAccounts; i++ { + + addr := generateRandomAddress(r) + + for j := 0; j < numberOfSlabsPerAccount; j++ { + + slabID := generateSlabID(addr) + + slab := generateRandomSlab(slabID, r) + + encodedSlabs[slabID], err = Encode(slab, encMode) + require.NoError(t, err) + } + } + + baseStorage := NewInMemBaseStorageFromMap(encodedSlabs) + storage := NewPersistentSlabStorage(baseStorage, encMode, decMode, decodeStorable, decodeTypeInfo) + + ids := make([]StorageID, 0, len(encodedSlabs)) + for id := range encodedSlabs { + ids = append(ids, id) + } + + // Batch preload slabs from base storage + err = storage.BatchPreload(ids, runtime.NumCPU()) + require.NoError(t, err) + require.Equal(t, len(encodedSlabs), len(storage.cache)) + require.Equal(t, 0, len(storage.deltas)) + + // Compare encoded data + for id, data := range encodedSlabs { + cachedData, err := Encode(storage.cache[id], encMode) + require.NoError(t, err) + + require.Equal(t, cachedData, data) + } +} + +func TestStorageBatchPreloadNotFoundSlabs(t *testing.T) { + + encMode, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + decMode, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + r := newRand(t) + + t.Run("empty storage", func(t *testing.T) { + const numberOfSlabs = 10 + + ids := make([]StorageID, numberOfSlabs) + for i := 0; i < numberOfSlabs; i++ { + var index StorageIndex + binary.BigEndian.PutUint64(index[:], uint64(i)) + + ids[i] = NewStorageID(generateRandomAddress(r), index) + } + + baseStorage := NewInMemBaseStorage() + storage := NewPersistentSlabStorage(baseStorage, encMode, decMode, decodeStorable, decodeTypeInfo) + + err := storage.BatchPreload(ids, runtime.NumCPU()) + require.NoError(t, err) + + require.Equal(t, 0, len(storage.cache)) + require.Equal(t, 0, len(storage.deltas)) + }) + + t.Run("non-empty storage", func(t *testing.T) { + const numberOfSlabs = 10 + + ids := make([]StorageID, numberOfSlabs) + encodedSlabs := make(map[StorageID][]byte) + + for i := 0; i < numberOfSlabs; i++ { + var index StorageIndex + binary.BigEndian.PutUint64(index[:], uint64(i)) + + id := NewStorageID(generateRandomAddress(r), index) + + slab := generateRandomSlab(id, r) + + encodedSlabs[id], err = Encode(slab, encMode) + require.NoError(t, err) + + ids[i] = id + } + + // Append a slab ID that doesn't exist in storage. + ids = append(ids, NewStorageID(generateRandomAddress(r), StorageIndex{numberOfSlabs})) + + baseStorage := NewInMemBaseStorageFromMap(encodedSlabs) + storage := NewPersistentSlabStorage(baseStorage, encMode, decMode, decodeStorable, decodeTypeInfo) + + err := storage.BatchPreload(ids, runtime.NumCPU()) + require.NoError(t, err) + + require.Equal(t, len(encodedSlabs), len(storage.cache)) + require.Equal(t, 0, len(storage.deltas)) + + // Compare encoded data + for id, data := range encodedSlabs { + cachedData, err := Encode(storage.cache[id], encMode) + require.NoError(t, err) + + require.Equal(t, cachedData, data) + } + }) +}