diff --git a/middleware/session/session.go b/middleware/session/session.go index ae5dbe3c43..1c5b22f67a 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -40,6 +40,8 @@ var sessionPool = sync.Pool{ // s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + s.mu.Lock() + defer s.mu.Unlock() if s.data == nil { s.data = acquireData() } @@ -76,6 +78,7 @@ func (s *Session) Release() { func releaseSession(s *Session) { s.mu.Lock() + defer s.mu.Unlock() s.id = "" s.idleTimeout = 0 s.ctx = nil @@ -86,7 +89,6 @@ func releaseSession(s *Session) { if s.byteBuffer != nil { s.byteBuffer.Reset() } - s.mu.Unlock() sessionPool.Put(s) } @@ -295,6 +297,7 @@ func (s *Session) saveSession() error { } s.mu.Lock() + defer s.mu.Unlock() // Check if session has your own expiration, otherwise use default value if s.idleTimeout <= 0 { @@ -316,13 +319,7 @@ func (s *Session) saveSession() error { copy(encodedBytes, s.byteBuffer.Bytes()) // Pass copied bytes with session id to provider - if err := s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout); err != nil { - return err - } - - s.mu.Unlock() - - return nil + return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout) } // Keys retrieves all keys in the current session.