From ceaabf622568a4734c364beb5c5a407e7e9840f0 Mon Sep 17 00:00:00 2001 From: Greg Weber Date: Wed, 11 Dec 2024 11:57:44 -0600 Subject: [PATCH] configure go routine launching --- README.md | 16 ++++++++++++- concurrent.go | 56 +++++++++++++++++++++++++++++++++++----------- concurrent_test.go | 45 +++++++++++++++++++++++++++++++++++++ group.go | 18 ++++++++++----- 4 files changed, 115 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 9cc3219..0c8cde8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,23 @@ Go library to run code concurrently +## Running multiple Go routines concurrency + * GoN - run N go routines concurrently -* GoEach - run a go routine per slice member +* GoEach - run a go routine for each array element * NewGroupContext - Similar to x/sync/errgroup but catches panics and returns all errors + +It is possible to instrument how the go routines are launched or launch them in serial for debugging. +See: + +* GoSerial - running in serial for debugging +* GoRoutine - create your own go routine launcher +* GoRoutine.GoN(...) +* GoEachRoutine(...)(GoRoutine) +* group.SetGoRoutine(GoRoutine) + +## Concurrency helpers + * UnboundedChan * ChannelMerge * TrySend diff --git a/concurrent.go b/concurrent.go index 3623421..2046d8a 100644 --- a/concurrent.go +++ b/concurrent.go @@ -13,19 +13,7 @@ import ( // If there are no errors, the slice will be nil. // To combine the errors as a single error, use errors.Join. func GoN(n int, fn func(int) error) []error { - errs := make([]error, n) - var wg sync.WaitGroup - for i := 0; i < n; i++ { - i := i - wg.Add(1) - go recovery.GoHandler(func(err error) { errs[i] = err }, func() error { - defer wg.Done() - errs[i] = fn(i) - return nil - }) - } - wg.Wait() - return errors.Joins(errs...) + return GoConcurrent().GoN(n, fn) } // GoEach runs a go routine for each item in an Array. @@ -41,6 +29,48 @@ func GoEach[T any](all []T, fn func(T) error) []error { }) } +func GoConcurrent() GoRoutine { + return GoRoutine(func(work func()) { go work() }) +} + +func GoSerial() GoRoutine { + return GoRoutine(func(work func()) { work() }) +} + +// GoRoutine allows for inserting hooks before launching Go routines +// GoSerial() allows for running in serial for debugging +type GoRoutine func(func()) + +func (gr GoRoutine) GoN(n int, fn func(int) error) []error { + errs := make([]error, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + i := i + wg.Add(1) + gr(func() { + recovery.GoHandler(func(err error) { errs[i] = err }, func() error { + defer wg.Done() + errs[i] = fn(i) + return nil + }) + }) + } + wg.Wait() + return errors.Joins(errs...) +} + +// GoEach but with a configurable GoRoutine. +// GoEach uses generics, so it cannot be called directly as a method. +// Instead, apply the GoEach arguments first, than apply the GoRoutine to the resulting function. +func GoEachRoutine[T any](all []T, work func(T) error) func(gr GoRoutine) []error { + return func(gr GoRoutine) []error { + return gr.GoN(len(all), func(n int) error { + item := all[n] + return work(item) + }) + } +} + // Merge multiple channels together. // From this article: https://go.dev/blog/pipelines func ChannelMerge[T any](cs ...<-chan T) <-chan T { diff --git a/concurrent_test.go b/concurrent_test.go index a1e0760..eb42446 100644 --- a/concurrent_test.go +++ b/concurrent_test.go @@ -31,6 +31,29 @@ func TestGoN(t *testing.T) { must.True(t, tracked[0]) } +func TestGoNSerials(t *testing.T) { + var err []error + gr := concurrent.GoSerial() + workNone := func(_ int) error { return nil } + err = gr.GoN(0, workNone) + must.Nil(t, err) + err = gr.GoN(2, workNone) + must.Nil(t, err) + + tracked := make([]bool, 10) + workTracked := func(i int) error { tracked[i] = true; return nil } + err = gr.GoN(0, workTracked) + must.Nil(t, err) + must.False(t, tracked[0]) + + tracked = make([]bool, 10) + err = gr.GoN(2, workTracked) + must.Nil(t, err) + must.False(t, tracked[2]) + must.True(t, tracked[1]) + must.True(t, tracked[0]) +} + func TestGoEach(t *testing.T) { var err []error tracked := make([]bool, 10) @@ -52,6 +75,28 @@ func TestGoEach(t *testing.T) { must.True(t, tracked[0]) } +func TestGoEachSerial(t *testing.T) { + var err []error + tracked := make([]bool, 10) + workNone := func(_ bool) error { return nil } + gr := concurrent.GoSerial() + err = concurrent.GoEachRoutine(tracked, workNone)(gr) + must.Nil(t, err) + + workTracked := func(_ bool) error { tracked[0] = true; return nil } + err = concurrent.GoEachRoutine(tracked, workTracked)(gr) + must.Nil(t, err) + must.False(t, tracked[1]) + must.True(t, tracked[0]) + + workTracked = func(_ bool) error { tracked[1] = true; return nil } + err = concurrent.GoEachRoutine(tracked, workTracked)(gr) + must.Nil(t, err) + must.False(t, tracked[2]) + must.True(t, tracked[1]) + must.True(t, tracked[0]) +} + func TestChannelMerge(t *testing.T) { { c1 := make(chan error) diff --git a/group.go b/group.go index 31b6895..f0a0f78 100644 --- a/group.go +++ b/group.go @@ -52,10 +52,11 @@ type token struct{} // * panics in the functions that are ran are recovered and converted to errors. // Must be constructed with NewGroupContext type group struct { - errChan UnboundedChan[error] - wg sync.WaitGroup - cancel func(error) - sem chan token + errChan UnboundedChan[error] + wg sync.WaitGroup + cancel func(error) + sem chan token + goRoutine GoRoutine } func (g *group) do(fn func() error) { @@ -94,11 +95,16 @@ func (g *group) Wait() []error { func NewGroupContext(ctx context.Context) (*group, context.Context) { ctx, cancel := context.WithCancelCause(ctx) return &group{ - cancel: cancel, - errChan: NewUnboundedChan[error](), + cancel: cancel, + errChan: NewUnboundedChan[error](), + goRoutine: GoConcurrent(), }, ctx } +func (g *group) SetGoRoutine(gr GoRoutine) { + g.goRoutine = gr +} + func (g *group) Go(fn func() error) { if g.sem != nil { g.sem <- token{}