diff --git a/go.mod b/go.mod index b278b776..9bf50b31 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,8 @@ module github.com/tus/tusd/v2 // Specify the Go version needed for the Heroku deployment // See https://github.com/heroku/heroku-buildpack-go#go-module-specifics // +heroku goVersion go1.22 -go 1.21.0 +go 1.22.1 + toolchain go1.22.7 require ( diff --git a/pkg/handler/composer.go b/pkg/handler/composer.go index abea0a69..c8f96a11 100644 --- a/pkg/handler/composer.go +++ b/pkg/handler/composer.go @@ -14,6 +14,8 @@ type StoreComposer struct { Concater ConcaterDataStore UsesLengthDeferrer bool LengthDeferrer LengthDeferrerDataStore + ContentServer ContentServerDataStore + UsesContentServer bool } // NewStoreComposer creates a new and empty store composer. @@ -85,3 +87,7 @@ func (store *StoreComposer) UseLengthDeferrer(ext LengthDeferrerDataStore) { store.UsesLengthDeferrer = ext != nil store.LengthDeferrer = ext } +func (store *StoreComposer) UseContentServer(ext ContentServerDataStore) { + store.UsesContentServer = ext != nil + store.ContentServer = ext +} diff --git a/pkg/handler/datastore.go b/pkg/handler/datastore.go index 54c828c6..4a3314a6 100644 --- a/pkg/handler/datastore.go +++ b/pkg/handler/datastore.go @@ -3,6 +3,7 @@ package handler import ( "context" "io" + "net/http" ) type MetaData map[string]string @@ -121,6 +122,16 @@ type DataStore interface { GetUpload(ctx context.Context, id string) (upload Upload, err error) } +// ServableUpload defines the method for serving content directly +type ServableUpload interface { + ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error +} + +// ContentServerDataStore is the interface for data stores that can serve content directly +type ContentServerDataStore interface { + AsServableUpload(upload Upload) ServableUpload +} + type TerminatableUpload interface { // Terminate an upload so any further requests to the upload resource will // return the ErrNotFound error. diff --git a/pkg/handler/unrouted_handler.go b/pkg/handler/unrouted_handler.go index 44cd7c2c..6b7cd5e0 100644 --- a/pkg/handler/unrouted_handler.go +++ b/pkg/handler/unrouted_handler.go @@ -1013,6 +1013,17 @@ func (handler *UnroutedHandler) GetFile(w http.ResponseWriter, r *http.Request) return } + // If the data store implements ContentServerDataStore, use the ServableUpload interface + if handler.composer.UsesContentServer { + servableUpload := handler.composer.ContentServer.AsServableUpload(upload) + err = servableUpload.ServeContent(c, w, r) + if err != nil { + handler.sendError(c, err) + } + return + } + + // Fall back to the existing GetReader implementation if ContentServerDataStore is not implemented contentType, contentDisposition := filterContentType(info) resp := HTTPResponse{ StatusCode: http.StatusOK, diff --git a/pkg/s3store/s3store.go b/pkg/s3store/s3store.go index 7b5acaff..4734c365 100644 --- a/pkg/s3store/s3store.go +++ b/pkg/s3store/s3store.go @@ -79,6 +79,7 @@ import ( "net/http" "os" "regexp" + "strconv" "strings" "sync" "time" @@ -376,6 +377,81 @@ func (store S3Store) AsConcatableUpload(upload handler.Upload) handler.Concatabl return upload.(*s3Upload) } +func (store S3Store) AsServableUpload(upload handler.Upload) handler.ServableUpload { + return upload.(*s3Upload) +} + +func (su *s3Upload) ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Get file info + info, err := su.GetInfo(ctx) + if err != nil { + return err + } + + // Prepare GetObject input + input := &s3.GetObjectInput{ + Bucket: aws.String(su.store.Bucket), + Key: su.store.keyWithPrefix(su.objectId), + } + + // Handle range requests + rangeHeader := r.Header.Get("Range") + if rangeHeader != "" { + if err := su.handleRangeRequest(ctx, w, r, info, input, rangeHeader); err != nil { + return err + } + return nil + } + + // For non-range requests, serve the entire file + result, err := su.store.Service.GetObject(ctx, input) + if err != nil { + return err + } + defer result.Body.Close() + + // Set headers + w.Header().Set("Content-Length", strconv.FormatInt(info.Size, 10)) + w.Header().Set("Content-Type", info.MetaData["filetype"]) + w.Header().Set("ETag", *result.ETag) + + // Stream the content + _, err = io.Copy(w, result.Body) + return err +} + +func (su *s3Upload) handleRangeRequest(ctx context.Context, w http.ResponseWriter, _ *http.Request, info handler.FileInfo, input *s3.GetObjectInput, rangeHeader string) error { + ranges, err := parseRange(rangeHeader, info.Size) + if err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return err + } + + if len(ranges) > 1 { + return fmt.Errorf("multiple ranges are not supported") + } + + // Set the range in the GetObject input + input.Range = aws.String(fmt.Sprintf("bytes=%d-%d", ranges[0].start, ranges[0].end)) + + result, err := su.store.Service.GetObject(ctx, input) + if err != nil { + return err + } + defer result.Body.Close() + + // Set headers for partial content + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ranges[0].start, ranges[0].end, info.Size)) + w.Header().Set("Content-Length", strconv.FormatInt(ranges[0].end-ranges[0].start+1, 10)) + w.Header().Set("Content-Type", info.MetaData["filetype"]) + w.Header().Set("ETag", *result.ETag) + w.WriteHeader(http.StatusPartialContent) + + // Stream the content + _, err = io.Copy(w, result.Body) + return err +} + func (upload *s3Upload) writeInfo(ctx context.Context, info handler.FileInfo) error { store := upload.store @@ -1249,3 +1325,65 @@ func (store S3Store) releaseUploadSemaphore() { store.uploadSemaphore.Release() store.uploadSemaphoreDemandMetric.Dec() } + +// Helper function to parse range header +func parseRange(rangeHeader string, size int64) ([]struct{ start, end int64 }, error) { + if rangeHeader == "" { + return nil, fmt.Errorf("empty range header") + } + + const b = "bytes=" + if !strings.HasPrefix(rangeHeader, b) { + return nil, fmt.Errorf("invalid range header format") + } + + var ranges []struct{ start, end int64 } + for _, ra := range strings.Split(rangeHeader[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, fmt.Errorf("invalid range format") + } + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) + var r struct{ start, end int64 } + if start == "" { + // suffix-byte-range-spec, like "-100" + n, err := strconv.ParseInt(end, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid range format") + } + if n > size { + n = size + } + r.start = size - n + r.end = size - 1 + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid range format") + } + if i >= size { + return nil, fmt.Errorf("range out of bounds") + } + r.start = i + if end == "" { + // byte-range-spec, like "100-" + r.end = size - 1 + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || i >= size || i < r.start { + return nil, fmt.Errorf("invalid range format") + } + r.end = i + } + } + ranges = append(ranges, r) + } + if len(ranges) == 0 { + return nil, fmt.Errorf("no valid ranges") + } + return ranges, nil +} diff --git a/pkg/s3store/s3store_test.go b/pkg/s3store/s3store_test.go index e69ba2e3..771a5d5a 100644 --- a/pkg/s3store/s3store_test.go +++ b/pkg/s3store/s3store_test.go @@ -3,8 +3,11 @@ package s3store import ( "bytes" "context" + "errors" "fmt" "io" + "net/http" + "net/http/httptest" "os" "strings" "testing" @@ -1468,3 +1471,281 @@ func TestWriteChunkCleansUpTempFiles(t *testing.T) { assert.Nil(err) assert.Equal(len(files), 0) } + +func TestS3StoreAsServerDataStore(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + + s3obj := NewMockS3API(mockCtrl) + store := New("bucket", s3obj) + + upload := &s3Upload{ + store: &store, + info: &handler.FileInfo{}, + objectId: "uploadId", + multipartId: "multipartId", + } + + servableUpload := store.AsServableUpload(upload) + assert.NotNil(servableUpload) + assert.IsType(&S3ServableUpload{}, servableUpload) +} + +func TestS3ServableUploadServeContent(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + + s3obj := NewMockS3API(mockCtrl) + store := New("bucket", s3obj) + + upload := &s3Upload{ + store: &store, + info: &handler.FileInfo{Size: 100, Offset: 100, MetaData: map[string]string{"filetype": "text/plain"}}, + objectId: "uploadId", + multipartId: "multipartId", + } + + s3obj.EXPECT().GetObject(gomock.Any(), &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("uploadId"), + }).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(strings.NewReader("test content")), + ContentLength: aws.Int64(100), + ETag: aws.String("etag123"), + }, nil) + + servableUpload := store.AsServableUpload(upload) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + err := servableUpload.ServeContent(context.Background(), w, r) + assert.Nil(err) + + assert.Equal(http.StatusOK, w.Code) + assert.Equal("100", w.Header().Get("Content-Length")) + assert.Equal("text/plain", w.Header().Get("Content-Type")) + assert.Equal("etag123", w.Header().Get("ETag")) + assert.Equal("test content", w.Body.String()) +} + +func TestS3ServableUploadServeContentWithRange(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + + s3obj := NewMockS3API(mockCtrl) + store := New("bucket", s3obj) + + upload := &s3Upload{ + store: &store, + info: &handler.FileInfo{Size: 100, Offset: 100, MetaData: map[string]string{"filetype": "text/plain"}}, + objectId: "uploadId", + multipartId: "multipartId", + } + + s3obj.EXPECT().GetObject(gomock.Any(), &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("uploadId"), + Range: aws.String("bytes=10-19"), + }).Return(&s3.GetObjectOutput{ + Body: io.NopCloser(strings.NewReader("0123456789")), + ContentLength: aws.Int64(10), + ETag: aws.String("etag123"), + }, nil) + + servableUpload := store.AsServableUpload(upload) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Range", "bytes=10-19") + + err := servableUpload.ServeContent(context.Background(), w, r) + assert.Nil(err) + + assert.Equal(http.StatusPartialContent, w.Code) + assert.Equal("10", w.Header().Get("Content-Length")) + assert.Equal("text/plain", w.Header().Get("Content-Type")) + assert.Equal("etag123", w.Header().Get("ETag")) + assert.Equal("bytes 10-19/100", w.Header().Get("Content-Range")) + assert.Equal("0123456789", w.Body.String()) +} + +func TestS3ServableUploadServeContentError(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + + s3obj := NewMockS3API(mockCtrl) + store := New("bucket", s3obj) + + upload := &s3Upload{ + store: &store, + info: &handler.FileInfo{Size: 100, Offset: 100, MetaData: map[string]string{"filetype": "text/plain"}}, + objectId: "uploadId", + multipartId: "multipartId", + } + + expectedError := errors.New("S3 error") + s3obj.EXPECT().GetObject(gomock.Any(), gomock.Any()).Return(nil, expectedError) + + servableUpload := store.AsServableUpload(upload) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + err := servableUpload.ServeContent(context.Background(), w, r) + assert.Equal(expectedError, err) +} + +func TestParseRange(t *testing.T) { + tests := []struct { + name string + rangeHeader string + size int64 + expected []struct{ start, end int64 } + expectedErr string + }{ + { + name: "Empty range header", + rangeHeader: "", + size: 100, + expectedErr: "empty range header", + }, + { + name: "Invalid range header format", + rangeHeader: "invalid=0-10", + size: 100, + expectedErr: "invalid range header format", + }, + { + name: "Single valid range", + rangeHeader: "bytes=0-50", + size: 100, + expected: []struct{ start, end int64 }{{0, 50}}, + }, + { + name: "Multiple valid ranges", + rangeHeader: "bytes=0-50,60-70,80-", + size: 100, + expected: []struct{ start, end int64 }{{0, 50}, {60, 70}, {80, 99}}, + }, + { + name: "Suffix range", + rangeHeader: "bytes=-30", + size: 100, + expected: []struct{ start, end int64 }{{70, 99}}, + }, + { + name: "Suffix range larger than file", + rangeHeader: "bytes=-150", + size: 100, + expected: []struct{ start, end int64 }{{0, 99}}, + }, + { + name: "Invalid range format", + rangeHeader: "bytes=invalid-50", + size: 100, + expectedErr: "invalid range format", + }, + { + name: "Range out of bounds", + rangeHeader: "bytes=150-200", + size: 100, + expectedErr: "range out of bounds", + }, + { + name: "End smaller than start", + rangeHeader: "bytes=50-40", + size: 100, + expectedErr: "invalid range format", + }, + { + name: "No valid ranges", + rangeHeader: "bytes=", + size: 100, + expectedErr: "no valid ranges", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseRange(tt.rangeHeader, tt.size) + + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, ranges) + } + }) + } +} + +func TestParseRangeEdgeCases(t *testing.T) { + tests := []struct { + name string + rangeHeader string + size int64 + expected []struct{ start, end int64 } + expectedErr string + }{ + { + name: "Zero size file", + rangeHeader: "bytes=0-10", + size: 0, + expectedErr: "range out of bounds", + }, + { + name: "Single byte file", + rangeHeader: "bytes=0-0", + size: 1, + expected: []struct{ start, end int64 }{{0, 0}}, + }, + { + name: "Very large file", + rangeHeader: "bytes=9223372036854775806-", + size: 9223372036854775807, // max int64 + expected: []struct{ start, end int64 }{{9223372036854775806, 9223372036854775806}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseRange(tt.rangeHeader, tt.size) + + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, ranges) + } + }) + } +} + +func TestParseRangeWhitespace(t *testing.T) { + tests := []struct { + name string + rangeHeader string + size int64 + expected []struct{ start, end int64 } + }{ + { + name: "Whitespace in range", + rangeHeader: "bytes= 0-50 , 60-70 , 80- ", + size: 100, + expected: []struct{ start, end int64 }{{0, 50}, {60, 70}, {80, 99}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseRange(tt.rangeHeader, tt.size) + assert.NoError(t, err) + assert.Equal(t, tt.expected, ranges) + }) + } +}