Skip to content

Commit

Permalink
Get rid of unnecessary string allocations in Noise C++ bindings
Browse files Browse the repository at this point in the history
The ByteToString helper allocated a new string every time, but in every
use case, a view into the data is all that we needed. So instead, add
coercions to absl::string_view.

Also turns BytesFromString into a proper constructor.

Change-Id: Id09e356ff1f6e84da95d8001bff3cba4d01bc799
  • Loading branch information
jblebrun committed Feb 3, 2025
1 parent 06fa8db commit dce33de
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 53 deletions.
10 changes: 4 additions & 6 deletions cc/oak_session/client_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ absl::Status ClientSession::PutIncomingMessage(
const v1::SessionResponse& response) {
const std::string response_bytes = response.SerializeAsString();
bindings::Error* error = bindings::client_put_incoming_message(
rust_session_, bindings::BytesFromString(response_bytes));
rust_session_, bindings::Bytes(response_bytes));
return ErrorIntoStatus(error);
}

Expand All @@ -70,7 +70,7 @@ ClientSession::GetOutgoingMessage() {
}

v1::SessionRequest request;
if (!request.ParseFromString(BytesToString(*result.result))) {
if (!request.ParseFromString(*result.result)) {
return absl::InternalError(
"Failed to parse GetoutoingMessage result bytes as SessionRequest");
}
Expand All @@ -82,8 +82,7 @@ ClientSession::GetOutgoingMessage() {
absl::Status ClientSession::Write(
const v1::PlaintextMessage& unencrypted_request) {
bindings::Error* error = bindings::client_write(
rust_session_,
bindings::BytesFromString(unencrypted_request.SerializeAsString()));
rust_session_, bindings::Bytes(unencrypted_request.SerializeAsString()));

return ErrorIntoStatus(error);
}
Expand All @@ -99,8 +98,7 @@ absl::StatusOr<std::optional<v1::PlaintextMessage>> ClientSession::Read() {
}

v1::PlaintextMessage plaintext_message_result;
if (!plaintext_message_result.ParseFromString(
bindings::BytesToString(*result.result))) {
if (!plaintext_message_result.ParseFromString(*result.result)) {
return absl::InternalError(
"Failed to parse client_read result bytes as PlaintextMessage");
}
Expand Down
2 changes: 1 addition & 1 deletion cc/oak_session/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ SessionConfigBuilder::SessionConfigBuilder(AttestationType attestation_type,
if (builder_result.error != nullptr) {
LOG(FATAL) << absl::StrCat(
"Failed to create builder: ",
bindings::BytesToString(builder_result.error->message));
static_cast<absl::string_view>(builder_result.error->message));
}

builder_ = builder_result.result;
Expand Down
13 changes: 6 additions & 7 deletions cc/oak_session/oak_session_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,19 @@

namespace oak::session::bindings {

Bytes BytesFromString(absl::string_view bytes) {
return Bytes{.data = bytes.data(), .len = bytes.size()};
}
Bytes::Bytes(absl::string_view bytes) : data(bytes.data()), len(bytes.size()) {}

std::string BytesToString(Bytes bytes) {
return std::string(bytes.data, bytes.len);
std::ostream& operator<<(std::ostream& stream, const Bytes& bytes) {
stream << absl::string_view(bytes.data, bytes.len);
return stream;
}

absl::Status ErrorIntoStatus(bindings::Error* error) {
if (error == nullptr) {
return absl::OkStatus();
}
absl::Status status = absl::Status(absl::StatusCode::kInternal,
bindings::BytesToString(error->message));
absl::Status status =
absl::Status(absl::StatusCode::kInternal, error->message);
free_error(error);
return status;
}
Expand Down
20 changes: 12 additions & 8 deletions cc/oak_session/oak_session_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,23 @@ class ClientSession;

// A struct holding a sequence of Bytes.
// Corresponds to Bytes struct in oak_session/ffi/types.rs
//
// There are no specified ownership requirements for the type. A function that
// accepts or returns this type should document expectations around ownership
// and de-allocation requirements.
struct Bytes {
const char* data;
uint64_t len;
};

// Create a `Bytes` instance wrapping the provided string data.
// The lifetime of the created bytes is determined by the lifetime
// of the data backing the string_view.
Bytes BytesFromString(absl::string_view bytes);
// Create a `Bytes` instance wrapping the provided string data.
// The lifetime of the created bytes is determined by the lifetime
// of the data backing the string_view.
explicit Bytes(absl::string_view data);

operator absl::string_view() { return absl::string_view(data, len); }
};

// Create a new string wrapping the Bytes object.
// The Bytes data will be copied into a new string.
std::string BytesToString(Bytes bytes);
std::ostream& operator<<(std::ostream& stream, const Bytes& bytes);

// Corresponds to Error struct in oak_session/ffi/types.rs
struct Error {
Expand Down
34 changes: 15 additions & 19 deletions cc/oak_session/oak_session_bindings_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ void DoHandshake(ServerSession* server_session, ClientSession* client_session) {
// We could just past init.result directly, but let's ensure that the request
// successfully goes through the ser/deser properly.
SessionRequest request;
ASSERT_TRUE(request.ParseFromString(BytesToString(*init.result)));
ASSERT_TRUE(request.ParseFromString(*init.result));
std::string request_reserialized;
ASSERT_TRUE(request.SerializeToString(&request_reserialized));
Bytes request_bytes = BytesFromString(request_reserialized);
Bytes request_bytes = Bytes(request_reserialized);
free_bytes(init.result);

ASSERT_THAT(server_put_incoming_message(server_session, request_bytes),
Expand All @@ -51,13 +51,13 @@ void DoHandshake(ServerSession* server_session, ClientSession* client_session) {
ASSERT_THAT(init_resp, IsResult());

SessionResponse response;
ASSERT_TRUE(response.ParseFromString(BytesToString(*init_resp.result)));
ASSERT_TRUE(response.ParseFromString(*init_resp.result));
free_bytes(init_resp.result);
std::string response_reserialized;
ASSERT_TRUE(response.SerializeToString(&response_reserialized));
ASSERT_THAT(client_put_incoming_message(
client_session, BytesFromString(response_reserialized)),
NoError());
ASSERT_THAT(
client_put_incoming_message(client_session, Bytes(response_reserialized)),
NoError());

ASSERT_TRUE(server_is_open(server_session));
ASSERT_TRUE(client_is_open(client_session));
Expand All @@ -68,7 +68,7 @@ SessionConfig* TestConfig() {
HANDSHAKE_TYPE_NOISE_NN);
if (result.error != nullptr) {
LOG(FATAL) << "Failed to create session config builder"
<< BytesToString(result.error->message);
<< result.error->message;
}

return session_config_builder_build(result.result);
Expand Down Expand Up @@ -144,10 +144,9 @@ TEST(OakSessionBindingsTest, TestClientEncryptServerDecrypt) {

v1::PlaintextMessage plaintext_message_out;
plaintext_message_out.set_plaintext("Hello Client To Server");
ASSERT_THAT(
client_write(client_session,
BytesFromString(plaintext_message_out.SerializeAsString())),
NoError());
ASSERT_THAT(client_write(client_session,
Bytes(plaintext_message_out.SerializeAsString())),
NoError());

ErrorOrBytes client_out = client_get_outgoing_message(client_session);
ASSERT_THAT(client_out, IsResult());
Expand All @@ -160,8 +159,7 @@ TEST(OakSessionBindingsTest, TestClientEncryptServerDecrypt) {
ASSERT_THAT(server_in, IsResult());

v1::PlaintextMessage plaintext_message_in;
ASSERT_TRUE(
plaintext_message_in.ParseFromString(BytesToString(*server_in.result)));
ASSERT_TRUE(plaintext_message_in.ParseFromString(*server_in.result));
EXPECT_THAT(plaintext_message_in.plaintext(),
Eq(plaintext_message_out.plaintext()));
free_bytes(server_in.result);
Expand All @@ -182,10 +180,9 @@ TEST(OakSessionBindingsTest, TestServerEncryptClientDecrypt) {

v1::PlaintextMessage plaintext_message_out;
plaintext_message_out.set_plaintext("Hello Server to Client");
ASSERT_THAT(
server_write(server_session,
BytesFromString(plaintext_message_out.SerializeAsString())),
NoError());
ASSERT_THAT(server_write(server_session,
Bytes(plaintext_message_out.SerializeAsString())),
NoError());

ErrorOrBytes server_out = server_get_outgoing_message(server_session);
ASSERT_THAT(server_out, IsResult());
Expand All @@ -198,8 +195,7 @@ TEST(OakSessionBindingsTest, TestServerEncryptClientDecrypt) {
ASSERT_THAT(client_in, IsResult());

v1::PlaintextMessage plaintext_message_in;
ASSERT_TRUE(
plaintext_message_in.ParseFromString(BytesToString(*client_in.result)));
ASSERT_TRUE(plaintext_message_in.ParseFromString(*client_in.result));
ASSERT_EQ(plaintext_message_in.plaintext(),
plaintext_message_out.plaintext());
free_bytes(client_in.result);
Expand Down
10 changes: 4 additions & 6 deletions cc/oak_session/server_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ absl::Status ServerSession::PutIncomingMessage(
const v1::SessionRequest& request) {
const std::string request_bytes = request.SerializeAsString();
bindings::Error* error = bindings::server_put_incoming_message(
rust_session_, bindings::BytesFromString(request_bytes));
rust_session_, bindings::Bytes(request_bytes));
return bindings::ErrorIntoStatus(error);
}

Expand All @@ -69,7 +69,7 @@ ServerSession::GetOutgoingMessage() {
}

v1::SessionResponse response;
if (!response.ParseFromString(BytesToString(*result.result))) {
if (!response.ParseFromString(*result.result)) {
return absl::InternalError(
"Failed to parse GetOutoingMessage result bytes as SessionResponse");
}
Expand All @@ -81,8 +81,7 @@ ServerSession::GetOutgoingMessage() {
absl::Status ServerSession::Write(
const v1::PlaintextMessage& unencrypted_request) {
bindings::Error* error = bindings::server_write(
rust_session_,
bindings::BytesFromString(unencrypted_request.SerializeAsString()));
rust_session_, bindings::Bytes(unencrypted_request.SerializeAsString()));

return bindings::ErrorIntoStatus(error);
}
Expand All @@ -99,8 +98,7 @@ absl::StatusOr<std::optional<v1::PlaintextMessage>> ServerSession::Read() {

// Copy into new result string so we can free the bytes.
v1::PlaintextMessage plaintext_message_result;
if (!plaintext_message_result.ParseFromString(
bindings::BytesToString(*result.result))) {
if (!plaintext_message_result.ParseFromString(*result.result)) {
return absl::InternalError(
"Failed to parse server_read result bytes as PlaintextMessage");
}
Expand Down
9 changes: 3 additions & 6 deletions cc/oak_session/testing/matchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
// Works with ErrorOrBytes, ErrorOrClientSession, ErrorOrServerSession.
MATCHER(IsResult, "Contains result and no error") {
if (arg.error != nullptr) {
*result_listener << "Expected no error, but have: "
<< BytesToString(arg.error->message);
*result_listener << "Expected no error, but have: " << arg.error->message;
return false;
}
if (arg.result == nullptr) {
Expand All @@ -37,8 +36,7 @@ MATCHER(IsResult, "Contains result and no error") {
// A matcher that verifies that ErrorOr* types contain an error.
MATCHER(IsError, "Contains error and no result") {
if (arg.result != nullptr) {
*result_listener << "Expected no result, but have: "
<< BytesToString(*arg.result);
*result_listener << "Expected no result, but have: " << *arg.result;
return false;
}
if (arg.error == nullptr) {
Expand All @@ -52,8 +50,7 @@ MATCHER(IsError, "Contains error and no result") {
// A matcher that verifies that an Error* is null.
MATCHER(NoError, "") {
if (arg != nullptr) {
*result_listener << "Expected non-null error, but got: "
<< BytesToString(arg->message);
*result_listener << "Expected non-null error, but got: " << arg->message;
return false;
}
return true;
Expand Down

0 comments on commit dce33de

Please sign in to comment.