From 9494fb353ff114a68f7406b454455cc5bf4768af Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 13 Nov 2024 11:19:11 -0800 Subject: [PATCH] Implement reference counting of volume mounts in amazon-ecs-volume-plugin (#4425) --- ecs-init/volumes/ecs_volume_plugin.go | 13 +- ecs-init/volumes/ecs_volume_plugin_test.go | 164 +++++++++++++++------ ecs-init/volumes/state_manager.go | 12 +- ecs-init/volumes/state_manager_test.go | 29 ++++ ecs-init/volumes/types/types.go | 18 ++- ecs-init/volumes/types/types_test.go | 18 ++- 6 files changed, 196 insertions(+), 58 deletions(-) diff --git a/ecs-init/volumes/ecs_volume_plugin.go b/ecs-init/volumes/ecs_volume_plugin.go index 5268a68637c..b047a5ee0f1 100644 --- a/ecs-init/volumes/ecs_volume_plugin.go +++ b/ecs-init/volumes/ecs_volume_plugin.go @@ -70,6 +70,17 @@ func (a *AmazonECSVolumePlugin) LoadState() error { if oldState.Volumes == nil { return nil } + + // Reset volume mount reference count. This is for backwards-compatibility with old + // state file format which did not have reference counting of volume mounts. + for _, vol := range oldState.Volumes { + for mountId, count := range vol.Mounts { + if count == 0 { + vol.Mounts[mountId] = 1 + } + } + } + for volName, vol := range oldState.Volumes { voldriver, err := a.getVolumeDriver(vol.Type) if err != nil { @@ -146,7 +157,7 @@ func (a *AmazonECSVolumePlugin) Create(r *volume.CreateRequest) error { Path: target, Options: r.Options, CreatedAt: time.Now().Format(time.RFC3339Nano), - Mounts: map[string]*string{}, + Mounts: map[string]int{}, } // record the volume information a.volumes[r.Name] = vol diff --git a/ecs-init/volumes/ecs_volume_plugin_test.go b/ecs-init/volumes/ecs_volume_plugin_test.go index 28bb8899b64..ca7b9805b01 100644 --- a/ecs-init/volumes/ecs_volume_plugin_test.go +++ b/ecs-init/volumes/ecs_volume_plugin_test.go @@ -22,6 +22,7 @@ import ( "github.com/docker/go-plugins-helpers/volume" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestVolumeDriver implements VolumeDriver interface for testing @@ -611,35 +612,95 @@ func TestCapabilities(t *testing.T) { } func TestPluginLoadState(t *testing.T) { - plugin := &AmazonECSVolumePlugin{ - volumeDrivers: map[string]driver.VolumeDriver{ - "efs": NewECSVolumeDriver(), + tcs := []struct { + name string + stateFileContents string + pluginAssertions func(*testing.T, *AmazonECSVolumePlugin) + }{ + { + name: "backwards compatibility with state format without reference counting of mounts", + stateFileContents: ` + { + "volumes": { + "efsVolume": { + "type":"efs", + "path":"/var/lib/ecs/volumes/efsVolume", + "options": {"device":"fs-123","o":"tls","type":"efs"}, + "mounts": {"id1": null} + } + } + }`, + pluginAssertions: func(t *testing.T, plugin *AmazonECSVolumePlugin) { + assert.Len(t, plugin.volumes, 1) + vol, ok := plugin.volumes["efsVolume"] + assert.True(t, ok) + assert.Equal(t, "efs", vol.Type) + assert.Equal(t, VolumeMountPathPrefix+"efsVolume", vol.Path) + vols := plugin.state.VolState.Volumes + assert.Len(t, vols, 1) + volInfo, ok := vols["efsVolume"] + require.True(t, ok) + assert.Equal(t, "efs", volInfo.Type) + assert.Equal(t, VolumeMountPathPrefix+"efsVolume", volInfo.Path) + + // Test for backwards compatibility of old state format following implementation of + // reference counting of volume mounts null value for mount IDs should be converted to 1. + assert.Equal(t, map[string]int{"id1": 1}, vols["efsVolume"].Mounts) + }, + }, + { + name: "current state format", + stateFileContents: ` + { + "volumes": { + "efsVolume": { + "type":"efs", + "path":"/var/lib/ecs/volumes/efsVolume", + "options": {"device":"fs-123","o":"tls","type":"efs"}, + "mounts": {"id1": 1, "id2": 2} + } + } + }`, + pluginAssertions: func(t *testing.T, plugin *AmazonECSVolumePlugin) { + assert.Len(t, plugin.volumes, 1) + vol, ok := plugin.volumes["efsVolume"] + assert.True(t, ok) + assert.Equal(t, "efs", vol.Type) + assert.Equal(t, VolumeMountPathPrefix+"efsVolume", vol.Path) + vols := plugin.state.VolState.Volumes + assert.Len(t, vols, 1) + volInfo, ok := vols["efsVolume"] + require.True(t, ok) + assert.Equal(t, "efs", volInfo.Type) + assert.Equal(t, VolumeMountPathPrefix+"efsVolume", volInfo.Path) + assert.Equal(t, map[string]int{"id1": 1, "id2": 2}, vols["efsVolume"].Mounts) + }, }, - volumes: make(map[string]*types.Volume), - state: NewStateManager(), - } - fileExists = func(path string) bool { - return true } - readStateFile = func() ([]byte, error) { - return []byte(`{"volumes":{"efsVolume":{"type":"efs","path":"/var/lib/ecs/volumes/efsVolume","options":{"device":"fs-123","o":"tls","type":"efs"}}}}`), nil + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + plugin := &AmazonECSVolumePlugin{ + volumeDrivers: map[string]driver.VolumeDriver{ + "efs": NewECSVolumeDriver(), + }, + volumes: make(map[string]*types.Volume), + state: NewStateManager(), + } + fileExists = func(path string) bool { + return true + } + readStateFile = func() ([]byte, error) { + return []byte(tc.stateFileContents), nil + } + defer func() { + fileExists = checkFile + readStateFile = readFile + }() + assert.NoError(t, plugin.LoadState(), "expected no error when loading state") + tc.pluginAssertions(t, plugin) + }) } - defer func() { - fileExists = checkFile - readStateFile = readFile - }() - assert.NoError(t, plugin.LoadState(), "expected no error when loading state") - assert.Len(t, plugin.volumes, 1) - vol, ok := plugin.volumes["efsVolume"] - assert.True(t, ok) - assert.Equal(t, "efs", vol.Type) - assert.Equal(t, VolumeMountPathPrefix+"efsVolume", vol.Path) - vols := plugin.state.VolState.Volumes - assert.Len(t, vols, 1) - volInfo, ok := vols["efsVolume"] - assert.True(t, ok) - assert.Equal(t, "efs", volInfo.Type) - assert.Equal(t, VolumeMountPathPrefix+"efsVolume", volInfo.Path) } func TestPluginNoStateFile(t *testing.T) { @@ -758,7 +819,7 @@ func TestPluginMount(t *testing.T) { req: &volume.MountRequest{Name: volName, ID: reqMountID}, expectedResponse: &volume.MountResponse{Mountpoint: volPath}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{reqMountID: nil} + mounts := map[string]int{reqMountID: 1} assert.Equal(t, map[string]*types.Volume{ volName: {Path: volPath, Options: volOpts, Mounts: mounts}, @@ -778,13 +839,13 @@ func TestPluginMount(t *testing.T) { pluginVolumes: map[string]*types.Volume{ volName: { Path: volPath, - Mounts: map[string]*string{"someMount": nil}, + Mounts: map[string]int{"someMount": 1}, }, }, req: &volume.MountRequest{Name: volName, ID: reqMountID}, expectedResponse: &volume.MountResponse{Mountpoint: volPath}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{reqMountID: nil, "someMount": nil} + mounts := map[string]int{reqMountID: 1, "someMount": 1} assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, plugin.volumes) @@ -826,24 +887,24 @@ func TestPluginMount(t *testing.T) { assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath}}, plugin.volumes) assert.Equal(t, &VolumeState{Volumes: map[string]*VolumeInfo{ - volName: {Path: volPath, Mounts: map[string]*string{}}, + volName: {Path: volPath, Mounts: map[string]int{}}, }}, plugin.state.VolState) }, }, { - name: "duplicate mount is a no-op", + name: "duplicate mount increments mount reference count", pluginVolumes: map[string]*types.Volume{ volName: { Path: volPath, - Mounts: map[string]*string{reqMountID: nil}, + Mounts: map[string]int{reqMountID: 1}, Options: volOpts, }, }, req: &volume.MountRequest{Name: volName, ID: reqMountID}, expectedResponse: &volume.MountResponse{Mountpoint: volPath}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{reqMountID: nil} + mounts := map[string]int{reqMountID: 2} assert.Equal(t, map[string]*types.Volume{ volName: {Path: volPath, Options: volOpts, Mounts: mounts}, @@ -870,7 +931,7 @@ func TestPluginMount(t *testing.T) { expectedError: "mount failed due to an error while saving state: some error", assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { // No mounts expected on the volume - mounts := map[string]*string{} + mounts := map[string]int{} assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, plugin.volumes) @@ -963,11 +1024,11 @@ func TestPluginUnmount(t *testing.T) { d.EXPECT().Remove(&driver.RemoveRequest{Name: volName}).Return(nil) }, pluginVolumes: map[string]*types.Volume{ - volName: {Path: volPath, Mounts: map[string]*string{reqMountID: nil}}, + volName: {Path: volPath, Mounts: map[string]int{reqMountID: 1}}, }, req: &volume.UnmountRequest{Name: volName, ID: reqMountID}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{} + mounts := map[string]int{} assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, plugin.volumes) @@ -983,12 +1044,33 @@ func TestPluginUnmount(t *testing.T) { pluginVolumes: map[string]*types.Volume{ volName: { Path: volPath, - Mounts: map[string]*string{"someMount": nil, reqMountID: nil}, + Mounts: map[string]int{"someMount": 1, reqMountID: 1}, + }, + }, + req: &volume.UnmountRequest{Name: volName, ID: reqMountID}, + assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { + mounts := map[string]int{"someMount": 1} + assert.Equal(t, + map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, + plugin.volumes) + assert.Equal(t, + &VolumeState{ + Volumes: map[string]*VolumeInfo{volName: {Path: volPath, Mounts: mounts}}, + }, + plugin.state.VolState) + }, + }, + { + name: "mount reference count decrements", + pluginVolumes: map[string]*types.Volume{ + volName: { + Path: volPath, + Mounts: map[string]int{reqMountID: 2}, }, }, req: &volume.UnmountRequest{Name: volName, ID: reqMountID}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{"someMount": nil} + mounts := map[string]int{reqMountID: 1} assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, plugin.volumes) @@ -1023,13 +1105,13 @@ func TestPluginUnmount(t *testing.T) { Return(errors.New("some error")) }, pluginVolumes: map[string]*types.Volume{ - volName: {Path: volPath, Mounts: map[string]*string{reqMountID: nil}}, + volName: {Path: volPath, Mounts: map[string]int{reqMountID: 1}}, }, req: &volume.UnmountRequest{Name: volName, ID: reqMountID}, expectedError: "failed to unmount volume volume: some error", assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { // Mount should not exist in the plugin state - mounts := map[string]*string{} + mounts := map[string]int{} assert.Equal(t, map[string]*types.Volume{volName: {Path: volPath, Mounts: mounts}}, plugin.volumes) @@ -1040,7 +1122,7 @@ func TestPluginUnmount(t *testing.T) { pluginVolumes: map[string]*types.Volume{volName: {Path: volPath, Options: volOpts}}, req: &volume.UnmountRequest{Name: volName, ID: reqMountID}, assertPluginState: func(t *testing.T, plugin *AmazonECSVolumePlugin) { - mounts := map[string]*string{} + mounts := map[string]int{} assert.Equal(t, map[string]*types.Volume{ volName: {Path: volPath, Mounts: nil, Options: volOpts}, diff --git a/ecs-init/volumes/state_manager.go b/ecs-init/volumes/state_manager.go index 1b387279c74..95ccce395b9 100644 --- a/ecs-init/volumes/state_manager.go +++ b/ecs-init/volumes/state_manager.go @@ -47,11 +47,11 @@ type VolumeState struct { // VolumeInfo contains the information of managed volumes type VolumeInfo struct { - Type string `json:"type,omitempty"` - Path string `json:"path,omitempty"` - Options map[string]string `json:"options,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` - Mounts map[string]*string `json:"mounts,omitempty"` + Type string `json:"type,omitempty"` + Path string `json:"path,omitempty"` + Options map[string]string `json:"options,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` + Mounts map[string]int `json:"mounts,omitempty"` } // NewStateManager initializes the state manager of volume plugin @@ -65,7 +65,7 @@ func NewStateManager() *StateManager { func (s *StateManager) recordVolume(volName string, vol *types.Volume) error { // Copy the mounts so that the map is not shared - mountsCopy := map[string]*string{} + mountsCopy := map[string]int{} for k, v := range vol.Mounts { mountsCopy[k] = v } diff --git a/ecs-init/volumes/state_manager_test.go b/ecs-init/volumes/state_manager_test.go index 62af4bf955e..63773553b47 100644 --- a/ecs-init/volumes/state_manager_test.go +++ b/ecs-init/volumes/state_manager_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSaveStateSuccess(t *testing.T) { @@ -85,6 +86,34 @@ func TestLoadStateSuccess(t *testing.T) { assert.NoError(t, s.load(oldState)) } +// Tests backwards compatibility of state loading. +// A change to state format has been introduced with reference counting of volume mounts. +func TestLoadStateFromOldFormat(t *testing.T) { + s := NewStateManager() + oldState := &VolumeState{} + readStateFile = func() ([]byte, error) { + return []byte(` + { + "volumes": { + "efsVolume": { + "type":"efs", + "mounts": {"id1": null} + } + } + }`), nil + } + defer func() { + readStateFile = readFile + }() + require.NoError(t, s.load(oldState)) + assert.Equal(t, + &VolumeState{Volumes: map[string]*VolumeInfo{"efsVolume": { + Type: "efs", + Mounts: map[string]int{"id1": 0}, + }}}, + oldState) +} + func TestLoadInvalidState(t *testing.T) { s := NewStateManager() oldState := &VolumeState{} diff --git a/ecs-init/volumes/types/types.go b/ecs-init/volumes/types/types.go index e3e81ab2130..2dc6b20acf0 100644 --- a/ecs-init/volumes/types/types.go +++ b/ecs-init/volumes/types/types.go @@ -19,16 +19,16 @@ type Volume struct { Path string Options map[string]string CreatedAt string - Mounts map[string]*string + Mounts map[string]int } // Adds a new mount to the volume. // This method is not thread-safe, caller is responsible for holding any locks on the volume. func (v *Volume) AddMount(mountID string) { if v.Mounts == nil { - v.Mounts = map[string]*string{} + v.Mounts = map[string]int{} } - v.Mounts[mountID] = nil + v.Mounts[mountID] += 1 } // Removes a mount from the volume. @@ -36,6 +36,14 @@ func (v *Volume) AddMount(mountID string) { // Returns a bool indicating whether the mountID was found in mounts or not. func (v *Volume) RemoveMount(mountID string) bool { _, exists := v.Mounts[mountID] - delete(v.Mounts, mountID) - return exists + if !exists { + return false + } + + v.Mounts[mountID] -= 1 + if v.Mounts[mountID] <= 0 { + delete(v.Mounts, mountID) + } + + return true } diff --git a/ecs-init/volumes/types/types_test.go b/ecs-init/volumes/types/types_test.go index 7ffdb260fa3..10655ab759a 100644 --- a/ecs-init/volumes/types/types_test.go +++ b/ecs-init/volumes/types/types_test.go @@ -23,18 +23,19 @@ func TestAddMount(t *testing.T) { t.Run("new map is created when Mounts is nil", func(t *testing.T) { v := &Volume{} v.AddMount("id") - assert.Equal(t, map[string]*string{"id": nil}, v.Mounts) + assert.Equal(t, map[string]int{"id": 1}, v.Mounts) }) t.Run("second mount", func(t *testing.T) { v := &Volume{} v.AddMount("id") v.AddMount("id2") - assert.Equal(t, map[string]*string{"id": nil, "id2": nil}, v.Mounts) + assert.Equal(t, map[string]int{"id": 1, "id2": 1}, v.Mounts) }) - t.Run("mount already exists", func(t *testing.T) { + t.Run("mount reference count is incremented if a mount already exists", func(t *testing.T) { v := &Volume{} v.AddMount("id") - assert.Equal(t, map[string]*string{"id": nil}, v.Mounts) + v.AddMount("id") + assert.Equal(t, map[string]int{"id": 2}, v.Mounts) }) } @@ -44,7 +45,14 @@ func TestRemoveMount(t *testing.T) { assert.False(t, v.RemoveMount("id")) assert.Empty(t, v.Mounts) }) - t.Run("mount should be removed if it exists", func(t *testing.T) { + t.Run("mount reference count is decremented", func(t *testing.T) { + v := &Volume{} + v.AddMount("id") + v.AddMount("id") + assert.True(t, v.RemoveMount("id")) + assert.Equal(t, map[string]int{"id": 1}, v.Mounts) + }) + t.Run("mount should be removed if it exists and mount reference count is 1", func(t *testing.T) { v := &Volume{} v.AddMount("id") assert.True(t, v.RemoveMount("id"))