Skip to content

Commit

Permalink
A bunch of improvements (#479)
Browse files Browse the repository at this point in the history
Hard to break these apart as they are all kinda entangled.

1) We know that `TestReadOwnWritesGuarantee` occasionally fails (#478)
but the output of that test is mangled with the output from the previous
test.
The mangling is caused by unclean shutdown. As those resources get
cleaned up, the output gets printed all over the place.
In tests, we don't need soft shutdown, so this PR introduces a hard
shutdown to speed up tests.

2) The sync worker was leaking connections

3) Recv() is blocking and does not listen to ctx Done. A separate
mechanism has to be introduced to handle shutdowns cleanly. This is the
same pattern that we use in `nodeCursorTracker`

4) Create DBs with the name of the test in them. Makes SQL debugging
easier.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Configurable timeout settings now enable a graceful shutdown process
to ensure active operations complete reliably.
- **Refactor**
- Enhanced shutdown logic with timeout parameters for improved control
over server shutdown processes.
- Optimized background processing with non-blocking error handling for
improved responsiveness.
- **Tests**
- Enhanced testing utilities with refined resource cleanup and dynamic
naming to bolster test robustness.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
mkysel authored Feb 12, 2025
1 parent b789de6 commit 35c421e
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 43 deletions.
3 changes: 2 additions & 1 deletion cmd/replication/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/Masterminds/semver/v3"
"log"
"sync"
"time"

"github.com/jessevdk/go-flags"
"github.com/xmtp/xmtpd/pkg/blockchain"
Expand Down Expand Up @@ -137,7 +138,7 @@ func main() {
log.Fatal("initializing server", zap.Error(err))
}

s.WaitForShutdown()
s.WaitForShutdown(10 * time.Second)
doneC <- true
})
<-doneC
Expand Down
8 changes: 6 additions & 2 deletions pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ func (s *ApiServer) gracefulShutdown(timeout time.Duration) {
<-ctx.Done()
}

func (s *ApiServer) Close() {
func (s *ApiServer) Close(timeout time.Duration) {
s.log.Debug("closing")
if s.grpcServer != nil {
s.gracefulShutdown(10 * time.Second)
if timeout != 0 {
s.gracefulShutdown(timeout)
} else {
s.grpcServer.Stop()
}
}
if s.grpcListener != nil {
if err := s.grpcListener.Close(); err != nil && !isErrUseOfClosedConnection(err) {
Expand Down
9 changes: 5 additions & 4 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/signal"
"syscall"
"time"

"github.com/Masterminds/semver/v3"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -260,14 +261,14 @@ func (s *ReplicationServer) Addr() net.Addr {
return s.apiServer.Addr()
}

func (s *ReplicationServer) WaitForShutdown() {
func (s *ReplicationServer) WaitForShutdown(timeout time.Duration) {
termChannel := make(chan os.Signal, 1)
signal.Notify(termChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
<-termChannel
s.Shutdown()
s.Shutdown(timeout)
}

func (s *ReplicationServer) Shutdown() {
func (s *ReplicationServer) Shutdown(timeout time.Duration) {
if s.metrics != nil {
s.metrics.Close()
}
Expand All @@ -284,7 +285,7 @@ func (s *ReplicationServer) Shutdown() {
s.indx.Close()
}
if s.apiServer != nil {
s.apiServer.Close()
s.apiServer.Close(timeout)
}

s.cancel()
Expand Down
46 changes: 27 additions & 19 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"testing"
"time"

"github.com/stretchr/testify/mock"

"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/config"
Expand Down Expand Up @@ -118,28 +116,35 @@ func TestCreateServer(t *testing.T) {
registry.On("GetNodes").Return(nodes, nil)

nodesChan := make(chan []r.Node)
cancelOnNewFunc := func() {
close(nodesChan)
}
registry.On("OnNewNodes").
Return((<-chan []r.Node)(nodesChan), r.CancelSubscription(cancelOnNewFunc))

nodeChan := make(chan r.Node)

cancelOnChangedFunc := func() {
close(nodeChan)
}

registry.On("OnChangedNode", mock.AnythingOfType("uint32")).
Return((<-chan r.Node)(nodeChan), r.CancelSubscription(cancelOnChangedFunc))
Return((<-chan []r.Node)(nodesChan), r.CancelSubscription(func() {}))

nodeChan1 := make(chan r.Node)
nodeChan2 := make(chan r.Node)
registry.On("OnChangedNode", server1NodeID).
Return((<-chan r.Node)(nodeChan1), r.CancelSubscription(func() {
close(nodeChan1)
}))
registry.On("OnChangedNode", server2NodeID).
Return((<-chan r.Node)(nodeChan2), r.CancelSubscription(func() {
close(nodeChan2)
}))

registry.On("GetNode", server1NodeID).Return(&nodes[0], nil)
registry.On("GetNode", server2NodeID).Return(&nodes[1], nil)

registry.On("Stop").Return(nil)

server1 := NewTestServer(t, server1Port, dbs[0], registry, privateKey1)
server2 := NewTestServer(t, server2Port, dbs[1], registry, privateKey2)

require.NotEqual(t, server1.Addr(), server2.Addr())

defer func() {
server1.Shutdown(0)
server2.Shutdown(0)
}()

client1, cleanup1 := apiTestUtils.NewReplicationAPIClient(t, ctx, server1.Addr().String())
defer cleanup1()
client2, cleanup2 := apiTestUtils.NewReplicationAPIClient(t, ctx, server2.Addr().String())
Expand Down Expand Up @@ -238,13 +243,16 @@ func TestReadOwnWritesGuarantee(t *testing.T) {
registry.On("GetNodes").Return(nodes, nil)

nodesChan := make(chan []r.Node)
cancelOnNewFunc := func() {
close(nodesChan)
}
registry.On("OnNewNodes").
Return((<-chan []r.Node)(nodesChan), r.CancelSubscription(cancelOnNewFunc))
Return((<-chan []r.Node)(nodesChan), r.CancelSubscription(func() {
}))

registry.On("Stop").Return(nil)

server1 := NewTestServer(t, server1Port, dbs[0], registry, privateKey1)
defer func() {
server1.Shutdown(0)
}()

client1, cleanup1 := apiTestUtils.NewReplicationAPIClient(t, ctx, server1.Addr().String())
defer cleanup1()
Expand Down
54 changes: 39 additions & 15 deletions pkg/sync/syncWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ func (s *syncWorker) subscribeToNodeRegistration(
return
}

var conn *grpc.ClientConn
var stream *originatorStream
err = nil

// TODO(mkysel) we should eventually implement a better backoff strategy
Expand All @@ -193,15 +191,21 @@ func (s *syncWorker) subscribeToNodeRegistration(
backoff = time.Second
}

var conn *grpc.ClientConn
conn, err = s.connectToNode(*node)
if err != nil {
continue
}

var stream *originatorStream
stream, err = s.setupStream(registration.ctx, *node, conn)
if err != nil {
_ = conn.Close()
continue
}
err = s.listenToStream(stream)
_ = stream.stream.CloseSend()
_ = conn.Close()
}
}
}
Expand Down Expand Up @@ -322,20 +326,40 @@ func (s *syncWorker) setupStream(
func (s *syncWorker) listenToStream(
originatorStream *originatorStream,
) error {
for {
// Recv() is a blocking operation that can only be interrupted by cancelling ctx
envs, err := originatorStream.stream.Recv()
if err == io.EOF {
return fmt.Errorf("stream closed with EOF")
}
if err != nil {
return fmt.Errorf(
"stream closed with error: %v",
err)
recvChan := make(chan *message_api.SubscribeEnvelopesResponse)
errChan := make(chan error)

go func() {
for {
envs, err := originatorStream.stream.Recv()
if err != nil {
errChan <- err
return
}
recvChan <- envs
}
s.log.Debug("Received envelopes", zap.Any("numEnvelopes", len(envs.Envelopes)))
for _, env := range envs.Envelopes {
s.validateAndInsertEnvelope(originatorStream, env)
}()

for {
select {
case <-s.ctx.Done():
s.log.Info("Context canceled, stopping stream listener")
return nil

case envs := <-recvChan:
s.log.Debug("Received envelopes", zap.Any("numEnvelopes", len(envs.Envelopes)))
for _, env := range envs.Envelopes {
s.validateAndInsertEnvelope(originatorStream, env)
}

case err := <-errChan:
if err == io.EOF {
s.log.Info("Stream closed with EOF")
// let the caller rebuild the stream if required
return nil
}
s.log.Error("Stream closed with error", zap.Error(err))
return err
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/testutils/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, ApiServerMocks, fu

return svr, db, allMocks, func() {
cancel()
svr.Close()
svr.Close(0)
dbCleanup()
}
}
Expand Down
17 changes: 16 additions & 1 deletion pkg/testutils/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package testutils
import (
"context"
"database/sql"
"path/filepath"
"runtime"
"strings"
"testing"

"github.com/jackc/pgx/v5"
Expand All @@ -17,6 +20,17 @@ const (
LocalTestDBDSNSuffix = "?sslmode=disable"
)

func getCallerName(depth int) string {
pc, _, _, ok := runtime.Caller(depth)
if !ok {
return "unknown"
}
nameFull := runtime.FuncForPC(pc).Name()
nameEnd := filepath.Ext(nameFull)
name := strings.TrimPrefix(nameEnd, ".")
return strings.ToLower(name)
}

func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) {
config, err := pgx.ParseConfig(dsn)
require.NoError(t, err)
Expand All @@ -32,7 +46,8 @@ func newCtlDB(t testing.TB) (*sql.DB, string, func()) {
}

func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, string, func()) {
dbName := "test_" + RandomStringLower(12)
dbName := "test_" + getCallerName(3) + "_" + RandomStringLower(12)
t.Logf("creating database %s ...", dbName)
_, err := ctlDB.Exec("CREATE DATABASE " + dbName)
require.NoError(t, err)

Expand Down

0 comments on commit 35c421e

Please sign in to comment.