Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context helpers #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

@samber: I sometimes forget to update this file. Ping me on [Twitter](https://twitter.com/samuelberthe) or open an issue in case of error. We need to keep a clear changelog for easier lib upgrade.

## 1.39.0 (xxxx-xx-xx)

Adding:
- lo.ContextWith
- lo.FromContext
- lo.FromContextOr

## 1.38.1 (2023-03-20)

Improvement:
Expand Down
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ Error handling:
- [TryCatchWithErrorValue](#trycatchwitherrorvalue)
- [ErrorsAs](#errorsas)

Supported helpers for context:

- [ContextWith](#contextwith)
- [FromContext](#fromcontext)
- [FromContextOr](#fromcontextor)

Constraints:

- Clonable
Expand Down Expand Up @@ -2851,6 +2857,35 @@ if rateLimitErr, ok := lo.ErrorsAs[*RateLimitError](err); ok {

[[play](https://go.dev/play/p/8wk5rH8UfrE)]


### ContextWith

Attach an object by type to a context:

```go
type user struct{ id, name string }

ctx = lo.ContextWith(ctx, &user{id: "42", name: "John Doe"})
```

### FromContext

Retrieve an object by type from a context:

```go
userfromContext, found := lo.FromContext[*user](ctx)
```

### FromContextOr

Retrieve an object by type from a context with a default value if not found.

Example get logger attached to context:
```go
logger := lo.FromContextOr[*slog.Logger](ctx, slog.Default())
```


## 🛩 Benchmark

We executed a simple benchmark with the a dead-simple `lo.Map` loop:
Expand Down
28 changes: 28 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package lo

import (
"context"
)

// ContextWith return a new context with the value attached by type.
func ContextWith[T any](ctx context.Context, val T) context.Context {
return context.WithValue(ctx, (*T)(nil), val)
}

// FromContext returns the entry in context using type as the context key.
func FromContext[T any](ctx context.Context) (val T, ok bool) {
val, ok = ctx.Value((*T)(nil)).(T)

return val, ok
}

// FromContextOr returns the entry in context using type as the context key otherwise return default value.
func FromContextOr[T any](ctx context.Context, def T) T {
val, ok := FromContext[T](ctx)

if ok {
return val
}

return def
}
24 changes: 24 additions & 0 deletions context_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package lo_test

import (
"context"
"fmt"

"github.com/samber/lo"
)

// ExampleContext set and retrieve a custom user type from context
func ExampleContext() {
ctx := context.Background()

type user struct{ id, name string }

ctx = lo.ContextWith(ctx, &user{id: "42", name: "John Doe"})

userfromContext, ok := lo.FromContext[*user](ctx)
fmt.Printf("%v\n%#v", ok, userfromContext)

// Output:
// true
// &lo_test.user{id:"42", name:"John Doe"}
}
241 changes: 241 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package lo_test

import (
"bytes"
"context"
"log"
"log/slog"
"net/http"
"testing"
"time"

"github.com/samber/lo"
"github.com/stretchr/testify/assert"
)

func TestContextString(t *testing.T) {
t.Parallel()

is := assert.New(t)

ctx := lo.ContextWith(context.Background(), "some string")
is.NotNil(ctx)

result, ok := lo.FromContext[string](ctx)

is.True(ok)
is.Equal("some string", result)
}

func TestContextFunc(t *testing.T) { // not comparable
t.Parallel()

is := assert.New(t)

ctx := lo.ContextWith(context.Background(), func() string { return "ok" })
is.NotNil(ctx)

result, ok := lo.FromContext[func() string](ctx)

is.True(ok)
is.Equal("ok", result())
}

func TestContextCustom(t *testing.T) {
t.Parallel()

is := assert.New(t)

type user struct {
id, name string
}

expected := user{
id: lo.RandomString(10, lo.AlphanumericCharset),
name: lo.RandomString(10, lo.AlphanumericCharset),
}

ctx := lo.ContextWith(context.Background(), expected)
is.NotNil(ctx)

result, ok := lo.FromContext[user](ctx)

is.True(ok)
is.Equal(expected.id, result.id)
is.Equal(expected.name, result.name)

// should not be updated
expected.name = lo.RandomString(10, lo.AlphanumericCharset)
is.NotEqual(expected.name, result.name)
}

func TestContextCustomPointer(t *testing.T) {
t.Parallel()

is := assert.New(t)

type user struct {
id, name string
}

expected := user{
id: lo.RandomString(10, lo.AlphanumericCharset),
name: lo.RandomString(10, lo.AlphanumericCharset),
}

ctx := lo.ContextWith(context.Background(), &expected)
is.NotNil(ctx)

result, ok := lo.FromContext[*user](ctx)

is.True(ok)
is.Equal(expected.id, result.id)
is.Equal(expected.name, result.name)

// should be updated by reference
expected.name = lo.RandomString(10, lo.AlphanumericCharset)
is.Equal(expected.name, result.name)
}

func TestContextLogger(t *testing.T) {
t.Parallel()

is := assert.New(t)

buf := new(bytes.Buffer)

logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
// Remove time from the output for predictable test output.
if a.Key == slog.TimeKey {
return slog.Attr{}
}

return a
},
}))

ctx := lo.ContextWith(context.Background(), logger.With(
slog.Group("request",
slog.String("method", http.MethodPost),
slog.String("url", "http://localhost")),
))
is.NotNil(ctx)

result, ok := lo.FromContext[*slog.Logger](ctx)

is.True(ok)
is.NotNil(result)

result.Info("testing")

is.Equal(
"level=INFO msg=testing request.method=POST request.url=http://localhost\n",
buf.String(),
)
}

func TestFromContextOrLogger(t *testing.T) {
t.Parallel()

is := assert.New(t)

buf := new(bytes.Buffer)

logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
// Remove time from the output for predictable test output.
if a.Key == slog.TimeKey {
return slog.Attr{}
}

return a
},
}))

ctx := lo.ContextWith(context.Background(), logger.With(
slog.Group("request",
slog.String("method", http.MethodPost),
slog.String("url", "http://localhost")),
))
is.NotNil(ctx)

result := lo.FromContextOr[*slog.Logger](ctx, slog.Default())

is.NotNil(result)

result.Info("testing")

is.Equal(
"level=INFO msg=testing request.method=POST request.url=http://localhost\n",
buf.String(),
)
}

func TestFromContextOrLoggerDefault(t *testing.T) {
t.Parallel()

is := assert.New(t)

logger := lo.FromContextOr[*log.Logger](context.Background(), log.Default())

is.NotNil(logger)

logger.Print("testing")
}

func TestContextMultipleType(t *testing.T) {
t.Parallel()

is := assert.New(t)

ctx := context.Background()

ctx = lo.ContextWith(ctx, "some string")
is.NotNil(ctx)

ctx = lo.ContextWith(ctx, time.Date(2023, time.September, 20, 1, 10, 20, 30, time.UTC))
is.NotNil(ctx)

ctx = lo.ContextWith(ctx, func() string { return "ok" })
is.NotNil(ctx)

ctx = lo.ContextWith(ctx, map[string]any{"a": "map"})
is.NotNil(ctx)

type user struct {
id, name string
}

loggedUser := &user{
id: lo.RandomString(10, lo.AlphanumericCharset),
name: lo.RandomString(10, lo.AlphanumericCharset),
}
ctx = lo.ContextWith(ctx, loggedUser)
is.NotNil(ctx)

resultStr, ok := lo.FromContext[string](ctx)

is.True(ok)
is.Equal("some string", resultStr)

resultTime, ok := lo.FromContext[time.Time](ctx)

is.True(ok)
is.Equal(time.Date(2023, time.September, 20, 1, 10, 20, 30, time.UTC), resultTime)

resultUser, ok := lo.FromContext[*user](ctx)

is.True(ok)
is.Equal(loggedUser, resultUser)

resultFunc, ok := lo.FromContext[func() string](ctx)

is.True(ok)
is.Equal("ok", resultFunc())

resultMap, ok := lo.FromContext[map[string]any](ctx)

is.True(ok)
is.Equal(map[string]any{"a": "map"}, resultMap)
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/samber/lo

go 1.18
go 1.20

//
// Dependencies are excluded from releases. Please check CI.
Expand All @@ -9,7 +9,7 @@ go 1.18
require (
github.com/stretchr/testify v1.8.0
github.com/thoas/go-funk v0.9.1
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=
github.com/thoas/go-funk v0.9.1/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down