diff --git a/go.mod b/go.mod index a4dde2cfa84a..8c46a45ea222 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/Pallinder/go-randomdata v1.2.0 github.com/PuerkitoBio/goquery v1.9.1 + github.com/avast/retry-go v3.0.0+incompatible github.com/aws/aws-sdk-go v1.51.16 github.com/aws/karpenter-provider-aws/tools/kompat v0.0.0-20240410220356-6b868db24881 github.com/awslabs/amazon-eks-ami/nodeadm v0.0.0-20240229193347-cfab22a10647 @@ -37,7 +38,6 @@ require ( contrib.go.opencensus.io/exporter/prometheus v0.4.2 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect - github.com/avast/retry-go v3.0.0+incompatible // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/blendle/zapdriver v1.3.1 // indirect diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index e428418123f1..ac24191fd0be 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -58,7 +58,9 @@ func NewControllers(ctx context.Context, sess *session.Session, clk clock.Clock, controllerspricing.NewController(pricingProvider), } if options.FromContext(ctx).InterruptionQueue != "" { - controllers = append(controllers, interruption.NewController(kubeClient, clk, recorder, lo.Must(sqs.NewProvider(ctx, servicesqs.New(sess), options.FromContext(ctx).InterruptionQueue)), unavailableOfferings)) + sqsapi := servicesqs.New(sess) + out := lo.Must(sqsapi.GetQueueUrlWithContext(ctx, &servicesqs.GetQueueUrlInput{QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue)})) + controllers = append(controllers, interruption.NewController(kubeClient, clk, recorder, lo.Must(sqs.NewDefaultProvider(sqsapi, lo.FromPtr(out.QueueUrl))), unavailableOfferings)) } return controllers } diff --git a/pkg/controllers/interruption/interruption_benchmark_test.go b/pkg/controllers/interruption/interruption_benchmark_test.go index 585fd22596c7..6a9a591db8de 100644 --- a/pkg/controllers/interruption/interruption_benchmark_test.go +++ b/pkg/controllers/interruption/interruption_benchmark_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/avast/retry-go" "github.com/aws/aws-sdk-go/aws" awsclient "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/endpoints" @@ -95,14 +96,15 @@ func benchmarkNotificationController(b *testing.B, messageCount int) { } }() - providers := newProviders(env.Context, env.Client) - if err := providers.makeInfrastructure(ctx); err != nil { + providers := newProviders(ctx, env.Client) + queueURL, err := providers.makeInfrastructure(ctx) + if err != nil { b.Fatalf("standing up infrastructure, %v", err) } // Cleanup the infrastructure after the coretest completes defer func() { if err := retry.Do(func() error { - return providers.cleanupInfrastructure(ctx) + return providers.cleanupInfrastructure(queueURL) }); err != nil { b.Fatalf("deleting infrastructure, %v", err) } @@ -174,31 +176,29 @@ func newProviders(ctx context.Context, kubeClient client.Client) providerSet { ), )) sqsAPI := servicesqs.New(sess) + out := lo.Must(sqsAPI.GetQueueUrlWithContext(ctx, &servicesqs.GetQueueUrlInput{QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue)})) return providerSet{ kubeClient: kubeClient, sqsAPI: sqsAPI, - sqsProvider: sqs.NewProvider(ctx, sqsAPI, "test-cluster"), + sqsProvider: lo.Must(sqs.NewDefaultProvider(sqsAPI, lo.FromPtr(out.QueueUrl))), } } -func (p *providerSet) makeInfrastructure(ctx context.Context) error { - if _, err := p.sqsAPI.CreateQueueWithContext(ctx, &servicesqs.CreateQueueInput{ - QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueueName), +func (p *providerSet) makeInfrastructure(ctx context.Context) (string, error) { + out, err := p.sqsAPI.CreateQueueWithContext(ctx, &servicesqs.CreateQueueInput{ + QueueName: lo.ToPtr(options.FromContext(ctx).InterruptionQueue), Attributes: map[string]*string{ servicesqs.QueueAttributeNameMessageRetentionPeriod: aws.String("1200"), // 20 minutes for this test }, - }); err != nil { - return fmt.Errorf("creating servicesqs queue, %w", err) + }) + if err != nil { + return "", fmt.Errorf("creating servicesqs queue, %w", err) } - return nil + return lo.FromPtr(out.QueueUrl), nil } -func (p *providerSet) cleanupInfrastructure(ctx context.Context) error { - queueURL, err := p.sqsProvider.DiscoverQueueURL(ctx) - if err != nil { - return fmt.Errorf("discovering queue url for deletion, %w", err) - } - if _, err = p.sqsAPI.DeleteQueueWithContext(ctx, &servicesqs.DeleteQueueInput{ +func (p *providerSet) cleanupInfrastructure(queueURL string) error { + if _, err := p.sqsAPI.DeleteQueueWithContext(ctx, &servicesqs.DeleteQueueInput{ QueueUrl: lo.ToPtr(queueURL), }); err != nil { return fmt.Errorf("deleting servicesqs queue, %w", err) @@ -220,11 +220,11 @@ func (p *providerSet) monitorMessagesProcessed(ctx context.Context, eventRecorde totalProcessed := 0 go func() { for totalProcessed < expectedProcessed { - totalProcessed = eventRecorder.Calls(events.InstanceStopping(coretest.Node()).Reason) + - eventRecorder.Calls(events.InstanceTerminating(coretest.Node()).Reason) + - eventRecorder.Calls(events.InstanceUnhealthy(coretest.Node()).Reason) + - eventRecorder.Calls(events.InstanceRebalanceRecommendation(coretest.Node()).Reason) + - eventRecorder.Calls(events.InstanceSpotInterrupted(coretest.Node()).Reason) + totalProcessed = eventRecorder.Calls(events.Stopping(coretest.Node(), coretest.NodeClaim())[0].Reason) + + eventRecorder.Calls(events.Stopping(coretest.Node(), coretest.NodeClaim())[0].Reason) + + eventRecorder.Calls(events.Unhealthy(coretest.Node(), coretest.NodeClaim())[0].Reason) + + eventRecorder.Calls(events.RebalanceRecommendation(coretest.Node(), coretest.NodeClaim())[0].Reason) + + eventRecorder.Calls(events.SpotInterrupted(coretest.Node(), coretest.NodeClaim())[0].Reason) logging.FromContext(ctx).With("processed-message-count", totalProcessed).Infof("processed messages from the queue") time.Sleep(time.Second) } diff --git a/pkg/controllers/interruption/suite_test.go b/pkg/controllers/interruption/suite_test.go index d0248494f0d5..52c2b0972add 100644 --- a/pkg/controllers/interruption/suite_test.go +++ b/pkg/controllers/interruption/suite_test.go @@ -82,7 +82,7 @@ var _ = BeforeSuite(func() { fakeClock = &clock.FakeClock{} unavailableOfferingsCache = awscache.NewUnavailableOfferings() sqsapi = &fake.SQSAPI{} - sqsProvider = lo.Must(sqs.NewProvider(ctx, sqsapi, "test-cluster")) + sqsProvider = lo.Must(sqs.NewDefaultProvider(sqsapi, fmt.Sprintf("https://sqs.%s.amazonaws.com/%s/test-cluster", fake.DefaultRegion, fake.DefaultAccount))) controller = interruption.NewController(env.Client, fakeClock, events.NewRecorder(&record.FakeRecorder{}), sqsProvider, unavailableOfferingsCache) }) diff --git a/pkg/providers/sqs/sqs.go b/pkg/providers/sqs/sqs.go index 731143333811..e1687c94c6a7 100644 --- a/pkg/providers/sqs/sqs.go +++ b/pkg/providers/sqs/sqs.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sqs" @@ -34,26 +35,19 @@ type Provider interface { type DefaultProvider struct { client sqsiface.SQSAPI - name string - url string + queueURL string } -func NewProvider(ctx context.Context, client sqsiface.SQSAPI, queueName string) (*DefaultProvider, error) { - ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) - if err != nil { - return nil, fmt.Errorf("fetching queue url, %w", err) - } +func NewDefaultProvider(client sqsiface.SQSAPI, queueURL string) (*DefaultProvider, error) { return &DefaultProvider{ - client: client, - name: queueName, - url: aws.StringValue(ret.QueueUrl), + client: client, + queueURL: queueURL, }, nil } func (p *DefaultProvider) Name() string { - return p.name + ss := strings.Split(p.queueURL, "/") + return ss[len(ss)-1] } func (p *DefaultProvider) GetSQSMessages(ctx context.Context) ([]*sqs.Message, error) { @@ -67,7 +61,7 @@ func (p *DefaultProvider) GetSQSMessages(ctx context.Context) ([]*sqs.Message, e MessageAttributeNames: []*string{ aws.String(sqs.QueueAttributeNameAll), }, - QueueUrl: aws.String(p.url), + QueueUrl: aws.String(p.queueURL), } result, err := p.client.ReceiveMessageWithContext(ctx, input) @@ -85,7 +79,7 @@ func (p *DefaultProvider) SendMessage(ctx context.Context, body interface{}) (st } input := &sqs.SendMessageInput{ MessageBody: aws.String(string(raw)), - QueueUrl: aws.String(p.url), + QueueUrl: aws.String(p.queueURL), } result, err := p.client.SendMessageWithContext(ctx, input) if err != nil { @@ -96,7 +90,7 @@ func (p *DefaultProvider) SendMessage(ctx context.Context, body interface{}) (st func (p *DefaultProvider) DeleteSQSMessage(ctx context.Context, msg *sqs.Message) error { input := &sqs.DeleteMessageInput{ - QueueUrl: aws.String(p.url), + QueueUrl: aws.String(p.queueURL), ReceiptHandle: msg.ReceiptHandle, } diff --git a/test/pkg/environment/aws/environment.go b/test/pkg/environment/aws/environment.go index eff7dd2e9306..e874cf325ce1 100644 --- a/test/pkg/environment/aws/environment.go +++ b/test/pkg/environment/aws/environment.go @@ -119,7 +119,9 @@ func NewEnvironment(t *testing.T) *Environment { } // Initialize the provider only if the INTERRUPTION_QUEUE environment variable is defined if v, ok := os.LookupEnv("INTERRUPTION_QUEUE"); ok { - awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, servicesqs.New(session), v)) + sqsapi := servicesqs.New(session) + out := lo.Must(sqsapi.GetQueueUrlWithContext(env.Context, &servicesqs.GetQueueUrlInput{QueueName: aws.String(v)})) + awsEnv.SQSProvider = lo.Must(sqs.NewDefaultProvider(sqsapi, lo.FromPtr(out.QueueUrl))) } return awsEnv }