diff --git a/cell/lifecycle.go b/cell/lifecycle.go index cbe41ee..a41f75e 100644 --- a/cell/lifecycle.go +++ b/cell/lifecycle.go @@ -75,6 +75,19 @@ type augmentedHook struct { moduleID FullModuleID } +func NewDefaultLifecycle(hooks []HookInterface, numStarted int, logThreshold time.Duration) *DefaultLifecycle { + h := make([]augmentedHook, 0, len(hooks)) + for _, hook := range hooks { + h = append(h, augmentedHook{hook, nil}) + } + return &DefaultLifecycle{ + mu: sync.Mutex{}, + hooks: h, + numStarted: numStarted, + LogThreshold: logThreshold, + } +} + func (lc *DefaultLifecycle) Append(hook HookInterface) { lc.mu.Lock() defer lc.mu.Unlock() @@ -92,7 +105,7 @@ func (lc *DefaultLifecycle) Start(log *slog.Logger, ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - for _, hook := range lc.hooks { + for i, hook := range lc.hooks { fnName, exists := getHookFuncName(hook, true) if !exists { @@ -102,6 +115,13 @@ func (lc *DefaultLifecycle) Start(log *slog.Logger, ctx context.Context) error { } l := log.With("function", fnName) + + // Do not attempt to start already started hooks. + if i < lc.numStarted { + l.Error("Hook appears to be running. Skipping") + continue + } + l.Debug("Executing start hook") t0 := time.Now() if err := hook.Start(ctx); err != nil { diff --git a/cell/lifecycle_test.go b/cell/lifecycle_test.go index cc28aca..7a37f26 100644 --- a/cell/lifecycle_test.go +++ b/cell/lifecycle_test.go @@ -8,6 +8,7 @@ import ( "errors" "log/slog" "testing" + "time" "github.com/stretchr/testify/assert" @@ -15,41 +16,56 @@ import ( ) var ( - started, stopped int - errLifecycle = errors.New("nope") - - goodHook = cell.Hook{ - OnStart: func(cell.HookContext) error { - started++ - return nil - }, - OnStop: func(cell.HookContext) error { - stopped++ - return nil - }, - } - badStartHook = cell.Hook{ OnStart: func(cell.HookContext) error { return errLifecycle }, } - badStopHook = cell.Hook{ - OnStart: func(cell.HookContext) error { - started++ - return nil - }, - OnStop: func(cell.HookContext) error { - return errLifecycle - }, + nilHook = cell.Hook{OnStart: nil, OnStop: nil} +) + +func TestNewDefaultLifecycle(t *testing.T) { + var started, stopped int + goodHook := cell.Hook{ + OnStart: func(cell.HookContext) error { started++; return nil }, + OnStop: func(cell.HookContext) error { stopped++; return nil }, } - nilHook = cell.Hook{nil, nil} -) + log := slog.Default() + lc := cell.NewDefaultLifecycle([]cell.HookInterface{goodHook}, 0, time.Second) + + err := lc.Start(log, context.TODO()) + assert.NoError(t, err, "expected Start to succeed") + err = lc.Stop(log, context.TODO()) + assert.NoError(t, err, "expected Stop to succeed") + + assert.Equal(t, 1, started) + assert.Equal(t, 1, stopped) + + // Construct already started hooks + started = 0 + stopped = 0 + + lc = cell.NewDefaultLifecycle([]cell.HookInterface{goodHook}, 1, time.Second) + err = lc.Stop(log, context.TODO()) + assert.NoError(t, err, "expected Stop to succeed") + + assert.Equal(t, 0, started) + assert.Equal(t, 1, stopped) +} func TestLifecycle(t *testing.T) { + var started, stopped int + goodHook := cell.Hook{ + OnStart: func(cell.HookContext) error { started++; return nil }, + OnStop: func(cell.HookContext) error { stopped++; return nil }, + } + badStopHook := cell.Hook{ + OnStart: func(cell.HookContext) error { started++; return nil }, + OnStop: func(cell.HookContext) error { return errLifecycle }, + } log := slog.Default() var lc cell.DefaultLifecycle