From 2f824768917d4c84a281e825b468a5704caa8fa0 Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Thu, 30 May 2024 18:07:48 -0300 Subject: [PATCH 1/7] =?UTF-8?q?fix(=F0=9F=A9=B9):=20error=20handling=20in?= =?UTF-8?q?=20CSRF=20token=20storage=20retrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolved TODO for error handling to ensure code quality in retrieving CSRF token data from storage. --- middleware/csrf/csrf.go | 19 ++++++++-------- middleware/csrf/session_manager.go | 18 ++++++++------- middleware/csrf/storage_manager.go | 36 ++++++++++++++++++++++-------- 3 files changed, 47 insertions(+), 26 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index a04d85cb2f..66b1154439 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -18,8 +18,11 @@ var ( ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") ErrOriginInvalid = errors.New("origin invalid") ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") - errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user - dummyValue = []byte{'+'} + ErrNotGetStorage = errors.New("csrf storage not found data") + ErrNotSetStorage = errors.New("csrf storage not set data") + + errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user + dummyValue = []byte{'+'} ) // Handler for CSRF middleware @@ -102,10 +105,9 @@ func New(config ...Config) fiber.Handler { switch c.Method() { case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: cookieToken := c.Cookies(cfg.CookieName) - if cookieToken != "" { - raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) - + // In this case, handling error doesn't make sense because we have validations after the switch. + raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) if raw != nil { token = cookieToken // Token is valid, safe to set it } @@ -148,9 +150,8 @@ func New(config ...Config) fiber.Handler { return cfg.ErrorHandler(c, ErrTokenInvalid) } - raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) - - if raw == nil { + raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) + if err != nil || raw == nil { // If token is not in storage, expire the cookie expireCSRFCookie(c, cfg) // and return an error @@ -209,7 +210,7 @@ func HandlerFromContext(c fiber.Ctx) *Handler { // getRawFromStorage returns the raw value from the storage for the given token // returns nil if the token does not exist, is expired or is invalid -func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte { +func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) (raw []byte, err error) { if cfg.Session != nil { return sessionManager.getRaw(c, token, dummyValue) } diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 87172eb838..e7731e3385 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -26,20 +26,22 @@ func newSessionManager(s *session.Store, k string) *sessionManager { } // get token from session -func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { +func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) (rawToken []byte, err error) { sess, err := m.session.Get(c) if err != nil { - return nil + return nil, err } + token, ok := sess.Get(m.key).(Token) - if ok { - if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { - return nil - } - return token.Raw + if !ok { + return nil, ErrTokenInvalid + } + + if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { + return nil, ErrTokenInvalid } - return nil + return token.Raw, nil } // set token in session diff --git a/middleware/csrf/storage_manager.go b/middleware/csrf/storage_manager.go index b6d7f0160d..47ccbb3e0e 100644 --- a/middleware/csrf/storage_manager.go +++ b/middleware/csrf/storage_manager.go @@ -1,6 +1,7 @@ package csrf import ( + "fmt" "sync" "time" @@ -40,31 +41,48 @@ func newStorageManager(storage fiber.Storage) *storageManager { } // get raw data from storage or memory -func (m *storageManager) getRaw(key string) []byte { - var raw []byte +func (m *storageManager) getRaw(key string) (raw []byte, err error) { if m.storage != nil { - raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error + raw, err = m.storage.Get(key) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrNotGetStorage, err.Error()) + } } else { - raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error + var ok bool + raw, ok = m.memory.Get(key).([]byte) + if !ok { + return nil, ErrNotGetStorage + } } - return raw + + return raw, nil } // set data to storage or memory -func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) { +func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) (err error) { if m.storage != nil { - _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error + err = m.storage.Set(key, raw, exp) + if err != nil { + return fmt.Errorf("%w: %s", ErrNotSetStorage, err.Error()) + } } else { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here m.memory.Set(utils.CopyString(key), raw, exp) } + + return nil } // delete data from storage or memory -func (m *storageManager) delRaw(key string) { +func (m *storageManager) delRaw(key string) (err error) { if m.storage != nil { - _ = m.storage.Delete(key) //nolint:errcheck // TODO: Do not ignore error + err = m.storage.Delete(key) + if err != nil { + return fmt.Errorf("%w: %s", ErrNotSetStorage, err.Error()) + } } else { m.memory.Delete(key) } + + return nil } From 750c7f518f3eb81fcb32c884bb9442ce5f13b0b5 Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Thu, 30 May 2024 18:17:42 -0300 Subject: [PATCH 2/7] chore: fix lint --- middleware/csrf/csrf.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 66b1154439..e6fc4fa30d 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -107,8 +107,8 @@ func New(config ...Config) fiber.Handler { cookieToken := c.Cookies(cfg.CookieName) if cookieToken != "" { // In this case, handling error doesn't make sense because we have validations after the switch. - raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) - if raw != nil { + raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) + if raw != nil && err == nil { token = cookieToken // Token is valid, safe to set it } } From a53d2065f03b945f8663d449340914a3e855dbe1 Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Thu, 30 May 2024 18:25:14 -0300 Subject: [PATCH 3/7] fix: review changes --- middleware/csrf/csrf.go | 2 +- middleware/csrf/session_manager.go | 2 +- middleware/csrf/storage_manager.go | 25 ++++++++++++++----------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index e6fc4fa30d..9a9e577926 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -210,7 +210,7 @@ func HandlerFromContext(c fiber.Ctx) *Handler { // getRawFromStorage returns the raw value from the storage for the given token // returns nil if the token does not exist, is expired or is invalid -func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) (raw []byte, err error) { +func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) ([]byte, error) { if cfg.Session != nil { return sessionManager.getRaw(c, token, dummyValue) } diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index e7731e3385..3682219f22 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -26,7 +26,7 @@ func newSessionManager(s *session.Store, k string) *sessionManager { } // get token from session -func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) (rawToken []byte, err error) { +func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) { sess, err := m.session.Get(c) if err != nil { return nil, err diff --git a/middleware/csrf/storage_manager.go b/middleware/csrf/storage_manager.go index 47ccbb3e0e..743b01b8aa 100644 --- a/middleware/csrf/storage_manager.go +++ b/middleware/csrf/storage_manager.go @@ -7,6 +7,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/memory" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -41,7 +42,11 @@ func newStorageManager(storage fiber.Storage) *storageManager { } // get raw data from storage or memory -func (m *storageManager) getRaw(key string) (raw []byte, err error) { +func (m *storageManager) getRaw(key string) ([]byte, error) { + var ( + raw []byte + err error + ) if m.storage != nil { raw, err = m.storage.Get(key) if err != nil { @@ -59,30 +64,28 @@ func (m *storageManager) getRaw(key string) (raw []byte, err error) { } // set data to storage or memory -func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) (err error) { +func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) { if m.storage != nil { - err = m.storage.Set(key, raw, exp) + err := m.storage.Set(key, raw, exp) if err != nil { - return fmt.Errorf("%w: %s", ErrNotSetStorage, err.Error()) + log.Warnf("csrf: failed to save session in storage: %s", err.Error()) + return } } else { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here m.memory.Set(utils.CopyString(key), raw, exp) } - - return nil } // delete data from storage or memory -func (m *storageManager) delRaw(key string) (err error) { +func (m *storageManager) delRaw(key string) { if m.storage != nil { - err = m.storage.Delete(key) + err := m.storage.Delete(key) if err != nil { - return fmt.Errorf("%w: %s", ErrNotSetStorage, err.Error()) + log.Warnf("csrf: failed to delete session in storage: %s", err.Error()) + return } } else { m.memory.Delete(key) } - - return nil } From aae5a6c76f9883456750c4bcf2ba65f64a08e47c Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Thu, 30 May 2024 18:30:26 -0300 Subject: [PATCH 4/7] chore: more fixes by review --- middleware/csrf/csrf.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 9a9e577926..39f410cbdb 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" ) var ( @@ -18,8 +19,7 @@ var ( ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") ErrOriginInvalid = errors.New("origin invalid") ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") - ErrNotGetStorage = errors.New("csrf storage not found data") - ErrNotSetStorage = errors.New("csrf storage not set data") + ErrNotGetStorage = errors.New("unable to retrieve data from CSRF storage") errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user dummyValue = []byte{'+'} @@ -107,8 +107,8 @@ func New(config ...Config) fiber.Handler { cookieToken := c.Cookies(cfg.CookieName) if cookieToken != "" { // In this case, handling error doesn't make sense because we have validations after the switch. - raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) - if raw != nil && err == nil { + raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) //nolint:errcheck + if raw != nil { token = cookieToken // Token is valid, safe to set it } } @@ -152,6 +152,8 @@ func New(config ...Config) fiber.Handler { raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) if err != nil || raw == nil { + log.Error("Failed to retrieve CSRF token: ", err) + // If token is not in storage, expire the cookie expireCSRFCookie(c, cfg) // and return an error From 97be135d742cda1d8382d644a1a4ed0b64fa12c1 Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Thu, 30 May 2024 18:37:22 -0300 Subject: [PATCH 5/7] chore: added nocheck in specific case and added log in getRaw when not found session --- middleware/csrf/csrf.go | 2 +- middleware/csrf/session_manager.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 39f410cbdb..aee06f2677 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -107,7 +107,7 @@ func New(config ...Config) fiber.Handler { cookieToken := c.Cookies(cfg.CookieName) if cookieToken != "" { // In this case, handling error doesn't make sense because we have validations after the switch. - raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) //nolint:errcheck + raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) //nolint:errcheck //the details are in the comment above if raw != nil { token = cookieToken // Token is valid, safe to set it } diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 3682219f22..67530601d4 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -29,7 +29,8 @@ func newSessionManager(s *session.Store, k string) *sessionManager { func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) { sess, err := m.session.Get(c) if err != nil { - return nil, err + log.Warn("csrf: failed to get session: ", err) + return nil, ErrTokenNotFound } token, ok := sess.Get(m.key).(Token) From 98bbb40398c5b43bd06793b58da5b4ec8f50bf9f Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Mon, 8 Jul 2024 14:29:34 -0300 Subject: [PATCH 6/7] wip --- middleware/csrf/csrf.go | 33 ++++++++------- middleware/csrf/csrf_test.go | 66 +++++++++++++++++++++++++++++- middleware/csrf/session_manager.go | 7 +++- middleware/session/store.go | 26 +++++++++--- 4 files changed, 109 insertions(+), 23 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index aee06f2677..0609b5d744 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -8,18 +8,17 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/log" ) var ( - ErrTokenNotFound = errors.New("csrf token not found") - ErrTokenInvalid = errors.New("csrf token invalid") - ErrRefererNotFound = errors.New("referer not supplied") - ErrRefererInvalid = errors.New("referer invalid") - ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") - ErrOriginInvalid = errors.New("origin invalid") - ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") - ErrNotGetStorage = errors.New("unable to retrieve data from CSRF storage") + ErrTokenNotFound = errors.New("csrf token not found") + ErrTokenInvalid = errors.New("csrf token invalid") + ErrRefererNotFound = errors.New("referer not supplied") + ErrRefererInvalid = errors.New("referer invalid") + ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") + ErrOriginInvalid = errors.New("origin invalid") + ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") + ErrStorageRetrievalFailed = errors.New("unable to retrieve data from CSRF storage") errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user dummyValue = []byte{'+'} @@ -106,8 +105,11 @@ func New(config ...Config) fiber.Handler { case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: cookieToken := c.Cookies(cfg.CookieName) if cookieToken != "" { - // In this case, handling error doesn't make sense because we have validations after the switch. - raw, _ := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) //nolint:errcheck //the details are in the comment above + raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) + if err != nil { + println("hereee+" + err.Error()) + return cfg.ErrorHandler(c, err) + } if raw != nil { token = cookieToken // Token is valid, safe to set it } @@ -151,14 +153,17 @@ func New(config ...Config) fiber.Handler { } raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) - if err != nil || raw == nil { - log.Error("Failed to retrieve CSRF token: ", err) + if err != nil { + + return cfg.ErrorHandler(c, err) + } else if raw == nil { // If token is not in storage, expire the cookie expireCSRFCookie(c, cfg) // and return an error - return cfg.ErrorHandler(c, ErrTokenNotFound) + return cfg.ErrorHandler(c, ErrTokenInvalid) } + if cfg.SingleUseToken { // If token is single use, delete it from storage deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager) diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index e6c2ce8a58..5c72aec9bd 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1,6 +1,7 @@ package csrf import ( + "fmt" "net/http/httptest" "strings" "testing" @@ -1263,7 +1264,6 @@ func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) { ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Exploit CSRF token we just injected ctx.Request.Reset() @@ -1509,3 +1509,67 @@ func Test_CSRF_FromContextMethods_Invalid(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } + +type mockStorage struct{} + +func (m *mockStorage) Get(key string) ([]byte, error) { + return nil, fmt.Errorf("not found") +} + +func (m *mockStorage) Set(key string, val []byte, exp time.Duration) error { + return nil +} + +func (m *mockStorage) Delete(key string) error { + return nil +} + +func (m *mockStorage) Reset() error { + return nil +} + +func (m *mockStorage) Close() error { + return nil +} + +func Test_NotGetTokenInSessionStorage(t *testing.T) { + t.Parallel() + + errHandler := func(c fiber.Ctx, err error) error { + require.Equal(t, ErrNotGetStorage.Error(), err.Error()) + return c.Status(419).Send([]byte(err.Error())) + } + + // &session.Store{}.Storage.Set(ConfigDefault.CookieName, "fiber", 300) + + app := fiber.New() + app.Use(New(Config{ + ErrorHandler: errHandler, + Session: &session.Store{ + Config: session.Config{ + Storage: &mockStorage{}, + KeyGenerator: ConfigDefault.KeyGenerator, + KeyLookup: ConfigDefault.KeyLookup, + Expiration: ConfigDefault.Expiration, + CookieSameSite: "Lax", + }, + }, + })) + + app.Post("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie(ConfigDefault.CookieName, "fiber") + h(ctx) + + require.Equal(t, 419, ctx.Response.StatusCode()) + require.Equal(t, "invalid CSRF token", string(ctx.Response.Body())) + +} diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 67530601d4..dd36ed545c 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -1,6 +1,7 @@ package csrf import ( + "fmt" "time" "github.com/gofiber/fiber/v3" @@ -29,11 +30,13 @@ func newSessionManager(s *session.Store, k string) *sessionManager { func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) { sess, err := m.session.Get(c) if err != nil { - log.Warn("csrf: failed to get session: ", err) - return nil, ErrTokenNotFound + return nil, ErrNotGetStorage } + fmt.Println("key: ", sess) + token, ok := sess.Get(m.key).(Token) + fmt.Println("key: ", token, ok) if !ok { return nil, ErrTokenInvalid } diff --git a/middleware/session/store.go b/middleware/session/store.go index 05fba8e233..af72f471c2 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -86,12 +86,26 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { sess.id = id sess.fresh = fresh - // Decode session data if found - if rawData != nil { - sess.data.Lock() - defer sess.data.Unlock() - if err := sess.decodeSessionData(rawData); err != nil { - return nil, fmt.Errorf("failed to decode session data: %w", err) + // Fetch existing data + if loadData { + raw, err := s.Storage.Get(id) + // Unmarshal if we found data + switch { + case err != nil: + return nil, err + + case raw != nil: + mux.Lock() + defer mux.Unlock() + sess.byteBuffer.Write(raw) + encCache := gob.NewDecoder(sess.byteBuffer) + err := encCache.Decode(&sess.data.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode session data: %w", err) + } + default: + // both raw and err is nil, which means id is not in the storage + sess.fresh = true } } From 41266f37f2c6af2ef179a960fb080b299f16f959 Mon Sep 17 00:00:00 2001 From: renanbastos93 Date: Mon, 8 Jul 2024 14:38:47 -0300 Subject: [PATCH 7/7] chore: resolve conflicts --- middleware/csrf/csrf_test.go | 2 +- middleware/csrf/session_manager.go | 2 +- middleware/csrf/storage_manager.go | 4 ++-- middleware/session/store.go | 26 ++++++-------------------- 4 files changed, 10 insertions(+), 24 deletions(-) diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 5c72aec9bd..2dafec97d5 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1536,7 +1536,7 @@ func Test_NotGetTokenInSessionStorage(t *testing.T) { t.Parallel() errHandler := func(c fiber.Ctx, err error) error { - require.Equal(t, ErrNotGetStorage.Error(), err.Error()) + require.Equal(t, ErrStorageRetrievalFailed.Error(), err.Error()) return c.Status(419).Send([]byte(err.Error())) } diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index dd36ed545c..5d56121eaf 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -30,7 +30,7 @@ func newSessionManager(s *session.Store, k string) *sessionManager { func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) { sess, err := m.session.Get(c) if err != nil { - return nil, ErrNotGetStorage + return nil, ErrStorageRetrievalFailed } fmt.Println("key: ", sess) diff --git a/middleware/csrf/storage_manager.go b/middleware/csrf/storage_manager.go index 743b01b8aa..c3b1012f00 100644 --- a/middleware/csrf/storage_manager.go +++ b/middleware/csrf/storage_manager.go @@ -50,13 +50,13 @@ func (m *storageManager) getRaw(key string) ([]byte, error) { if m.storage != nil { raw, err = m.storage.Get(key) if err != nil { - return nil, fmt.Errorf("%w: %s", ErrNotGetStorage, err.Error()) + return nil, fmt.Errorf("%w: %s", ErrStorageRetrievalFailed, err.Error()) } } else { var ok bool raw, ok = m.memory.Get(key).([]byte) if !ok { - return nil, ErrNotGetStorage + return nil, ErrStorageRetrievalFailed } } diff --git a/middleware/session/store.go b/middleware/session/store.go index af72f471c2..05fba8e233 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -86,26 +86,12 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { sess.id = id sess.fresh = fresh - // Fetch existing data - if loadData { - raw, err := s.Storage.Get(id) - // Unmarshal if we found data - switch { - case err != nil: - return nil, err - - case raw != nil: - mux.Lock() - defer mux.Unlock() - sess.byteBuffer.Write(raw) - encCache := gob.NewDecoder(sess.byteBuffer) - err := encCache.Decode(&sess.data.Data) - if err != nil { - return nil, fmt.Errorf("failed to decode session data: %w", err) - } - default: - // both raw and err is nil, which means id is not in the storage - sess.fresh = true + // Decode session data if found + if rawData != nil { + sess.data.Lock() + defer sess.data.Unlock() + if err := sess.decodeSessionData(rawData); err != nil { + return nil, fmt.Errorf("failed to decode session data: %w", err) } }