diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bd28b97 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 pkg + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..0d5ff72 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Group + +A group provides a way manage the lifetime of a group of related goroutines. \ No newline at end of file diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..c8e24d5 --- /dev/null +++ b/example_test.go @@ -0,0 +1,210 @@ +package group_test + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "time" + + "github.com/pkg/group" +) + +type Group = group.G + +func ExampleGroup_Wait() { + // A Group's zero value is ready to use. + var g group.G + + // Add a goroutine to the group. + g.Add(func(c context.Context) error { + select { + case <-c.Done(): + return c.Err() + case <-time.After(1 * time.Second): + return errors.New("timed out") + } + }) + + // Wait for all goroutines to finish. + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: timed out +} + +func ExampleGroup_Wait_with_startup_error() { + // A Group's zero value is ready to use. + var g group.G + + // Add a goroutine to the group. + g.Add(func(_ context.Context) error { + return errors.New("startup error") + }) + + // Wait for all goroutines to finish, in this case it will return the startup error. + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: startup error +} + +func ExampleGroup_Wait_with_panic() { + // A Group's zero value is ready to use. + var g group.G + + // Add a goroutine to the group. + g.Add(func(c context.Context) error { + panic("boom") + }) + + // Wait for all goroutines to finish. + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: panic: boom +} + +func ExampleGroup_Wait_with_shutdown() { + // A Group's zero value is ready to use. + var g group.G + + shutdown := make(chan struct{}) + + // Add a goroutine to the group. + g.Add(func(c context.Context) error { + select { + case <-c.Done(): + return errors.New("stopped") + case <-shutdown: + return errors.New("shutdown") + } + }) + + time.AfterFunc(100*time.Millisecond, func() { + close(shutdown) + }) + + // Wait for all goroutines to finish. + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: shutdown +} + +func ExampleGroup_Wait_with_context_cancel() { + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) + + // pass WithContext option to New to use the provided context. + g := group.New(group.WithContext(ctx)) + + // Add a goroutine to the group. + g.Add(func(c context.Context) error { + select { + case <-c.Done(): + return c.Err() + } + }) + + // Cancel the context. + cancel() + + // Wait for all goroutines to finish. + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: context canceled +} + +func ExampleGroup_Wait_with_signal() { + ctx := context.Background() + ctx, _ = signal.NotifyContext(ctx, os.Interrupt) + + g := group.New(group.WithContext(ctx)) + + g.Add(MainHTTPServer) + g.Add(DebugHTTPServer) + g.Add(AsyncLogger) + + <-time.After(100 * time.Millisecond) + + // simulate ^C + proc, _ := os.FindProcess(os.Getpid()) + proc.Signal(os.Interrupt) + + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Unordered output: + // async logger started + // debug http server started + // main http server started + // async logger stopped + // main http server stopped + // debug http server stopped + // context canceled +} + +func ExampleGroup_Wait_with_http_shutdown() { + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) + defer cancel() + + g := group.New(group.WithContext(ctx)) + + g.Add(func(ctx context.Context) error { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + svr := http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello, world!") + })} + go func() { + svr.Serve(l) + }() + + <-ctx.Done() // wait for group to stop + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // five seconds graceful timeout + defer cancel() + return svr.Shutdown(shutdownCtx) + }) + + if err := g.Wait(); err != nil { + fmt.Println(err) + } + + // Output: +} +func MainHTTPServer(ctx context.Context) error { + fmt.Println("main http server started") + defer fmt.Println("main http server stopped") + <-ctx.Done() + return ctx.Err() +} + +func DebugHTTPServer(ctx context.Context) error { + fmt.Println("debug http server started") + defer fmt.Println("debug http server stopped") + <-ctx.Done() + return ctx.Err() +} + +func AsyncLogger(ctx context.Context) error { + fmt.Println("async logger started") + defer fmt.Println("async logger stopped") + <-ctx.Done() + return ctx.Err() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a6f926b --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/pkg/group + +go 1.23 diff --git a/group.go b/group.go new file mode 100644 index 0000000..bd1c022 --- /dev/null +++ b/group.go @@ -0,0 +1,86 @@ +// package group provides a way to manage the lifecycle of a group of goroutines. +package group + +import ( + "context" + "fmt" + "sync" +) + +// G manages the lifetime of a set of goroutines from a common context. +// The first goroutine in the group to return will cause the context to be canceled, +// terminating the remaining goroutines. +type G struct { + // ctx is the context passed to all goroutines in the group. + ctx context.Context + cancel context.CancelFunc + done sync.WaitGroup + + initOnce sync.Once + + errOnce sync.Once + err error +} + +type Option func(*G) + +// WithContext uses the provided context for the group. +func WithContext(ctx context.Context) Option { + return func(g *G) { + g.ctx = ctx + } +} + +// New creates a new group. +func New(opts ...Option) *G { + g := new(G) + for _, opt := range opts { + opt(g) + } + return g +} + +// init initializes the group. +func (g *G) init() { + if g.ctx == nil { + g.ctx = context.Background() + } + g.ctx, g.cancel = context.WithCancel(g.ctx) +} + +// add adds a new goroutine to the group. The goroutine should exit when the context +// passed to it is canceled. +func (g *G) Add(fn func(context.Context) error) { + g.initOnce.Do(g.init) + g.done.Add(1) + go func() { + defer g.done.Done() + defer g.cancel() + defer func() { + if r := recover(); r != nil { + g.errOnce.Do(func() { + if err, ok := r.(error); ok { + g.err = err + } else { + g.err = fmt.Errorf("panic: %v", r) + } + }) + } + }() + if err := fn(g.ctx); err != nil { + g.errOnce.Do(func() { g.err = err }) + } + }() +} + +// wait waits for all goroutines in the group to exit. If any of the goroutines +// fail with an error, wait will return the first error. +// Wait waits for all goroutines in the group to exit. +// If any of the goroutines fail with an error, Wait will return the first error. +func (g *G) Wait() error { + g.done.Wait() + g.errOnce.Do(func() { + // noop, required to synchronise on the errOnce mutex. + }) + return g.err +}