Skip to content

Commit 9aa5566

Browse files
Probabilistic diff to sample partitions for running diff test
1 parent 75b70c3 commit 9aa5566

File tree

8 files changed

+154
-7
lines changed

8 files changed

+154
-7
lines changed

common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ default SpecificTokens specificTokens() {
8787

8888
MetadataKeyspaceOptions metadataOptions();
8989

90+
/**
91+
* Sampling probability ranges from 0-1 which decides how many partitions are to be diffed using probabilistic diff
92+
* default value is 1 which means all the partitions are diffed
93+
* @return partitionSamplingProbability
94+
*/
95+
double partitionSamplingProbability();
96+
9097
/**
9198
* Contains the options that specify the retry strategy for retrieving data at the application level.
9299
* Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy},

common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public class YamlJobConfiguration implements JobConfiguration {
4848
public String specific_tokens = null;
4949
public String disallowed_tokens = null;
5050
public RetryOptions retry_options;
51+
public double partition_sampling_probability = 1;
5152

5253
public static YamlJobConfiguration load(InputStream inputStream) {
5354
Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class,
@@ -103,6 +104,11 @@ public MetadataKeyspaceOptions metadataOptions() {
103104
return metadata_options;
104105
}
105106

107+
@Override
108+
public double partitionSamplingProbability() {
109+
return partition_sampling_probability;
110+
}
111+
106112
public RetryOptions retryOptions() {
107113
return retry_options;
108114
}
@@ -130,6 +136,7 @@ public String toString() {
130136
", keyspace_tables=" + keyspace_tables +
131137
", buckets=" + buckets +
132138
", rate_limit=" + rate_limit +
139+
", partition_sampling_probability=" + partition_sampling_probability +
133140
", job_id='" + job_id + '\'' +
134141
", token_scan_fetch_size=" + token_scan_fetch_size +
135142
", partition_read_fetch_size=" + partition_read_fetch_size +

spark-job/src/main/java/org/apache/cassandra/diff/Differ.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
import java.util.Iterator;
2828
import java.util.List;
2929
import java.util.Map;
30+
import java.util.Random;
3031
import java.util.UUID;
3132
import java.util.concurrent.Callable;
3233
import java.util.function.BiConsumer;
3334
import java.util.function.Function;
35+
import java.util.function.Predicate;
3436
import java.util.stream.Collectors;
3537

3638
import com.google.common.annotations.VisibleForTesting;
@@ -63,6 +65,7 @@ public class Differ implements Serializable
6365
private final double reverseReadProbability;
6466
private final SpecificTokens specificTokens;
6567
private final RetryStrategyProvider retryStrategyProvider;
68+
private final double partitionSamplingProbability;
6669

6770
private static DiffCluster srcDiffCluster;
6871
private static DiffCluster targetDiffCluster;
@@ -103,6 +106,7 @@ public Differ(JobConfiguration config,
103106
this.reverseReadProbability = config.reverseReadProbability();
104107
this.specificTokens = config.specificTokens();
105108
this.retryStrategyProvider = retryStrategyProvider;
109+
this.partitionSamplingProbability = config.partitionSamplingProbability();
106110
synchronized (Differ.class)
107111
{
108112
/*
@@ -225,12 +229,28 @@ public RangeStats diffTable(final DiffContext context,
225229
mismatchReporter,
226230
journal,
227231
COMPARISON_EXECUTOR);
228-
229-
final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider);
232+
final Predicate<PartitionKey> partitionSamplingFunction = shouldIncludePartition(jobId, partitionSamplingProbability);
233+
final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider, partitionSamplingFunction);
230234
logger.debug("Table [{}] stats - ({})", context.table.getTable(), tableStats);
231235
return tableStats;
232236
}
233237

238+
// Returns a function which decides if we should include a partition for diffing
239+
// Uses probability for sampling.
240+
@VisibleForTesting
241+
static Predicate<PartitionKey> shouldIncludePartition(final UUID jobId, final double partitionSamplingProbability) {
242+
if (partitionSamplingProbability > 1 || partitionSamplingProbability <= 0) {
243+
logger.error("Invalid partition sampling property {}, it should be between 0 and 1", partitionSamplingProbability);
244+
throw new IllegalArgumentException("Invalid partition sampling property, it should be between 0 and 1");
245+
}
246+
if (partitionSamplingProbability == 1) {
247+
return partitionKey -> true;
248+
} else {
249+
final Random random = new Random(jobId.hashCode());
250+
return partitionKey -> random.nextDouble() <= partitionSamplingProbability;
251+
}
252+
}
253+
234254
private Iterator<Row> fetchRows(DiffContext context, PartitionKey key, boolean shouldReverse, DiffCluster.Type type) {
235255
Callable<Iterator<Row>> rows = () -> type == DiffCluster.Type.SOURCE
236256
? context.source.getPartition(context.table, key, shouldReverse)

spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.function.BiConsumer;
2828
import java.util.function.Consumer;
2929
import java.util.function.Function;
30+
import java.util.function.Predicate;
3031

3132
import com.google.common.base.Verify;
3233
import org.slf4j.Logger;
@@ -57,6 +58,22 @@ public RangeComparator(DiffContext context,
5758
public RangeStats compare(Iterator<PartitionKey> sourceKeys,
5859
Iterator<PartitionKey> targetKeys,
5960
Function<PartitionKey, PartitionComparator> partitionTaskProvider) {
61+
return compare(sourceKeys,targetKeys,partitionTaskProvider, partitionKey -> true);
62+
}
63+
64+
/**
65+
* Compares partitions in src and target clusters.
66+
*
67+
* @param sourceKeys partition keys in the source cluster
68+
* @param targetKeys partition keys in the target cluster
69+
* @param partitionTaskProvider comparision task
70+
* @param partitionSampler samples partitions based on the probability for probabilistic diff
71+
* @return stats about the diff
72+
*/
73+
public RangeStats compare(Iterator<PartitionKey> sourceKeys,
74+
Iterator<PartitionKey> targetKeys,
75+
Function<PartitionKey, PartitionComparator> partitionTaskProvider,
76+
Predicate<PartitionKey> partitionSampler) {
6077

6178
final RangeStats rangeStats = RangeStats.newStats();
6279
// We can catch this condition earlier, but it doesn't hurt to also check here
@@ -115,11 +132,16 @@ public RangeStats compare(Iterator<PartitionKey> sourceKeys,
115132

116133
BigInteger token = sourceKey.getTokenAsBigInteger();
117134
try {
118-
PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
119-
comparisonExecutor.submit(comparisonTask,
120-
onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
121-
onError(rangeStats, token, errorReporter),
122-
phaser);
135+
// Use probabilisticPartitionSampler for sampling partitions, skip partition
136+
// if the sampler returns false otherwise run diff on that partition
137+
if (partitionSampler.test(sourceKey)) {
138+
PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
139+
comparisonExecutor.submit(comparisonTask,
140+
onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
141+
onError(rangeStats, token, errorReporter),
142+
phaser);
143+
}
144+
123145
} catch (Throwable t) {
124146
// Handle errors thrown when creating the comparison task. This should trap timeouts and
125147
// unavailables occurring when performing the initial query to read the full partition.

spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,10 @@ public int buckets() {
108108
public Optional<UUID> jobId() {
109109
return Optional.of(UUID.randomUUID());
110110
}
111+
112+
@Override
113+
public double partitionSamplingProbability() {
114+
return 1;
115+
}
111116
}
112117
}

spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,65 @@
2121

2222
import java.math.BigInteger;
2323
import java.util.Map;
24+
import java.util.UUID;
2425
import java.util.function.Function;
26+
import java.util.function.Predicate;
2527

2628
import com.google.common.base.VerifyException;
2729
import com.google.common.collect.Lists;
30+
import org.junit.Rule;
2831
import org.junit.Test;
32+
import org.junit.rules.ExpectedException;
2933

3034
import static org.junit.Assert.assertEquals;
3135
import static org.junit.Assert.assertNull;
36+
import static org.junit.Assert.assertTrue;
3237

3338
public class DifferTest {
39+
@Rule
40+
public ExpectedException expectedException = ExpectedException.none();
41+
42+
@Test
43+
public void testIncludeAllPartitions() {
44+
final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
45+
final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
46+
assertTrue(Differ.shouldIncludePartition(uuid, 1).test(testKey));
47+
}
48+
49+
@Test
50+
public void shouldIncludePartitionWithProbabilityInvalidProbability() {
51+
final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
52+
final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
53+
expectedException.expect(IllegalArgumentException.class);
54+
expectedException.expectMessage("Invalid partition sampling property, it should be between 0 and 1");
55+
Differ.shouldIncludePartition(uuid, -1).test(testKey);
56+
}
57+
58+
@Test
59+
public void shouldIncludePartitionWithProbabilityHalf() {
60+
final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
61+
int count = 0;
62+
final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
63+
final Predicate<PartitionKey> partitionSampler = Differ.shouldIncludePartition(uuid, 0.5);
64+
for (int i = 0; i < 20; i++) {
65+
if (partitionSampler.test(testKey)) {
66+
count++;
67+
}
68+
}
69+
assertTrue(count <= 15);
70+
assertTrue(count >= 5);
71+
}
72+
73+
@Test
74+
public void shouldIncludePartitionShouldGenerateSameSequenceForGivenJobId() {
75+
final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
76+
final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
77+
final Predicate<PartitionKey> partitionSampler1 = Differ.shouldIncludePartition(uuid, 0.5);
78+
final Predicate<PartitionKey> partitionSampler2 = Differ.shouldIncludePartition(uuid, 0.5);
79+
for (int i = 0; i < 10; i++) {
80+
assertEquals(partitionSampler2.test(testKey), partitionSampler1.test(testKey));
81+
}
82+
}
3483

3584
@Test(expected = VerifyException.class)
3685
public void rejectNullStartOfRange() {

spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,38 @@ public class RangeComparatorTest {
5656
private ComparisonExecutor executor = ComparisonExecutor.newExecutor(1, new MetricRegistry());
5757
private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider
5858

59+
@Test
60+
public void probabilisticDiffIncludeAllPartitions() {
61+
RangeComparator comparator = comparator(context(0L, 100L));
62+
RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), keys(0,1, 2, 3, 4, 5, 7), this::alwaysMatch);
63+
assertFalse(stats.isEmpty());
64+
assertEquals(1, stats.getOnlyInSource());
65+
assertEquals(1, stats.getOnlyInTarget());
66+
assertEquals(6, stats.getMatchedPartitions());
67+
assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
68+
assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
69+
assertNothingReported(errors, journal);
70+
assertCompared(0, 1, 2, 3, 4, 5);
71+
}
72+
73+
@Test
74+
public void probabilisticDiffProbabilityHalf() {
75+
RangeComparator comparator = comparator(context(0L, 100L));
76+
RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6),
77+
keys(0, 1, 2, 3, 4, 5, 7),
78+
this::alwaysMatch,
79+
key -> key.getTokenAsBigInteger().intValue() % 2 == 0);
80+
assertFalse(stats.isEmpty());
81+
assertEquals(1, stats.getOnlyInSource());
82+
assertEquals(1, stats.getOnlyInTarget());
83+
assertEquals(3, stats.getMatchedPartitions());
84+
assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
85+
assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
86+
assertNothingReported(errors, journal);
87+
assertCompared(0, 2, 4);
88+
}
89+
90+
5991
@Test
6092
public void emptyRange() {
6193
RangeComparator comparator = comparator(context(100L, 100L));

spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ private class MockConfig extends AbstractMockJobConfiguration {
2929
public List<String> disallowedKeyspaces() {
3030
return disallowedKeyspaces;
3131
}
32+
33+
@Override
34+
public double partitionSamplingProbability() {
35+
return 1;
36+
}
3237
}
3338

3439
@Test

0 commit comments

Comments
 (0)