diff --git a/hcloud/action.go b/hcloud/action.go index 7100cf77..9b5d65e1 100644 --- a/hcloud/action.go +++ b/hcloud/action.go @@ -120,173 +120,6 @@ func (c *ActionClient) AllWithOpts(ctx context.Context, opts ActionListOpts) ([] return c.action.All(ctx, opts) } -// WatchOverallProgress watches several actions' progress until they complete -// with success or error. This watching happens in a goroutine and updates are -// provided through the two returned channels: -// -// - The first channel receives percentage updates of the progress, based on -// the number of completed versus total watched actions. The return value -// is an int between 0 and 100. -// - The second channel returned receives errors for actions that did not -// complete successfully, as well as any errors that happened while -// querying the API. -// -// By default, the method keeps watching until all actions have finished -// processing. If you want to be able to cancel the method or configure a -// timeout, use the [context.Context]. Once the method has stopped watching, -// both returned channels are closed. -// -// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait -// until sending the next request. -func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) { - errCh := make(chan error, len(actions)) - progressCh := make(chan int) - - go func() { - defer close(errCh) - defer close(progressCh) - - completedIDs := make([]int, 0, len(actions)) - watchIDs := make(map[int]struct{}, len(actions)) - - for _, action := range actions { - watchIDs[action.ID] = struct{}{} - } - - retries := 0 - previousProgress := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - opts := ActionListOpts{} - for watchID := range watchIDs { - opts.ID = append(opts.ID, watchID) - } - - as, err := c.AllWithOpts(ctx, opts) - if err != nil { - errCh <- err - return - } - if len(as) == 0 { - // No actions returned for the provided IDs, they do not exist in the API. - // We need to catch and fail early for this, otherwise the loop will continue - // indefinitely. - errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID) - return - } - - progress := 0 - for _, a := range as { - switch a.Status { - case ActionStatusRunning: - progress += a.Progress - case ActionStatusSuccess: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - case ActionStatusError: - delete(watchIDs, a.ID) - completedIDs = append(completedIDs, a.ID) - errCh <- fmt.Errorf("action %d failed: %w", a.ID, a.Error()) - } - } - - progress += len(completedIDs) * 100 - if progress != 0 && progress != previousProgress { - sendProgress(progressCh, progress/len(actions)) - previousProgress = progress - } - - if len(watchIDs) == 0 { - return - } - } - }() - - return progressCh, errCh -} - -// WatchProgress watches one action's progress until it completes with success -// or error. This watching happens in a goroutine and updates are provided -// through the two returned channels: -// -// - The first channel receives percentage updates of the progress, based on -// the progress percentage indicated by the API. The return value is an int -// between 0 and 100. -// - The second channel receives any errors that happened while querying the -// API, as well as the error of the action if it did not complete -// successfully, or nil if it did. -// -// By default, the method keeps watching until the action has finished -// processing. If you want to be able to cancel the method or configure a -// timeout, use the [context.Context]. Once the method has stopped watching, -// both returned channels are closed. -// -// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until -// sending the next request. -func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) { - errCh := make(chan error, 1) - progressCh := make(chan int) - - go func() { - defer close(errCh) - defer close(progressCh) - - retries := 0 - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case <-time.After(c.action.client.pollBackoffFunc(retries)): - retries++ - } - - a, _, err := c.GetByID(ctx, action.ID) - if err != nil { - errCh <- err - return - } - if a == nil { - errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID) - return - } - - switch a.Status { - case ActionStatusRunning: - sendProgress(progressCh, a.Progress) - case ActionStatusSuccess: - sendProgress(progressCh, 100) - errCh <- nil - return - case ActionStatusError: - errCh <- a.Error() - return - } - } - }() - - return progressCh, errCh -} - -// sendProgress allows the user to only read from the error channel and ignore any progress updates. -func sendProgress(progressCh chan int, p int) { - select { - case progressCh <- p: - break - default: - break - } -} - // ResourceActionClient is a client for the actions API exposed by the resource. type ResourceActionClient struct { resource string diff --git a/hcloud/action_test.go b/hcloud/action_test.go index 275ea5b4..e032383b 100644 --- a/hcloud/action_test.go +++ b/hcloud/action_test.go @@ -3,10 +3,7 @@ package hcloud import ( "context" "encoding/json" - "errors" "net/http" - "reflect" - "strings" "testing" "time" @@ -320,355 +317,3 @@ func TestResourceActionClientAll(t *testing.T) { t.Errorf("unexpected actions") } } - -func TestActionClientWatchOverallProgress(t *testing.T) { - t.Parallel() - env := newTestEnv() - defer env.Teardown() - - callCount := 0 - - env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { - callCount++ - var actions []schema.Action - - switch callCount { - case 1: - actions = []schema.Action{ - { - ID: 1, - Status: "running", - Progress: 50, - }, - { - ID: 2, - Status: "running", - Progress: 50, - }, - } - case 2: - actions = []schema.Action{ - { - ID: 1, - Status: "running", - Progress: 75, - }, - { - ID: 2, - Status: "error", - Progress: 100, - Error: &schema.ActionError{ - Code: "action_failed", - Message: "action failed", - }, - }, - } - case 3: - actions = []schema.Action{ - { - ID: 1, - Status: "success", - Progress: 100, - }, - } - default: - t.Errorf("unexpected number of calls to the test server: %v", callCount) - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(struct { - Actions []schema.Action `json:"actions"` - Meta schema.Meta `json:"meta"` - }{ - Actions: actions, - Meta: schema.Meta{ - Pagination: &schema.MetaPagination{ - Page: 1, - LastPage: 1, - PerPage: len(actions), - TotalEntries: len(actions), - }, - }, - }) - }) - - actions := []*Action{ - { - ID: 1, - Status: ActionStatusRunning, - }, - { - ID: 2, - Status: ActionStatusRunning, - }, - } - - ctx := context.Background() - progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions) - progressUpdates := []int{} - errs := []error{} - - moreProgress, moreErrors := true, true - - for moreProgress || moreErrors { - var progress int - var err error - - select { - case progress, moreProgress = <-progressCh: - if moreProgress { - progressUpdates = append(progressUpdates, progress) - } - case err, moreErrors = <-errCh: - if moreErrors { - errs = append(errs, err) - } - } - } - - if len(errs) != 1 { - t.Fatalf("expected to receive one error: %v", errs) - } - - err := errs[0] - - if e, ok := errors.Unwrap(err).(ActionError); !ok || e.Code != "action_failed" { - t.Fatalf("expected hcloud.Error, but got: %#v", err) - } - - expectedProgressUpdates := []int{50, 100} - if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { - t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) - } -} - -func TestActionClientWatchOverallProgressInvalidID(t *testing.T) { - env := newTestEnv() - defer env.Teardown() - - callCount := 0 - - env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { - callCount++ - var actions []schema.Action - - switch callCount { - case 1: - default: - t.Errorf("unexpected number of calls to the test server: %v", callCount) - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(struct { - Actions []schema.Action `json:"actions"` - Meta schema.Meta `json:"meta"` - }{ - Actions: actions, - Meta: schema.Meta{ - Pagination: &schema.MetaPagination{ - Page: 1, - LastPage: 1, - PerPage: len(actions), - TotalEntries: len(actions), - }, - }, - }) - }) - - actions := []*Action{ - { - ID: 1, - Status: ActionStatusRunning, - }, - } - - ctx := context.Background() - progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions) - progressUpdates := []int{} - errs := []error{} - - moreProgress, moreErrors := true, true - - for moreProgress || moreErrors { - var progress int - var err error - - select { - case progress, moreProgress = <-progressCh: - if moreProgress { - progressUpdates = append(progressUpdates, progress) - } - case err, moreErrors = <-errCh: - if moreErrors { - errs = append(errs, err) - } - } - } - - if len(errs) != 1 { - t.Fatalf("expected to receive one error: %v", errs) - } - - err := errs[0] - - if !strings.HasPrefix(err.Error(), "failed to wait for actions") { - t.Fatalf("expected failed to wait for actions error, but got: %#v", err) - } - - expectedProgressUpdates := []int{} - if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { - t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) - } -} - -func TestActionClientWatchProgress(t *testing.T) { - env := newTestEnv() - defer env.Teardown() - - callCount := 0 - - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.Header().Set("Content-Type", "application/json") - switch callCount { - case 1: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "running", - Progress: 50, - }, - }) - case 2: - w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeConflict), - Message: "conflict", - }, - }) - return - case 3: - _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ - Action: schema.Action{ - ID: 1, - Status: "error", - Progress: 100, - Error: &schema.ActionError{ - Code: "action_failed", - Message: "action failed", - }, - }, - }) - default: - t.Errorf("unexpected number of calls to the test server: %v", callCount) - } - }) - action := &Action{ - ID: 1, - Status: ActionStatusRunning, - Progress: 0, - } - - ctx := context.Background() - progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) - var ( - progressUpdates []int - err error - ) - -loop: - for { - select { - case progress := <-progressCh: - progressUpdates = append(progressUpdates, progress) - case err = <-errCh: - break loop - } - } - - if err == nil { - t.Fatal("expected an error") - } - if e, ok := err.(ActionError); !ok || e.Code != "action_failed" { - t.Fatalf("expected hcloud.Error, but got: %#v", err) - } - if len(progressUpdates) != 1 || progressUpdates[0] != 50 { - t.Fatalf("unexpected progress updates: %v", progressUpdates) - } -} - -func TestActionClientWatchProgressError(t *testing.T) { - env := newTestEnv() - defer env.Teardown() - - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnprocessableEntity) - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeServiceError), - Message: "service error", - }, - }) - }) - - action := &Action{ID: 1} - ctx := context.Background() - _, errCh := env.Client.Action.WatchProgress(ctx, action) - if err := <-errCh; err == nil { - t.Fatal("expected an error") - } -} - -func TestActionClientWatchProgressInvalidID(t *testing.T) { - env := newTestEnv() - defer env.Teardown() - - callCount := 0 - - env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - switch callCount { - case 1: - _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ - Error: schema.Error{ - Code: string(ErrorCodeNotFound), - Message: "action with ID '1' not found", - Details: nil, - }, - }) - default: - t.Errorf("unexpected number of calls to the test server: %v", callCount) - } - }) - action := &Action{ - ID: 1, - } - - ctx := context.Background() - progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) - var ( - progressUpdates []int - err error - ) - -loop: - for { - select { - case progress := <-progressCh: - progressUpdates = append(progressUpdates, progress) - case err = <-errCh: - break loop - } - } - - if !strings.HasPrefix(err.Error(), "failed to wait for action") { - t.Fatalf("expected failed to wait for action error, but got: %#v", err) - } - if len(progressUpdates) != 0 { - t.Fatalf("unexpected progress updates: %v", progressUpdates) - } -} diff --git a/hcloud/action_watch.go b/hcloud/action_watch.go new file mode 100644 index 00000000..28332470 --- /dev/null +++ b/hcloud/action_watch.go @@ -0,0 +1,173 @@ +package hcloud + +import ( + "context" + "fmt" + "time" +) + +// WatchOverallProgress watches several actions' progress until they complete +// with success or error. This watching happens in a goroutine and updates are +// provided through the two returned channels: +// +// - The first channel receives percentage updates of the progress, based on +// the number of completed versus total watched actions. The return value +// is an int between 0 and 100. +// - The second channel returned receives errors for actions that did not +// complete successfully, as well as any errors that happened while +// querying the API. +// +// By default, the method keeps watching until all actions have finished +// processing. If you want to be able to cancel the method or configure a +// timeout, use the [context.Context]. Once the method has stopped watching, +// both returned channels are closed. +// +// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait +// until sending the next request. +func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Action) (<-chan int, <-chan error) { + errCh := make(chan error, len(actions)) + progressCh := make(chan int) + + go func() { + defer close(errCh) + defer close(progressCh) + + completedIDs := make([]int, 0, len(actions)) + watchIDs := make(map[int]struct{}, len(actions)) + for _, action := range actions { + watchIDs[action.ID] = struct{}{} + } + + retries := 0 + previousProgress := 0 + + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case <-time.After(c.action.client.pollBackoffFunc(retries)): + retries++ + } + + opts := ActionListOpts{} + for watchID := range watchIDs { + opts.ID = append(opts.ID, watchID) + } + + as, err := c.AllWithOpts(ctx, opts) + if err != nil { + errCh <- err + return + } + if len(as) == 0 { + // No actions returned for the provided IDs, they do not exist in the API. + // We need to catch and fail early for this, otherwise the loop will continue + // indefinitely. + errCh <- fmt.Errorf("failed to wait for actions: remaining actions (%v) are not returned from API", opts.ID) + return + } + + progress := 0 + for _, a := range as { + switch a.Status { + case ActionStatusRunning: + progress += a.Progress + case ActionStatusSuccess: + delete(watchIDs, a.ID) + completedIDs = append(completedIDs, a.ID) + case ActionStatusError: + delete(watchIDs, a.ID) + completedIDs = append(completedIDs, a.ID) + errCh <- fmt.Errorf("action %d failed: %w", a.ID, a.Error()) + } + } + + progress += len(completedIDs) * 100 + if progress != 0 && progress != previousProgress { + sendProgress(progressCh, progress/len(actions)) + previousProgress = progress + } + + if len(watchIDs) == 0 { + return + } + } + }() + + return progressCh, errCh +} + +// WatchProgress watches one action's progress until it completes with success +// or error. This watching happens in a goroutine and updates are provided +// through the two returned channels: +// +// - The first channel receives percentage updates of the progress, based on +// the progress percentage indicated by the API. The return value is an int +// between 0 and 100. +// - The second channel receives any errors that happened while querying the +// API, as well as the error of the action if it did not complete +// successfully, or nil if it did. +// +// By default, the method keeps watching until the action has finished +// processing. If you want to be able to cancel the method or configure a +// timeout, use the [context.Context]. Once the method has stopped watching, +// both returned channels are closed. +// +// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until +// sending the next request. +func (c *ActionClient) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) { + errCh := make(chan error, 1) + progressCh := make(chan int) + + go func() { + defer close(errCh) + defer close(progressCh) + + retries := 0 + + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case <-time.After(c.action.client.pollBackoffFunc(retries)): + retries++ + } + + a, _, err := c.GetByID(ctx, action.ID) + if err != nil { + errCh <- err + return + } + if a == nil { + errCh <- fmt.Errorf("failed to wait for action %d: action not returned from API", action.ID) + return + } + + switch a.Status { + case ActionStatusRunning: + sendProgress(progressCh, a.Progress) + case ActionStatusSuccess: + sendProgress(progressCh, 100) + errCh <- nil + return + case ActionStatusError: + errCh <- a.Error() + return + } + } + }() + + return progressCh, errCh +} + +// sendProgress allows the user to only read from the error channel and ignore any progress updates. +func sendProgress(progressCh chan int, p int) { + select { + case progressCh <- p: + break + default: + break + } +} diff --git a/hcloud/action_watch_test.go b/hcloud/action_watch_test.go new file mode 100644 index 00000000..c08d91da --- /dev/null +++ b/hcloud/action_watch_test.go @@ -0,0 +1,365 @@ +package hcloud + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/hetznercloud/hcloud-go/hcloud/schema" +) + +func TestActionClientWatchOverallProgress(t *testing.T) { + t.Parallel() + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { + callCount++ + var actions []schema.Action + + switch callCount { + case 1: + actions = []schema.Action{ + { + ID: 1, + Status: "running", + Progress: 50, + }, + { + ID: 2, + Status: "running", + Progress: 50, + }, + } + case 2: + actions = []schema.Action{ + { + ID: 1, + Status: "running", + Progress: 75, + }, + { + ID: 2, + Status: "error", + Progress: 100, + Error: &schema.ActionError{ + Code: "action_failed", + Message: "action failed", + }, + }, + } + case 3: + actions = []schema.Action{ + { + ID: 1, + Status: "success", + Progress: 100, + }, + } + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(struct { + Actions []schema.Action `json:"actions"` + Meta schema.Meta `json:"meta"` + }{ + Actions: actions, + Meta: schema.Meta{ + Pagination: &schema.MetaPagination{ + Page: 1, + LastPage: 1, + PerPage: len(actions), + TotalEntries: len(actions), + }, + }, + }) + }) + + actions := []*Action{ + { + ID: 1, + Status: ActionStatusRunning, + }, + { + ID: 2, + Status: ActionStatusRunning, + }, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions) + progressUpdates := []int{} + errs := []error{} + + moreProgress, moreErrors := true, true + + for moreProgress || moreErrors { + var progress int + var err error + + select { + case progress, moreProgress = <-progressCh: + if moreProgress { + progressUpdates = append(progressUpdates, progress) + } + case err, moreErrors = <-errCh: + if moreErrors { + errs = append(errs, err) + } + } + } + + if len(errs) != 1 { + t.Fatalf("expected to receive one error: %v", errs) + } + + err := errs[0] + + if e, ok := errors.Unwrap(err).(ActionError); !ok || e.Code != "action_failed" { + t.Fatalf("expected hcloud.Error, but got: %#v", err) + } + + expectedProgressUpdates := []int{50, 100} + if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { + t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) + } +} + +func TestActionClientWatchOverallProgressInvalidID(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions", func(w http.ResponseWriter, r *http.Request) { + callCount++ + var actions []schema.Action + + switch callCount { + case 1: + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(struct { + Actions []schema.Action `json:"actions"` + Meta schema.Meta `json:"meta"` + }{ + Actions: actions, + Meta: schema.Meta{ + Pagination: &schema.MetaPagination{ + Page: 1, + LastPage: 1, + PerPage: len(actions), + TotalEntries: len(actions), + }, + }, + }) + }) + + actions := []*Action{ + { + ID: 1, + Status: ActionStatusRunning, + }, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchOverallProgress(ctx, actions) + progressUpdates := []int{} + errs := []error{} + + moreProgress, moreErrors := true, true + + for moreProgress || moreErrors { + var progress int + var err error + + select { + case progress, moreProgress = <-progressCh: + if moreProgress { + progressUpdates = append(progressUpdates, progress) + } + case err, moreErrors = <-errCh: + if moreErrors { + errs = append(errs, err) + } + } + } + + if len(errs) != 1 { + t.Fatalf("expected to receive one error: %v", errs) + } + + err := errs[0] + + if !strings.HasPrefix(err.Error(), "failed to wait for actions") { + t.Fatalf("expected failed to wait for actions error, but got: %#v", err) + } + + expectedProgressUpdates := []int{} + if !reflect.DeepEqual(progressUpdates, expectedProgressUpdates) { + t.Fatalf("expected progresses %v but received %v", expectedProgressUpdates, progressUpdates) + } +} + +func TestActionClientWatchProgress(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + switch callCount { + case 1: + _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ + Action: schema.Action{ + ID: 1, + Status: "running", + Progress: 50, + }, + }) + case 2: + w.WriteHeader(http.StatusConflict) + _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ + Error: schema.Error{ + Code: string(ErrorCodeConflict), + Message: "conflict", + }, + }) + return + case 3: + _ = json.NewEncoder(w).Encode(schema.ActionGetResponse{ + Action: schema.Action{ + ID: 1, + Status: "error", + Progress: 100, + Error: &schema.ActionError{ + Code: "action_failed", + Message: "action failed", + }, + }, + }) + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + }) + action := &Action{ + ID: 1, + Status: ActionStatusRunning, + Progress: 0, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) + var ( + progressUpdates []int + err error + ) + +loop: + for { + select { + case progress := <-progressCh: + progressUpdates = append(progressUpdates, progress) + case err = <-errCh: + break loop + } + } + + if err == nil { + t.Fatal("expected an error") + } + if e, ok := err.(ActionError); !ok || e.Code != "action_failed" { + t.Fatalf("expected hcloud.Error, but got: %#v", err) + } + if len(progressUpdates) != 1 || progressUpdates[0] != 50 { + t.Fatalf("unexpected progress updates: %v", progressUpdates) + } +} + +func TestActionClientWatchProgressError(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnprocessableEntity) + _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ + Error: schema.Error{ + Code: string(ErrorCodeServiceError), + Message: "service error", + }, + }) + }) + + action := &Action{ID: 1} + ctx := context.Background() + _, errCh := env.Client.Action.WatchProgress(ctx, action) + if err := <-errCh; err == nil { + t.Fatal("expected an error") + } +} + +func TestActionClientWatchProgressInvalidID(t *testing.T) { + env := newTestEnv() + defer env.Teardown() + + callCount := 0 + + env.Mux.HandleFunc("/actions/1", func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + switch callCount { + case 1: + _ = json.NewEncoder(w).Encode(schema.ErrorResponse{ + Error: schema.Error{ + Code: string(ErrorCodeNotFound), + Message: "action with ID '1' not found", + Details: nil, + }, + }) + default: + t.Errorf("unexpected number of calls to the test server: %v", callCount) + } + }) + action := &Action{ + ID: 1, + } + + ctx := context.Background() + progressCh, errCh := env.Client.Action.WatchProgress(ctx, action) + var ( + progressUpdates []int + err error + ) + +loop: + for { + select { + case progress := <-progressCh: + progressUpdates = append(progressUpdates, progress) + case err = <-errCh: + break loop + } + } + + if !strings.HasPrefix(err.Error(), "failed to wait for action") { + t.Fatalf("expected failed to wait for action error, but got: %#v", err) + } + if len(progressUpdates) != 0 { + t.Fatalf("unexpected progress updates: %v", progressUpdates) + } +}