Skip to content

Merge changes from tls-channel to prevent accidentally calling SSLEng… #1726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ public TlsChannelImpl(
private final Lock readLock = new ReentrantLock();
private final Lock writeLock = new ReentrantLock();

private volatile boolean negotiated = false;
private boolean handshakeStarted = false;

private volatile boolean handshakeCompleted = false;

/**
* Whether a IOException was received from the underlying channel or from the {@link SSLEngine}.
Expand Down Expand Up @@ -526,22 +528,35 @@ public void handshake() throws IOException {
}

private void doHandshake(boolean force) throws IOException, EofException {
if (!force && negotiated) return;
if (!force && handshakeCompleted) {
return;
}
initLock.lock();
try {
if (invalid || shutdownSent) throw new ClosedChannelException();
if (force || !negotiated) {
engine.beginHandshake();
LOGGER.trace("Called engine.beginHandshake()");
if (force || !handshakeCompleted) {

if (!handshakeStarted) {
engine.beginHandshake();
LOGGER.trace("Called engine.beginHandshake()");

// Some engines that do not support renegotiations may be sensitive to calling
// SSLEngine.beginHandshake() more than once. This guard prevents that.
// See: https://github.com/marianobarrios/tls-channel/issues/197
handshakeStarted = true;
}

handshake(Optional.empty(), Optional.empty());

handshakeCompleted = true;

// call client code
try {
initSessionCallback.accept(engine.getSession());
Comment on lines +551 to 555
Copy link
Preview

Copilot AI Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting handshakeCompleted to true immediately may lead to an inconsistent handshake state if the subsequent session initialization callback fails. Consider moving this assignment to after a successful execution of the callback or resetting it on failure.

Suggested change
handshakeCompleted = true;
// call client code
try {
initSessionCallback.accept(engine.getSession());
// call client code
try {
initSessionCallback.accept(engine.getSession());
handshakeCompleted = true; // set only after successful callback execution

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a bad recommendation? However, best to confirm if initSessionCallback could be called twice and if that has any ramifications.

Copy link
Member Author

@vbabanin vbabanin Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think in this case the Copilot suggestion does not apply.
Based on the TLSChannelImpl javadoc, the callback is only invoked when the TLS session is established or re-established:

"Register a callback function to be executed when the TLS session is established (or re-established)..."

public T withSessionInitCallback(Consumer<SSLSession> sessionInitCallback) {

If the callback fails, it doesn't indicate that the session was not established or in partial state. Internally, we still mark the handshake as completed to prevent re-establishing the session again on the next read or write. So the callback doesn’t control session validity - it's more of a post-handshake hook. Also, we rely on the default session callback, which is a no-op. Reference:

Given this, I believe we are okay and we can copy the logic from the upstream.

} catch (Exception e) {
LOGGER.trace("client code threw exception in session initialization callback", e);
throw new TlsChannelCallbackException("session initialization callback failed", e);
}
negotiated = true;
}
} finally {
initLock.unlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,52 @@

package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.getPrimaryServerDescription;
import static com.mongodb.internal.connection.OperationContext.simpleOperationContext;
import static java.lang.String.format;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class TlsChannelStreamFunctionalTest {
private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
Expand Down Expand Up @@ -98,6 +114,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class);
ServerSocket serverSocket = new ServerSocket(0, 1)) {

SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);

Expand Down Expand Up @@ -147,4 +164,35 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable {
private static OperationContext createOperationContext(final int connectTimeoutMs) {
return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs)));
}

@Test
@DisplayName("should not call beginHandshake more than once during TLS session establishment")
void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case, as upstream didn't include one to cover this change.

assumeTrue(ClusterFixture.getSslSettings().isEnabled());

//given
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) {

SSLContext sslContext = Mockito.spy(SSLContext.getDefault());
SingleResultSpyCaptor<SSLEngine> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor);

StreamFactory streamFactory = streamFactoryFactory.create(
SocketSettings.builder().build(),
SslSettings.builder(ClusterFixture.getSslSettings())
.context(sslContext)
.build());

Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress());
stream.open(ClusterFixture.OPERATION_CONTEXT);
ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4}));

//when
stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT);

//then
SECONDS.sleep(5);
verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake();
}
}
}