Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix goroutine leaking #103

Merged
merged 12 commits into from
Feb 15, 2024
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
pull_request:
branches:
- main
- v1
name: Test
env:
GO_TARGET_VERSION: 1.21
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/pashagolub/pgxmock/v2 v2.12.0
github.com/stretchr/testify v1.8.2
go.mongodb.org/mongo-driver v1.12.2
go.uber.org/goleak v1.3.0 // indirect
go.uber.org/multierr v1.9.0
gorm.io/driver/mysql v1.5.2
gorm.io/driver/sqlite v1.5.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
Expand Down
2 changes: 2 additions & 0 deletions gorm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Implementation rewrites [DisableNestedTransaction](https://gorm.io/docs/gorm_config.html#DisableNestedTransaction) by [Settings.Propogation](../trm/settings.go) if it is [PropagationNested](../trm/transaction.go).

5 changes: 5 additions & 0 deletions gorm/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ func Example() {
db, err := gorm.Open(sqlite.Open("file:test.db?mode=memory"))
checkErr(err)

sqlDB, err := db.DB()
checkErr(err)

defer sqlDB.Close()

// Migrate the schema
checkErr(db.AutoMigrate(&userRow{}))

Expand Down
5 changes: 4 additions & 1 deletion gorm/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"github.com/avito-tech/go-transaction-manager/trm"
)

// NewDefaultFactory creates default trm.Transaction(sqlx.Tx).
// NewDefaultFactory creates default trm.Transaction(gorm.DB).
// Factory rewrites DisableNestedTransaction in gorm.Config with Propagation in trm.Settings.
func NewDefaultFactory(db *gorm.DB) trm.TrFactory {
return func(ctx context.Context, trms trm.Settings) (context.Context, trm.Transaction, error) {
s, _ := trms.(Settings)

db.Config.DisableNestedTransaction = trms.Propagation() != trm.PropagationNested

return NewTransaction(ctx, s.TxOpts(), db)
}
}
14 changes: 14 additions & 0 deletions gorm/goroutine_leak_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build go1.20
// +build go1.20

package gorm

import (
"testing"

"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
2 changes: 1 addition & 1 deletion gorm/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewSettings(trms trm.Settings, oo ...Opt) (Settings, error) {
return *s, nil
}

// MustSettings returns Settings if err is nil and panics otherwise.
// MustSettings returns Settings if stopByErr is nil and panics otherwise.
func MustSettings(trms trm.Settings, oo ...Opt) Settings {
s, err := NewSettings(trms, oo...)
if err != nil {
Expand Down
98 changes: 69 additions & 29 deletions gorm/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@ import (
"database/sql"
"errors"
"sync"
"sync/atomic"

"gorm.io/gorm"

"github.com/avito-tech/go-transaction-manager/trm"
"github.com/avito-tech/go-transaction-manager/trm/drivers"
)

var errRollbackTx = errors.New("rollback transaction")

// Transaction is trm.Transaction for sqlx.Tx.
type Transaction struct {
tx *gorm.DB
err chan error
isActive int64
tx *gorm.DB
txMutex sync.Mutex
isClosed *drivers.IsClosed
isClosedClosure *drivers.IsClosed
}

// NewTransaction creates trm.Transaction for sqlx.Tx.
Expand All @@ -31,7 +30,12 @@ func NewTransaction(
opts *sql.TxOptions,
db *gorm.DB,
) (context.Context, *Transaction, error) {
tr := &Transaction{isActive: 1, err: make(chan error), tx: nil}
t := &Transaction{
tx: nil,
txMutex: sync.Mutex{},
isClosed: drivers.NewIsClosed(),
isClosedClosure: drivers.NewIsClosed(),
}

var err error

Expand All @@ -40,16 +44,30 @@ func NewTransaction(

go func() {
db = db.WithContext(ctx)
// Used closure to avoid implementing nested transactions.
err = db.Transaction(func(tx *gorm.DB) error {
tr.tx = tx
t.tx = tx

wg.Done()

return <-tr.err
<-t.isClosedClosure.Closed()

return t.isClosedClosure.Err()
}, opts)

if tr.tx != nil {
tr.err <- err
t.txMutex.Lock()
defer t.txMutex.Unlock()
tx := t.tx

if tx != nil {
// Return error from transaction rollback
// Error from commit returns from db.Transaction closure
if errors.Is(err, drivers.ErrRollbackTr) &&
tx.Error != nil {
err = t.tx.Error
}

t.isClosed.CloseWithCause(err)
} else {
wg.Done()
}
Expand All @@ -61,19 +79,22 @@ func NewTransaction(
return ctx, nil, err
}

go tr.awaitDone(ctx)
go t.awaitDone(ctx)

return ctx, tr, nil
return ctx, t, nil
}

func (t *Transaction) awaitDone(ctx context.Context) {
if ctx.Done() == nil {
return
}

<-ctx.Done()

t.deactivate()
select {
case <-ctx.Done():
// Rollback will be called by context.Err()
t.isClosedClosure.Close()
case <-t.isClosed.Closed():
}
}

// Transaction returns the real transaction sqlx.Tx.
Expand All @@ -84,38 +105,57 @@ func (t *Transaction) Transaction() interface{} {

// Begin nested transaction by save point.
func (t *Transaction) Begin(ctx context.Context, s trm.Settings) (context.Context, trm.Transaction, error) {
t.txMutex.Lock()
defer t.txMutex.Unlock()

return NewDefaultFactory(t.tx)(ctx, s)
}

// Commit closes the trm.Transaction.
func (t *Transaction) Commit(_ context.Context) error {
defer t.deactivate()
select {
case <-t.isClosed.Closed():
t.txMutex.Lock()
defer t.txMutex.Unlock()

return t.tx.Commit().Error
default:
t.isClosedClosure.Close()

t.err <- nil
<-t.isClosed.Closed()

return <-t.err
return t.isClosed.Err()
}
}

// Rollback the trm.Transaction.
func (t *Transaction) Rollback(_ context.Context) error {
defer t.deactivate()
select {
case <-t.isClosed.Closed():
t.txMutex.Lock()
defer t.txMutex.Unlock()

t.err <- errRollbackTx
return t.tx.Rollback().Error
default:
t.isClosedClosure.CloseWithCause(drivers.ErrRollbackTr)

err := <-t.err
<-t.isClosed.Closed()

if errors.Is(err, errRollbackTx) {
return nil
}
err := t.isClosed.Err()
if errors.Is(err, drivers.ErrRollbackTr) {
return nil
}

return err
return err
}
}

// IsActive returns true if the transaction started but not committed or rolled back.
func (t *Transaction) IsActive() bool {
return atomic.LoadInt64(&t.isActive) == 1
return t.isClosed.IsActive()
}

func (t *Transaction) deactivate() {
atomic.SwapInt64(&t.isActive, 0)
// Closed returns a channel that's closed when transaction committed or rolled back.
func (t *Transaction) Closed() <-chan struct{} {
return t.isClosed.Closed()
}
Loading
Loading