Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handler, s3store: Fix data race problems #1199

Merged
merged 6 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/continuous-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ jobs:
-
name: Test code
run: |
go test ./pkg/...
go test ./internal/...
go test -race ./pkg/...
go test -race ./internal/...
shell: bash

-
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ require (
github.com/vimeo/go-util v1.4.1
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
golang.org/x/net v0.29.0
golang.org/x/sync v0.8.0
google.golang.org/api v0.199.0
google.golang.org/grpc v1.67.0
google.golang.org/protobuf v1.34.2
Expand Down Expand Up @@ -95,7 +96,6 @@ require (
go.opentelemetry.io/otel/trace v1.29.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
Expand Down
25 changes: 20 additions & 5 deletions pkg/handler/body_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -28,8 +29,11 @@ type bodyReader struct {
bytesCounter int64
ctx *httpContext
reader io.ReadCloser
err error
onReadDone func()

// lock protects concurrent access to err.
lock sync.RWMutex
err error
}

func newBodyReader(c *httpContext, maxSize int64) *bodyReader {
Expand All @@ -41,7 +45,10 @@ func newBodyReader(c *httpContext, maxSize int64) *bodyReader {
}

func (r *bodyReader) Read(b []byte) (int, error) {
if r.err != nil {
r.lock.RLock()
hasErrored := r.err != nil
r.lock.RUnlock()
if hasErrored {
return 0, io.EOF
}

Expand Down Expand Up @@ -99,28 +106,36 @@ func (r *bodyReader) Read(b []byte) (int, error) {

// Other errors are stored for retrival with hasError, but is not returned
// to the consumer. We do not overwrite an error if it has been set already.
r.lock.Lock()
if r.err == nil {
r.err = err
}
r.lock.Unlock()
}

return n, nil
}

func (r bodyReader) hasError() error {
if r.err == io.EOF {
func (r *bodyReader) hasError() error {
r.lock.RLock()
err := r.err
r.lock.RUnlock()

if err == io.EOF {
return nil
}

return r.err
return err
}

func (r *bodyReader) bytesRead() int64 {
return atomic.LoadInt64(&r.bytesCounter)
}

func (r *bodyReader) closeWithError(err error) {
r.lock.Lock()
r.err = err
r.lock.Unlock()

// SetReadDeadline with the current time causes concurrent reads to the body to time out,
// so the body will be closed sooner with less delay.
Expand Down
1 change: 1 addition & 0 deletions pkg/s3store/multi_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
)

// TODO: Replace with errors.Join
func newMultiError(errs []error) error {
message := "Multiple errors occurred:\n"
for _, err := range errs {
Expand Down
90 changes: 45 additions & 45 deletions pkg/s3store/s3store.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import (
"github.com/tus/tusd/v2/internal/uid"
"github.com/tus/tusd/v2/pkg/handler"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand Down Expand Up @@ -469,8 +470,7 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
}()
go partProducer.produce(producerCtx, optimalPartSize)

var wg sync.WaitGroup
var uploadErr error
var eg errgroup.Group

for {
// We acquire the semaphore before starting the goroutine to avoid
Expand All @@ -497,10 +497,8 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
}
upload.parts = append(upload.parts, part)

wg.Add(1)
go func(file io.ReadSeeker, part *s3Part, closePart func() error) {
eg.Go(func() error {
defer upload.store.releaseUploadSemaphore()
defer wg.Done()

t := time.Now()
uploadPartInput := &s3.UploadPartInput{
Expand All @@ -509,39 +507,46 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
UploadId: aws.String(upload.multipartId),
PartNumber: aws.Int32(part.number),
}
etag, err := upload.putPartForUpload(ctx, uploadPartInput, file, part.size)
etag, err := upload.putPartForUpload(ctx, uploadPartInput, partfile, part.size)
store.observeRequestDuration(t, metricUploadPart)
if err != nil {
uploadErr = err
} else {
if err == nil {
part.etag = etag
}
if cerr := closePart(); cerr != nil && uploadErr == nil {
uploadErr = cerr

cerr := closePart()
if err != nil {
return err
}
}(partfile, part, closePart)
if cerr != nil {
return cerr
}
return nil
})
} else {
wg.Add(1)
go func(file io.ReadSeeker, closePart func() error) {
eg.Go(func() error {
defer upload.store.releaseUploadSemaphore()
defer wg.Done()

if err := store.putIncompletePartForUpload(ctx, upload.objectId, file); err != nil {
uploadErr = err
err := store.putIncompletePartForUpload(ctx, upload.objectId, partfile)
if err == nil {
upload.incompletePartSize = partsize
}
if cerr := closePart(); cerr != nil && uploadErr == nil {
uploadErr = cerr

cerr := closePart()
if err != nil {
return err
}
if cerr != nil {
return cerr
}
upload.incompletePartSize = partsize
}(partfile, closePart)
return nil
})
}

bytesUploaded += partsize
nextPartNum += 1
}

wg.Wait()

uploadErr := eg.Wait()
if uploadErr != nil {
return 0, uploadErr
}
Expand Down Expand Up @@ -969,47 +974,42 @@ func (upload *s3Upload) concatUsingDownload(ctx context.Context, partialUploads
func (upload *s3Upload) concatUsingMultipart(ctx context.Context, partialUploads []handler.Upload) error {
store := upload.store

numPartialUploads := len(partialUploads)
errs := make([]error, 0, numPartialUploads)
upload.parts = make([]*s3Part, len(partialUploads))

// Copy partial uploads concurrently
var wg sync.WaitGroup
wg.Add(numPartialUploads)
var eg errgroup.Group
for i, partialUpload := range partialUploads {

// Part numbers must be in the range of 1 to 10000, inclusive. Since
// slice indexes start at 0, we add 1 to ensure that i >= 1.
partNumber := int32(i + 1)
partialS3Upload := partialUpload.(*s3Upload)

upload.parts = append(upload.parts, &s3Part{
number: partNumber,
size: -1,
etag: "",
})

go func(partNumber int32, sourceObject string) {
defer wg.Done()

eg.Go(func() error {
res, err := store.Service.UploadPartCopy(ctx, &s3.UploadPartCopyInput{
Bucket: aws.String(store.Bucket),
Key: store.keyWithPrefix(upload.objectId),
UploadId: aws.String(upload.multipartId),
PartNumber: aws.Int32(partNumber),
CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(sourceObject)),
CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(partialS3Upload.objectId)),
})
if err != nil {
errs = append(errs, err)
return
return err
}

upload.parts[partNumber-1].etag = *res.CopyPartResult.ETag
}(partNumber, partialS3Upload.objectId)
}
upload.parts[partNumber-1] = &s3Part{
number: partNumber,
size: -1, // -1 is fine here bcause FinishUpload does not need this info.
etag: *res.CopyPartResult.ETag,
}

wg.Wait()
return nil
})
}

if len(errs) > 0 {
return newMultiError(errs)
err := eg.Wait()
if err != nil {
return err
}

return upload.FinishUpload(ctx)
Expand Down