Skip to content

Commit 3ea196f

Browse files
authored
RUST-1433 Propagate original error for some labeled retry errors (mongodb#903)
1 parent 2692d43 commit 3ea196f

File tree

5 files changed

+152
-14
lines changed

5 files changed

+152
-14
lines changed

src/client/executor.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,10 @@ impl Client {
410410
drop(server);
411411

412412
if let Some(r) = retry {
413-
if err.is_server_error()
413+
if (err.is_server_error()
414414
|| err.is_read_retryable()
415-
|| err.is_write_retryable()
415+
|| err.is_write_retryable())
416+
&& !err.contains_label("NoWritesPerformed")
416417
{
417418
return Err(err);
418419
} else {
@@ -606,7 +607,8 @@ impl Client {
606607
connection: connection_info.clone(),
607608
service_id,
608609
})
609-
});
610+
})
611+
.await;
610612

611613
let start_time = Instant::now();
612614
let command_result = match connection.send_raw_command(raw_cmd, request_id).await {
@@ -706,7 +708,8 @@ impl Client {
706708
connection: connection_info.clone(),
707709
service_id,
708710
})
709-
});
711+
})
712+
.await;
710713

711714
if let Some(ref mut session) = session {
712715
if err.is_network_error() {
@@ -735,7 +738,8 @@ impl Client {
735738
connection: connection_info.clone(),
736739
service_id,
737740
})
738-
});
741+
})
742+
.await;
739743

740744
#[cfg(feature = "in-use-encryption-unstable")]
741745
let response = {

src/client/mod.rs

+37-5
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,42 @@ impl Client {
218218
.map_or(false, |cs| cs.exec().has_mongocryptd_client())
219219
}
220220

221+
fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
222+
#[cfg(test)]
223+
{
224+
self.inner
225+
.options
226+
.test_options
227+
.as_ref()
228+
.and_then(|t| t.async_event_listener.as_ref())
229+
}
230+
#[cfg(not(test))]
231+
{
232+
None
233+
}
234+
}
235+
221236
#[cfg(not(feature = "tracing-unstable"))]
222-
pub(crate) fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
223-
if let Some(ref handler) = self.inner.options.command_event_handler {
224-
let event = generate_event();
237+
pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
238+
let handler = self.inner.options.command_event_handler.as_ref();
239+
let test_channel = self.test_command_event_channel();
240+
if handler.is_none() && test_channel.is_none() {
241+
return;
242+
}
243+
244+
let event = generate_event();
245+
if let Some(tx) = test_channel {
246+
let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
247+
let _ = tx.send(msg).await;
248+
ack.wait_for_acknowledgment().await;
249+
}
250+
if let Some(handler) = handler {
225251
handle_command_event(handler.as_ref(), event);
226252
}
227253
}
228254

229255
#[cfg(feature = "tracing-unstable")]
230-
pub(crate) fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
256+
pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
231257
let tracing_emitter = if trace_or_log_enabled!(
232258
target: COMMAND_TRACING_EVENT_TARGET,
233259
TracingOrLogLevel::Debug
@@ -240,11 +266,17 @@ impl Client {
240266
None
241267
};
242268
let apm_event_handler = self.inner.options.command_event_handler.as_ref();
243-
if !(tracing_emitter.is_some() || apm_event_handler.is_some()) {
269+
let test_channel = self.test_command_event_channel();
270+
if !(tracing_emitter.is_some() || apm_event_handler.is_some() || test_channel.is_some()) {
244271
return;
245272
}
246273

247274
let event = generate_event();
275+
if let Some(tx) = test_channel {
276+
let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
277+
let _ = tx.send(msg).await;
278+
ack.wait_for_acknowledgment().await;
279+
}
248280
if let (Some(event_handler), Some(ref tracing_emitter)) =
249281
(apm_event_handler, &tracing_emitter)
250282
{

src/client/options/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,15 @@ pub(crate) struct TestOptions {
587587

588588
/// Mock response for `SrvPollingMonitor::lookup_hosts`.
589589
pub(crate) mock_lookup_hosts: Option<Result<LookupHosts>>,
590+
591+
/// Async-capable command event listener.
592+
pub(crate) async_event_listener: Option<TestEventSender>,
590593
}
591594

595+
pub(crate) type TestEventSender = tokio::sync::mpsc::Sender<
596+
crate::runtime::AcknowledgedMessage<crate::event::command::CommandEvent>,
597+
>;
598+
592599
fn default_hosts() -> Vec<ServerAddress> {
593600
vec![ServerAddress::default()]
594601
}

src/test/spec/retryable_writes/mod.rs

+97-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ mod test_file;
22

33
use std::{sync::Arc, time::Duration};
44

5+
use bson::Bson;
56
use futures::stream::TryStreamExt;
67
use semver::VersionReq;
7-
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
8+
use tokio::sync::{Mutex, RwLockReadGuard, RwLockWriteGuard};
89

910
use test_file::{TestFile, TestResult};
1011

@@ -13,11 +14,11 @@ use crate::{
1314
error::{ErrorKind, Result, RETRYABLE_WRITE_ERROR},
1415
event::{
1516
cmap::{CmapEvent, CmapEventHandler, ConnectionCheckoutFailedReason},
16-
command::CommandEventHandler,
17+
command::{CommandEvent, CommandEventHandler},
1718
},
1819
options::{ClientOptions, FindOptions, InsertManyOptions},
1920
runtime,
20-
runtime::AsyncJoinHandle,
21+
runtime::{spawn, AcknowledgedMessage, AsyncJoinHandle},
2122
sdam::MIN_HEARTBEAT_FREQUENCY,
2223
test::{
2324
assert_matches,
@@ -35,6 +36,7 @@ use crate::{
3536
CLIENT_OPTIONS,
3637
LOCK,
3738
},
39+
Client,
3840
};
3941

4042
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
@@ -499,3 +501,95 @@ async fn retry_write_pool_cleared() {
499501

500502
assert_eq!(handler.get_command_started_events(&["insert"]).len(), 3);
501503
}
504+
505+
/// Prose test from retryable writes spec verifying that the original error is returned after
506+
/// encountering a WriteConcernError with a RetryableWriteError label.
507+
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
508+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
509+
async fn retry_write_retryable_write_error() {
510+
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
511+
512+
let mut client_options = CLIENT_OPTIONS.get().await.clone();
513+
client_options.retry_writes = Some(true);
514+
let (event_tx, event_rx) = tokio::sync::mpsc::channel::<AcknowledgedMessage<CommandEvent>>(1);
515+
// The listener needs to be active on client startup, but also needs a handle to the client
516+
// itself for the trigger action.
517+
let listener_client: Arc<Mutex<Option<TestClient>>> = Arc::new(Mutex::new(None));
518+
// Set up event listener
519+
let (fp_tx, mut fp_rx) = tokio::sync::mpsc::unbounded_channel();
520+
{
521+
let client = listener_client.clone();
522+
let mut event_rx = event_rx;
523+
let fp_tx = fp_tx.clone();
524+
// Spawn a task to watch the event channel
525+
spawn(async move {
526+
while let Some(msg) = event_rx.recv().await {
527+
if let CommandEvent::Succeeded(ev) = &*msg {
528+
if let Some(Bson::Document(wc_err)) = ev.reply.get("writeConcernError") {
529+
if ev.command_name == "insert" && wc_err.get_i32("code") == Ok(91) {
530+
// Spawn a new task so events continue to process
531+
let client = client.clone();
532+
let fp_tx = fp_tx.clone();
533+
spawn(async move {
534+
// Enable the failpoint.
535+
let fp_guard = {
536+
let client = client.lock().await;
537+
FailPoint::fail_command(
538+
&["insert"],
539+
FailPointMode::Times(1),
540+
FailCommandOptions::builder()
541+
.error_code(10107)
542+
.error_labels(vec![
543+
"RetryableWriteError".to_string(),
544+
"NoWritesPerformed".to_string(),
545+
])
546+
.build(),
547+
)
548+
.enable(client.as_ref().unwrap(), None)
549+
.await
550+
.unwrap()
551+
};
552+
fp_tx.send(fp_guard).unwrap();
553+
// Defer acknowledging the message until the failpoint has been set
554+
// up so the retry hits it.
555+
msg.acknowledge(());
556+
});
557+
}
558+
}
559+
}
560+
}
561+
});
562+
}
563+
client_options.test_options_mut().async_event_listener = Some(event_tx);
564+
let client = Client::test_builder().options(client_options).build().await;
565+
*listener_client.lock().await = Some(client.clone());
566+
567+
if !client.is_replica_set() || client.server_version_lt(6, 0) {
568+
log_uncaptured("skipping retry_write_retryable_write_error: invalid topology");
569+
return;
570+
}
571+
572+
let _fp_guard = FailPoint::fail_command(
573+
&["insert"],
574+
FailPointMode::Times(1),
575+
FailCommandOptions::builder()
576+
.write_concern_error(doc! {
577+
"code": 91,
578+
"errorLabels": ["RetryableWriteError"],
579+
})
580+
.build(),
581+
)
582+
.enable(&client, None)
583+
.await
584+
.unwrap();
585+
586+
let result = client
587+
.database("test")
588+
.collection::<Document>("test")
589+
.insert_one(doc! { "hello": "there" }, None)
590+
.await;
591+
assert_eq!(result.unwrap_err().code(), Some(91));
592+
593+
// Consume failpoint guard.
594+
let _ = fp_rx.recv().await;
595+
}

src/test/util/failpoint.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ impl FailPoint {
6464
}
6565
}
6666

67+
#[derive(Debug)]
6768
pub struct FailPointGuard {
6869
client: Client,
6970
failpoint_name: String,
@@ -103,7 +104,7 @@ pub enum FailPointMode {
103104
}
104105

105106
#[serde_with::skip_serializing_none]
106-
#[derive(Debug, TypedBuilder, Serialize)]
107+
#[derive(Debug, Default, TypedBuilder, Serialize)]
107108
#[builder(field_defaults(default, setter(into)))]
108109
#[serde(rename_all = "camelCase")]
109110
pub struct FailCommandOptions {

0 commit comments

Comments
 (0)