Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions google/cloud/storage/internal/async/writer_connection_resumed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,22 @@ class AsyncWriterConnectionResumedState
}

void WriteLoop(std::unique_lock<std::mutex> lk) {
if (state_ != State::kIdle) return;

// Determine if there's data left to write *before* potentially finalizing.
writing_ = write_offset_ < resend_buffer_.size();
auto const has_data = write_offset_ < resend_buffer_.size();

// If we are writing data, continue doing so.
if (writing_) {
if (has_data) {
state_ = State::kWriting;
// Still data to write, determine the next chunk.
auto const n = resend_buffer_.size() - write_offset_;
auto payload = resend_buffer_.Subcord(write_offset_, n);
if (flush_) return FlushStep(std::move(lk), std::move(payload));
return WriteStep(std::move(lk), std::move(payload));
}

// No data left to write (writing_ is false).
// No data left to write.
// Check if we need to finalize (only if not already writing data AND not
// already finalizing).
if (finalize_ && !finalizing_) {
Expand All @@ -212,22 +215,24 @@ class AsyncWriterConnectionResumedState
}
// If not finalizing, check if an empty flush is needed.
if (flush_) {
state_ = State::kWriting;
// Pass empty payload to FlushStep
return FlushStep(std::move(lk), absl::Cord{});
}

// No data to write, not finalizing, not flushing. The loop can stop.
// writing_ is already false.
state_ = State::kIdle;
}

// FinalizeStep is now called only when all data in resend_buffer_ is written.
void FinalizeStep(std::unique_lock<std::mutex> lk) {
// Check *under lock* if we are already finalizing.
if (finalizing_) {
if (finalizing_ || state_ != State::kIdle) {
// If another thread initiated FinalizeStep concurrently, just return.
return;
}
// Mark that we are starting the finalization process.
state_ = State::kWriting;
finalizing_ = true;
auto impl = Impl(lk);
lk.unlock();
Expand Down Expand Up @@ -263,8 +268,10 @@ class AsyncWriterConnectionResumedState
auto impl = Impl(lk);
lk.unlock();
impl->Query().then([this, result, w = WeakFromThis()](auto f) {
SetFlushed(std::unique_lock<std::mutex>(mu_), std::move(result));
if (auto self = w.lock()) return self->OnQuery(f.get());
auto self = w.lock();
if (!self) return;
self->OnQuery(f.get());
self->SetFlushed(std::unique_lock<std::mutex>(self->mu_), std::move(result));
});
}

Expand Down Expand Up @@ -303,7 +310,8 @@ class AsyncWriterConnectionResumedState
write_offset_ -= static_cast<std::size_t>(n);
// If the buffer is small enough, collect all the handlers to notify them.
auto const handlers = ClearHandlersIfEmpty(lk);
WriteLoop(std::move(lk));
state_ = State::kIdle;
StartWriting(std::move(lk));
// The notifications are deferred until the lock is released, as they might
// call back and try to acquire the lock.
for (auto const& h : handlers) {
Expand All @@ -325,7 +333,8 @@ class AsyncWriterConnectionResumedState
if (!result.ok()) return Resume(std::move(result));
std::unique_lock<std::mutex> lk(mu_);
write_offset_ += write_size;
return WriteLoop(std::move(lk));
state_ = State::kIdle;
return StartWriting(std::move(lk));
}

void Resume(Status const& s) {
Expand Down Expand Up @@ -353,10 +362,12 @@ class AsyncWriterConnectionResumedState
bool was_finalizing;
{
std::unique_lock<std::mutex> lk(mu_);
if (state_ == State::kResuming) return;
was_finalizing = finalizing_;
if (!s.ok() && cancelled_) {
return SetError(std::move(lk), std::move(s));
}
state_ = State::kResuming;
}
// Pass the original status `s` and `was_finalizing` to the callback.
factory_(std::move(request))
Expand Down Expand Up @@ -427,7 +438,7 @@ class AsyncWriterConnectionResumedState
void SetFinalized(std::unique_lock<std::mutex> lk,
google::storage::v2::Object object) {
resend_buffer_.Clear();
writing_ = false;
state_ = State::kIdle;
finalize_ = false;
finalizing_ = false; // Reset finalizing flag
flush_ = false;
Expand Down Expand Up @@ -471,15 +482,11 @@ class AsyncWriterConnectionResumedState
// lock.
for (auto& h : handlers) h->Execute(Status{});
flushed.set_value(result);
// Restart the write loop ONLY if we are not already finalizing.
// If finalizing_ is true, the completion will be handled by OnFinalize.
std::unique_lock<std::mutex> loop_lk(mu_);
if (!finalizing_) WriteLoop(std::move(loop_lk));
}

void SetError(std::unique_lock<std::mutex> lk, Status const& status) {
resume_status_ = status;
writing_ = false;
state_ = State::kIdle;
finalize_ = false;
finalizing_ = false; // Reset finalizing flag
flush_ = false;
Expand Down Expand Up @@ -602,7 +609,12 @@ class AsyncWriterConnectionResumedState
std::vector<std::unique_ptr<BufferShrinkHandler>> flush_handlers_;

// True if the writing loop is activate.
bool writing_ = false;
enum class State {
kIdle,
kWriting,
kResuming,
};
State state_ = State::kIdle;

// True if cancelled, in which case any RPC failures are final.
bool cancelled_ = false;
Expand Down Expand Up @@ -717,4 +729,4 @@ MakeWriterConnectionResumed(
GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END
} // namespace storage_internal
} // namespace cloud
} // namespace google
} // namespace google
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "google/cloud/testing_util/status_matchers.h"
#include <google/storage/v2/storage.pb.h>
#include <gmock/gmock.h>
#include <chrono>
#include <thread>

namespace google {
namespace cloud {
Expand Down Expand Up @@ -170,7 +172,7 @@ TEST(WriterConnectionResumed, FlushEmpty) {
auto mock = std::make_unique<MockAsyncWriterConnection>();
EXPECT_CALL(*mock, PersistedState)
.WillRepeatedly(Return(MakePersistedState(0)));
EXPECT_CALL(*mock, Flush).WillOnce([&](auto const& p) {
EXPECT_CALL(*mock, Flush).WillRepeatedly([&](auto const& p) {
EXPECT_TRUE(p.payload().empty());
return sequencer.PushBack("Flush").then([](auto f) {
if (!f.get()) return TransientError();
Expand Down Expand Up @@ -214,13 +216,21 @@ TEST(WriteConnectionResumed, FlushNonEmpty) {

EXPECT_CALL(*mock, PersistedState)
.WillRepeatedly(Return(MakePersistedState(0)));
EXPECT_CALL(*mock, Flush).WillOnce([&](auto const& p) {
EXPECT_EQ(p.payload(), payload.payload());
return sequencer.PushBack("Flush").then([](auto f) {
if (!f.get()) return TransientError();
return Status{};
});
});
EXPECT_CALL(*mock, Flush)
.WillOnce([&](auto const& p) {
EXPECT_EQ(p.payload(), payload.payload());
return sequencer.PushBack("Flush").then([](auto f) {
if (!f.get()) return TransientError();
return Status{};
});
})
.WillOnce([&](auto const& p) {
EXPECT_TRUE(p.payload().empty());
return sequencer.PushBack("Flush").then([](auto f) {
if (!f.get()) return TransientError();
return Status{};
});
});
EXPECT_CALL(*mock, Query).WillOnce([&]() {
return sequencer.PushBack("Query").then(
[](auto f) -> StatusOr<std::int64_t> {
Expand Down Expand Up @@ -394,6 +404,83 @@ TEST(WriteConnectionResumed, ResumeUsesAppendObjectSpecFromInitialRequest) {
"projects/_/buckets/test-bucket");
}

TEST(WriteConnectionResumed, NoConcurrentWritesWhenFlushAndWriteRace) {
AsyncSequencer<bool> sequencer;
auto mock = std::make_unique<MockAsyncWriterConnection>();
auto initial_request = google::storage::v2::BidiWriteObjectRequest{};
auto first_response = google::storage::v2::BidiWriteObjectResponse{};

EXPECT_CALL(*mock, PersistedState)
.WillRepeatedly(Return(MakePersistedState(0)));
EXPECT_CALL(*mock, Flush(_)).WillRepeatedly([&](auto) {
return sequencer.PushBack("Flush").then([](auto f) {
if (!f.get()) return TransientError();
return Status{};
});
});
EXPECT_CALL(*mock, Query).WillOnce([&]() {
return sequencer.PushBack("Query").then([](auto f) -> StatusOr<std::int64_t> {
if (!f.get()) return TransientError();
return 0;
});
});

// Make Write detect concurrent invocations. If two writes run concurrently
// the compare_exchange will fail and the test will fail.
std::atomic<bool> in_write{false};
EXPECT_CALL(*mock, Write(_))
.WillRepeatedly([&](auto) {
bool expected = false;
EXPECT_TRUE(in_write.compare_exchange_strong(expected, true));
// Simulate some work that allows a concurrent write to attempt to run.
std::this_thread::sleep_for(std::chrono::milliseconds(50));
in_write.store(false);
return make_ready_future(Status{});
});

MockFactory mock_factory;
EXPECT_CALL(mock_factory, Call).Times(0);

auto connection = MakeWriterConnectionResumed(
mock_factory.AsStdFunction(), std::move(mock), initial_request, nullptr,
first_response, Options{});

// Start a flush which will call impl->Flush() and block.
auto flush_future = connection->Flush({});
// Allow the Flush to complete, this will schedule a Query (but Query will
// remain blocked until we pop it).
auto next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Flush");
next.first.set_value(true);

// Immediately perform a user Write after the flush completed but before
// Query completes. This can race with the OnQuery-driven write.
auto write_future = connection->Write(TestPayload(1024));

// Now allow the Query to complete; OnQuery may schedule a write.
next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Query");
next.first.set_value(true);

// Wait for both futures to complete with a timeout to avoid indefinite hang.
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(2);
while (!write_future.is_ready() && std::chrono::steady_clock::now() < deadline) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
deadline = std::chrono::steady_clock::now() + std::chrono::seconds(2);
while (!flush_future.is_ready() && std::chrono::steady_clock::now() < deadline) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}

ASSERT_TRUE(write_future.is_ready());
ASSERT_TRUE(flush_future.is_ready());

// Both futures should complete successfully.
EXPECT_THAT(write_future.get(), StatusIs(StatusCode::kOk));
EXPECT_THAT(flush_future.get(), StatusIs(StatusCode::kOk));
}


TEST(WriteConnectionResumed, WriteHandleAssignmentAfterResume) {
struct {
bool use_write_object_spec;
Expand Down Expand Up @@ -463,4 +550,4 @@ TEST(WriteConnectionResumed, WriteHandleAssignmentAfterResume) {
GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END
} // namespace storage_internal
} // namespace cloud
} // namespace google
} // namespace google