Skip to content

Commit

Permalink
Generalize ffi tests handshake helper
Browse files Browse the repository at this point in the history
Change-Id: I1d6983449267565a148a6a74fdd4fb94eccaf54f
  • Loading branch information
jblebrun committed Feb 6, 2025
1 parent d56d5c9 commit 590fd95
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 60 deletions.
29 changes: 18 additions & 11 deletions cc/oak_session/client_server_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,24 @@ SessionConfig* TestConfig() {
}

void DoHandshake(ClientSession& client_session, ServerSession& server_session) {
absl::StatusOr<std::optional<SessionRequest>> init =
client_session.GetOutgoingMessage();
ASSERT_THAT(init, IsOk());
ASSERT_THAT(*init, Ne(std::nullopt));
ASSERT_THAT(server_session.PutIncomingMessage(**init), IsOk());

absl::StatusOr<std::optional<SessionResponse>> init_resp =
server_session.GetOutgoingMessage();
ASSERT_THAT(init_resp, IsOk());
ASSERT_THAT(*init_resp, Ne(std::nullopt));
ASSERT_THAT(client_session.PutIncomingMessage(**init_resp), IsOk());
while (!client_session.IsOpen() && !server_session.IsOpen()) {
if (!client_session.IsOpen()) {
absl::StatusOr<std::optional<SessionRequest>> init =
client_session.GetOutgoingMessage();
ASSERT_THAT(init, IsOk());
ASSERT_THAT(*init, Ne(std::nullopt));
ASSERT_THAT(server_session.PutIncomingMessage(**init), IsOk());
}

if (!server_session.IsOpen()) {
absl::StatusOr<std::optional<SessionResponse>> init_resp =
server_session.GetOutgoingMessage();
ASSERT_THAT(init_resp, IsOk());
if (*init_resp != std::nullopt) {
ASSERT_THAT(client_session.PutIncomingMessage(**init_resp), IsOk());
}
}
}

EXPECT_THAT(client_session.IsOpen(), Eq(true));
EXPECT_THAT(server_session.IsOpen(), Eq(true));
Expand Down
60 changes: 34 additions & 26 deletions cc/oak_session/oak_session_bindings_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,40 @@ using ::oak::session::v1::SessionResponse;
using ::testing::Eq;

void DoHandshake(ServerSession* server_session, ClientSession* client_session) {
ErrorOrRustBytes init = client_get_outgoing_message(client_session);
ASSERT_THAT(init, IsResult());

// We could just past init.result directly, but let's ensure that the request
// successfully goes through the ser/deser properly.
SessionRequest request;
ASSERT_TRUE(request.ParseFromString(*init.result));
std::string request_reserialized;
ASSERT_TRUE(request.SerializeToString(&request_reserialized));
BytesView request_bytes = BytesView(request_reserialized);
free_rust_bytes(init.result);

ASSERT_THAT(server_put_incoming_message(server_session, request_bytes),
NoError());

ErrorOrRustBytes init_resp = server_get_outgoing_message(server_session);
ASSERT_THAT(init_resp, IsResult());

SessionResponse response;
ASSERT_TRUE(response.ParseFromString(*init_resp.result));
free_rust_bytes(init_resp.result);
std::string response_reserialized;
ASSERT_TRUE(response.SerializeToString(&response_reserialized));
ASSERT_THAT(client_put_incoming_message(client_session,
BytesView(response_reserialized)),
NoError());
while (!client_is_open(client_session) && !server_is_open(server_session)) {
if (!client_is_open(client_session)) {
ErrorOrRustBytes init = client_get_outgoing_message(client_session);
ASSERT_THAT(init, IsResult());

// We could just past init.result directly, but let's ensure that the
// request successfully goes through the ser/deser properly.
SessionRequest request;
ASSERT_TRUE(request.ParseFromString(*init.result));
std::string request_reserialized;
ASSERT_TRUE(request.SerializeToString(&request_reserialized));
BytesView request_bytes = BytesView(request_reserialized);
free_rust_bytes(init.result);

ASSERT_THAT(server_put_incoming_message(server_session, request_bytes),
NoError());
}

if (!server_is_open(server_session)) {
ErrorOrRustBytes init_resp = server_get_outgoing_message(server_session);
ASSERT_THAT(init_resp.error, NoError());

if (init_resp.result != nullptr) {
SessionResponse response;
ASSERT_TRUE(response.ParseFromString(*init_resp.result));
free_rust_bytes(init_resp.result);
std::string response_reserialized;
ASSERT_TRUE(response.SerializeToString(&response_reserialized));
ASSERT_THAT(client_put_incoming_message(
client_session, BytesView(response_reserialized)),
NoError());
}
}
}

ASSERT_TRUE(server_is_open(server_session));
ASSERT_TRUE(client_is_open(client_session));
Expand Down
57 changes: 34 additions & 23 deletions oak_session/ffi/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
use oak_proto_rust::oak::session::v1::{SessionRequest, SessionResponse};
use oak_session::session::{ClientSession, ServerSession};
use oak_session::{
session::{ClientSession, ServerSession},
Session,
};
use oak_session_ffi_client_session as client_ffi;
use oak_session_ffi_config as config_ffi;
use oak_session_ffi_server_session as server_ffi;
Expand Down Expand Up @@ -115,27 +117,36 @@ unsafe fn do_handshake(
client_session_ptr: *mut ClientSession,
server_session_ptr: *mut ServerSession,
) {
// Handshake
let init = client_ffi::client_get_outgoing_message(client_session_ptr);
assert_no_error!(init.error);
let incoming_slice = (*init.result).as_slice();
let _ = SessionRequest::decode(incoming_slice).expect("couldn't decode init request");

let result =
server_ffi::server_put_incoming_message(server_session_ptr, (*init.result).as_bytes_view());
assert_no_error!(result);
unsafe { oak_session_ffi_types::free_rust_bytes(init.result) };

let init_resp = server_ffi::server_get_outgoing_message(server_session_ptr);
assert_no_error!(init_resp.error);
let init_resp_slice = (*init_resp.result).as_slice();
let _ = SessionResponse::decode(init_resp_slice).expect("couldn't decode init response");
let put_result = client_ffi::client_put_incoming_message(
client_session_ptr,
(*init_resp.result).as_bytes_view(),
);
unsafe { oak_session_ffi_types::free_rust_bytes(init_resp.result) };
assert_no_error!(put_result);
while !(*client_session_ptr).is_open() && !(*server_session_ptr).is_open() {
if !(*client_session_ptr).is_open() {
let init = client_ffi::client_get_outgoing_message(client_session_ptr);
assert_no_error!(init.error);
let incoming_slice = (*init.result).as_slice();
let _ = SessionRequest::decode(incoming_slice).expect("couldn't decode init request");
let result = server_ffi::server_put_incoming_message(
server_session_ptr,
(*init.result).as_bytes_view(),
);
assert_no_error!(result);
unsafe { oak_session_ffi_types::free_rust_bytes(init.result) };
}

if !(*server_session_ptr).is_open() {
let init_resp = server_ffi::server_get_outgoing_message(server_session_ptr);
assert_no_error!(init_resp.error);
if !init_resp.result.is_null() {
let init_resp_slice = (*init_resp.result).as_slice();
let _ = SessionResponse::decode(init_resp_slice)
.expect("couldn't decode init response");
let put_result = client_ffi::client_put_incoming_message(
client_session_ptr,
(*init_resp.result).as_bytes_view(),
);
assert_no_error!(put_result);
unsafe { oak_session_ffi_types::free_rust_bytes(init_resp.result) };
}
}
}
}

fn create_test_session_config() -> *mut oak_session::config::SessionConfig {
Expand Down

0 comments on commit 590fd95

Please sign in to comment.