Skip to content

Commit d7a02a4

Browse files
authored
Merge pull request scylladb#1142 from Lorak-mmk/batch-bugfix
Make RawBatchValuesIteratorAdapter length equal to its internal BatchValuesIter length
2 parents 4d80aa1 + 559583f commit d7a02a4

File tree

3 files changed

+86
-8
lines changed

3 files changed

+86
-8
lines changed

scylla-cql/src/types/serialize/raw_batch.rs

+11-8
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ pub trait RawBatchValuesIterator<'a> {
5252
where
5353
Self: Sized,
5454
{
55-
let mut count = 0;
56-
while self.skip_next().is_some() {
57-
count += 1;
58-
}
59-
count
55+
std::iter::from_fn(|| self.skip_next()).count()
6056
}
6157
}
6258

@@ -145,19 +141,26 @@ where
145141
{
146142
#[inline]
147143
fn serialize_next(&mut self, writer: &mut RowWriter) -> Option<Result<(), SerializationError>> {
148-
let ctx = self.contexts.next()?;
144+
// We do `unwrap_or` because we want the iterator length to be the same
145+
// as the amount of values. Limiting to length of the amount of
146+
// statements (contexts) causes the caller to not be able to correctly
147+
// detect that amount of statements and values is different.
148+
let ctx = self
149+
.contexts
150+
.next()
151+
.unwrap_or(RowSerializationContext::empty());
149152
self.batch_values_iterator.serialize_next(&ctx, writer)
150153
}
151154

152155
fn is_empty_next(&mut self) -> Option<bool> {
153-
self.contexts.next()?;
156+
let _ = self.contexts.next();
154157
let ret = self.batch_values_iterator.is_empty_next()?;
155158
Some(ret)
156159
}
157160

158161
#[inline]
159162
fn skip_next(&mut self) -> Option<()> {
160-
self.contexts.next()?;
163+
let _ = self.contexts.next();
161164
self.batch_values_iterator.skip_next()?;
162165
Some(())
163166
}

scylla/tests/integration/batch.rs

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use scylla::batch::Batch;
2+
use scylla::batch::BatchType;
3+
use scylla::frame::frame_errors::BatchSerializationError;
4+
use scylla::frame::frame_errors::CqlRequestSerializationError;
5+
use scylla::query::Query;
6+
use scylla::transport::errors::QueryError;
7+
8+
use crate::utils::create_new_session_builder;
9+
use crate::utils::setup_tracing;
10+
use crate::utils::unique_keyspace_name;
11+
use crate::utils::PerformDDL;
12+
13+
use assert_matches::assert_matches;
14+
15+
#[tokio::test]
16+
#[ntest::timeout(60000)]
17+
async fn batch_statements_and_values_mismatch_detected() {
18+
setup_tracing();
19+
let session = create_new_session_builder().build().await.unwrap();
20+
let ks = unique_keyspace_name();
21+
session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap();
22+
session.use_keyspace(ks, false).await.unwrap();
23+
session
24+
.ddl("CREATE TABLE IF NOT EXISTS batch_serialization_test (p int PRIMARY KEY, val int)")
25+
.await
26+
.unwrap();
27+
28+
let mut batch = Batch::new(BatchType::Logged);
29+
let stmt = session
30+
.prepare("INSERT INTO batch_serialization_test (p, val) VALUES (?, ?)")
31+
.await
32+
.unwrap();
33+
batch.append_statement(stmt.clone());
34+
batch.append_statement(Query::new(
35+
"INSERT INTO batch_serialization_test (p, val) VALUES (3, 4)",
36+
));
37+
batch.append_statement(stmt);
38+
39+
// Subtest 1: counts are correct
40+
{
41+
session.batch(&batch, &((1, 2), (), (5, 6))).await.unwrap();
42+
}
43+
44+
// Subtest 2: not enough values
45+
{
46+
let err = session.batch(&batch, &((1, 2), ())).await.unwrap_err();
47+
assert_matches!(
48+
err,
49+
QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization(
50+
BatchSerializationError::ValuesAndStatementsLengthMismatch {
51+
n_value_lists: 2,
52+
n_statements: 3
53+
}
54+
))
55+
)
56+
}
57+
58+
// Subtest 3: too many values
59+
{
60+
let err = session
61+
.batch(&batch, &((1, 2), (), (5, 6), (7, 8)))
62+
.await
63+
.unwrap_err();
64+
assert_matches!(
65+
err,
66+
QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization(
67+
BatchSerializationError::ValuesAndStatementsLengthMismatch {
68+
n_value_lists: 4,
69+
n_statements: 3
70+
}
71+
))
72+
)
73+
}
74+
}

scylla/tests/integration/main.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod authenticate;
2+
mod batch;
23
mod consistency;
34
mod cql_collections;
45
mod cql_types;

0 commit comments

Comments
 (0)