Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions neo4j/directrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type directRouter struct {
address string
}

func (r *directRouter) IsMultiServer() bool {
return false
}

func (r *directRouter) InvalidateWriter(string, string) {}

func (r *directRouter) InvalidateReader(string, string) {}
Expand Down
1 change: 1 addition & 0 deletions neo4j/driver_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ type sessionRouter interface {
InvalidateWriter(db string, server string)
InvalidateReader(db string, server string)
InvalidateServer(server string)
IsMultiServer() bool
}

type driverWithContext struct {
Expand Down
4 changes: 4 additions & 0 deletions neo4j/driver_with_context_testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []s
return errorutil.WrapError(err)
}

func RegisterDnsResolver(d DriverWithContext, hook func(address string) []string) {
d.(*driverWithContext).connector.TestKitResolver = hook
}

func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error) {
driver := d.(*driverWithContext)
router, ok := driver.router.(*router.Router)
Expand Down
9 changes: 9 additions & 0 deletions neo4j/internal/bolt/bolt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ func (b *bolt3) ServerName() string {
return b.serverName
}

func (b *bolt3) AdvertisedServerName() string {
// Advertised address not supported by this protocol version
return ""
}

func (b *bolt3) SetServerName(serverName string) {
b.serverName = serverName
}

func (b *bolt3) ServerVersion() string {
return b.serverVersion
}
Expand Down
11 changes: 10 additions & 1 deletion neo4j/internal/bolt/bolt4.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ func (b *bolt4) ServerName() string {
return b.serverName
}

func (b *bolt4) AdvertisedServerName() string {
// Advertised address not supported by this protocol version
return ""
}

func (b *bolt4) SetServerName(serverName string) {
b.serverName = serverName
}

func (b *bolt4) ServerVersion() string {
return b.serverVersion
}
Expand Down Expand Up @@ -985,7 +994,7 @@ func (b *bolt4) GetCurrentAuth() (auth.TokenManager, iauth.Token) {
}

func (b *bolt4) Telemetry(telemetry.API, func()) {
// TELEMETRY not support by this protocol version, so we ignore it.
// TELEMETRY not supported by this protocol version, so we ignore it.
}

func (b *bolt4) helloResponseHandler(checkUtcPatch bool) responseHandler {
Expand Down
60 changes: 38 additions & 22 deletions neo4j/internal/bolt/bolt5.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,29 @@ func (i *internalTx5) toMeta(logger log.Logger, logId string, version db.Protoco
}

type bolt5 struct {
state int
txId idb.TxHandle
streams openstreams
conn io.ReadWriteCloser
serverName string
queue messageQueue
connId string
logId string
serverVersion string
bookmark string // Last bookmark
birthDate time.Time
log log.Logger
databaseName string
err error // Last fatal error
minor int
lastQid int64 // Last seen qid
idleDate time.Time
auth map[string]any
authManager auth.TokenManager
resetAuth bool
errorListener ConnectionErrorListener
telemetryEnabled bool
state int
txId idb.TxHandle
streams openstreams
conn io.ReadWriteCloser
serverName string // Initial server name
advertisedServerName string // Preferred server name
queue messageQueue
connId string
logId string
serverVersion string
bookmark string // Last bookmark
birthDate time.Time
log log.Logger
databaseName string
err error // Last fatal error
minor int
lastQid int64 // Last seen qid
idleDate time.Time
auth map[string]any
authManager auth.TokenManager
resetAuth bool
errorListener ConnectionErrorListener
telemetryEnabled bool
}

func NewBolt5(
Expand Down Expand Up @@ -178,6 +179,14 @@ func (b *bolt5) ServerName() string {
return b.serverName
}

func (b *bolt5) AdvertisedServerName() string {
return b.advertisedServerName
}

func (b *bolt5) SetServerName(serverName string) {
b.serverName = serverName
}

func (b *bolt5) ServerVersion() string {
return b.serverVersion
}
Expand Down Expand Up @@ -989,6 +998,9 @@ func (b *bolt5) logoffResponseHandler() responseHandler {
}

func (b *bolt5) logonResponseHandler() responseHandler {
if b.Version().Major >= 5 && b.Version().Minor >= 8 {
return b.expectedSuccessHandler(b.onLogonSuccess)
}
return b.expectedSuccessHandler(onSuccessNoOp)
}

Expand Down Expand Up @@ -1127,6 +1139,10 @@ func (b *bolt5) onHelloSuccess(helloSuccess *success) {
b.initializeTelemetryHint(helloSuccess.configurationHints)
}

func (b *bolt5) onLogonSuccess(logonSuccess *success) {
b.advertisedServerName = logonSuccess.advertisedAddress
}

func (b *bolt5) onCommitSuccess(commitSuccess *success) {
if len(commitSuccess.bookmark) > 0 {
b.bookmark = commitSuccess.bookmark
Expand Down
2 changes: 1 addition & 1 deletion neo4j/internal/bolt/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type protocolVersion struct {

// Supported versions in priority order
var versions = [4]protocolVersion{
{major: 5, minor: 7, back: 7},
{major: 5, minor: 8, back: 8},
{major: 4, minor: 4, back: 2},
{major: 4, minor: 1},
{major: 3, minor: 0},
Expand Down
3 changes: 3 additions & 0 deletions neo4j/internal/bolt/hydrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type success struct {
num uint32
configurationHints map[string]any
patches []string
advertisedAddress string
}

func (s *success) String() string {
Expand Down Expand Up @@ -302,6 +303,8 @@ func (h *hydrator) success(n uint32) *success {
case "patch_bolt":
patches := h.strings()
succ.patches = patches
case "advertised_address":
succ.advertisedAddress = h.unp.String()
default:
// Unknown key, waste it
h.trash()
Expand Down
22 changes: 20 additions & 2 deletions neo4j/internal/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ type Connector struct {
Network string
Config *config.Config
SupplyConnection func(context.Context, string) (net.Conn, error)
TestKitResolver func(string) []string
}

func (c Connector) Connect(
func (c *Connector) Connect(
ctx context.Context,
address string,
auth *db.ReAuthToken,
Expand Down Expand Up @@ -138,7 +139,24 @@ func (c Connector) createConnection(ctx context.Context, address string) (net.Co
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
}

return dialer.DialContext(ctx, c.Network, address)
if c.TestKitResolver == nil {
return dialer.DialContext(ctx, c.Network, address)
}

addresses := c.TestKitResolver(address)

if len(addresses) == 0 {
return nil, errors.New("TestKit DNS resolver returned no address")
}

var err error = nil
for _, address := range addresses {
con, err := dialer.DialContext(ctx, c.Network, address)
if err == nil {
return con, nil
}
}
return nil, err
}

func (c Connector) tlsConfig(serverName string) *tls.Config {
Expand Down
4 changes: 4 additions & 0 deletions neo4j/internal/db/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ type Connection interface {
Bookmark() string
// ServerName returns the name of the remote server
ServerName() string
// AdvertisedServerName returns the advertised name of the remote server.
AdvertisedServerName() string
// SetServerName updates the server name to given value.
SetServerName(serverName string)
// ServerVersion returns the server version on pattern Neo4j/1.2.3
ServerVersion() string
// IsAlive returns true if the connection is fully functional.
Expand Down
42 changes: 31 additions & 11 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type poolRouter interface {
InvalidateWriter(db string, server string)
InvalidateReader(db string, server string)
InvalidateServer(server string)
IsMultiServer() bool
}

type qitem struct {
Expand Down Expand Up @@ -303,7 +304,7 @@ func (p *Pool) tryBorrow(
if healthy {
return connection, nil
}
p.unreg(ctx, serverName, connection, itime.Now())
p.unreg(ctx, serverName, connection, itime.Now(), true)
if err != nil {
p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err)
return nil, err
Expand Down Expand Up @@ -343,16 +344,18 @@ func (p *Pool) tryBorrow(
return c, nil
}

func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) {
func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time, close bool) {
p.serversMut.Lock()
defer p.serversMut.Unlock()
p.unregLocked(ctx, serverName, c, now)
p.unregLocked(ctx, serverName, c, now, close)
}

func (p *Pool) unregLocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) {
func (p *Pool) unregLocked(ctx context.Context, serverName string, c idb.Connection, now time.Time, close bool) {
defer func() {
// Close connection in another thread to avoid potential long blocking operation during close.
go c.Close(ctx)
if close {
go c.Close(ctx)
}
}()

server := p.servers[serverName]
Expand Down Expand Up @@ -384,16 +387,33 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) {
return
}

// Get the name of the server that the connection belongs to.
serverName := c.ServerName()
isAlive := c.IsAlive()
p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive)

// If the connection is dead, remove all other idle connections on the same server that older
// or of the same age as the dead connection, otherwise perform normal cleanup of old connections
maxAge := p.config.MaxConnectionLifetime
now := itime.Now()
age := now.Sub(c.Birthdate())

// Check if we have an advertised server name and if so replace connection from initial server.
// Only do this when routing is enabled.
if p.router.IsMultiServer() && c.AdvertisedServerName() != "" && c.ServerName() != c.AdvertisedServerName() {
// Remove connection from busy list of initial server.
p.unreg(ctx, c.ServerName(), c, now, false)
p.log.Debugf(log.Pool, p.logId, "Transferring connection from %s to advertised server %s", c.ServerName(), c.AdvertisedServerName())
// Update connection server name to that of the advertised address.
c.SetServerName(c.AdvertisedServerName())
// Create a fresh server.
p.serversMut.Lock()
if _, ok := p.servers[c.ServerName()]; !ok {
p.servers[c.ServerName()] = NewServer()
}
p.serversMut.Unlock()
}

// Get the name of the server that the connection belongs to
serverName := c.ServerName()
isAlive := c.IsAlive()
p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive)

if !isAlive {
// Since this connection has died all other connections that connected before this one
// might also be bad, remove the idle ones.
Expand All @@ -418,7 +438,7 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) {
// Fix for race condition where expired connections could be reused or closed concurrently.
// See: https://github.com/neo4j/neo4j-go-driver/issues/574
isAlive = false
p.unreg(ctx, serverName, c, now)
p.unreg(ctx, serverName, c, now, true)
p.log.Infof(log.Pool, p.logId, "Unregistering dead or too old connection to %s", serverName)
}

Expand Down
Loading