Skip to content

Commit de1b481

Browse files
Merge pull request #1573 from benjchristensen/parallel-backpressure
Backpressure: parallel
2 parents e128ffe + 71bb31e commit de1b481

File tree

2 files changed

+79
-67
lines changed

2 files changed

+79
-67
lines changed

rxjava-core/src/main/java/rx/internal/operators/OperatorParallel.java

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
*/
1616
package rx.internal.operators;
1717

18-
import java.util.concurrent.atomic.AtomicReference;
18+
import java.util.concurrent.atomic.AtomicBoolean;
19+
import java.util.concurrent.atomic.AtomicLong;
1920

2021
import rx.Observable;
22+
import rx.Observable.OnSubscribe;
2123
import rx.Observable.Operator;
24+
import rx.Producer;
2225
import rx.Scheduler;
2326
import rx.Subscriber;
2427
import rx.functions.Func1;
25-
import rx.subjects.Subject;
2628

2729
/**
2830
* Identifies unit of work that can be executed in parallel on a given Scheduler.
@@ -41,85 +43,99 @@ public OperatorParallel(Func1<Observable<T>, Observable<R>> f, Scheduler schedul
4143

4244
@Override
4345
public Subscriber<? super T> call(final Subscriber<? super R> child) {
44-
45-
@SuppressWarnings("unchecked")
46-
final UnicastPassThruSubject<T>[] subjects = new UnicastPassThruSubject[degreeOfParallelism];
4746
@SuppressWarnings("unchecked")
4847
final Observable<R>[] os = new Observable[degreeOfParallelism];
49-
for (int i = 0; i < subjects.length; i++) {
50-
subjects[i] = UnicastPassThruSubject.<T> create();
51-
os[i] = f.call(subjects[i].observeOn(scheduler));
52-
}
53-
54-
// subscribe BEFORE receiving data so everything is hooked up
55-
Observable.merge(os).unsafeSubscribe(child);
56-
57-
return new Subscriber<T>(child) {
58-
59-
int index = 0; // trust that we receive data synchronously
60-
61-
@Override
62-
public void onCompleted() {
63-
for (UnicastPassThruSubject<T> s : subjects) {
64-
s.onCompleted();
65-
}
66-
}
67-
68-
@Override
69-
public void onError(Throwable e) {
70-
// bypass the subjects and immediately terminate
71-
child.onError(e);
72-
}
48+
@SuppressWarnings("unchecked")
49+
final Subscriber<? super T>[] ss = new Subscriber[degreeOfParallelism];
50+
final ParentSubscriber subscriber = new ParentSubscriber(child, ss);
51+
for (int i = 0; i < os.length; i++) {
52+
final int index = i;
53+
Observable<T> o = Observable.create(new OnSubscribe<T>() {
7354

74-
@Override
75-
public void onNext(T t) {
76-
// round-robin subjects
77-
subjects[index++].onNext(t);
78-
if (index >= degreeOfParallelism) {
79-
index = 0;
55+
@Override
56+
public void call(Subscriber<? super T> inner) {
57+
ss[index] = inner;
58+
child.add(inner); // unsubscribe chain
59+
inner.setProducer(new Producer() {
60+
61+
@Override
62+
public void request(long n) {
63+
// as we receive requests from observeOn propagate upstream to the parent Subscriber
64+
subscriber.requestMore(n);
65+
}
66+
67+
});
8068
}
81-
}
8269

83-
};
70+
});
71+
os[i] = f.call(o.observeOn(scheduler));
72+
}
8473

74+
// subscribe BEFORE receiving data so everything is hooked up
75+
Observable.merge(os).unsafeSubscribe(child);
76+
return subscriber;
8577
}
8678

87-
private static class UnicastPassThruSubject<T> extends Subject<T, T> {
88-
89-
private static <T> UnicastPassThruSubject<T> create() {
90-
final AtomicReference<Subscriber<? super T>> subscriber = new AtomicReference<Subscriber<? super T>>();
91-
return new UnicastPassThruSubject<T>(subscriber, new OnSubscribe<T>() {
79+
private class ParentSubscriber extends Subscriber<T> {
9280

93-
@Override
94-
public void call(Subscriber<? super T> s) {
95-
subscriber.set(s);
96-
}
97-
98-
});
81+
final Subscriber<? super R> child;
82+
final Subscriber<? super T>[] ss;
83+
int index = 0;
84+
final AtomicLong initialRequest = new AtomicLong();
85+
final AtomicBoolean started = new AtomicBoolean();
9986

87+
private ParentSubscriber(Subscriber<? super R> child, Subscriber<? super T>[] ss) {
88+
super(child);
89+
this.child = child;
90+
this.ss = ss;
10091
}
10192

102-
private final AtomicReference<Subscriber<? super T>> subscriber;
93+
public void requestMore(long n) {
94+
if (started.get()) {
95+
request(n);
96+
} else {
97+
initialRequest.addAndGet(n);
98+
}
99+
}
103100

104-
protected UnicastPassThruSubject(AtomicReference<Subscriber<? super T>> subscriber, OnSubscribe<T> onSubscribe) {
105-
super(onSubscribe);
106-
this.subscriber = subscriber;
101+
@Override
102+
public void onStart() {
103+
if (started.compareAndSet(false, true)) {
104+
// if no request via requestMore has been sent yet, we start with 0 (rather than default Long.MAX_VALUE).
105+
request(initialRequest.get());
106+
}
107107
}
108108

109109
@Override
110110
public void onCompleted() {
111-
subscriber.get().onCompleted();
111+
for (Subscriber<? super T> s : ss) {
112+
s.onCompleted();
113+
}
112114
}
113115

114116
@Override
115117
public void onError(Throwable e) {
116-
subscriber.get().onError(e);
118+
child.onError(e);
117119
}
118120

119121
@Override
120122
public void onNext(T t) {
121-
subscriber.get().onNext(t);
123+
/*
124+
* There is a possible bug here ... we could get a MissingBackpressureException
125+
* if the processing on each of the threads is unbalanced. In other words, if 1 of the
126+
* observeOn queues has space, but another is full, this could try emitting to one that
127+
* is full and get a MissingBackpressureException.
128+
*
129+
* To solve that we'd need to check the outstanding request per Subscriber, which will
130+
* need a more complicated mechanism to expose a type that has both the requested + the
131+
* Subscriber to emit to.
132+
*/
133+
ss[index++].onNext(t);
134+
if (index >= degreeOfParallelism) {
135+
index = 0;
136+
}
122137
}
123138

124-
}
139+
};
140+
125141
}

rxjava-core/src/test/java/rx/internal/operators/OperatorParallelTest.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.concurrent.TimeUnit;
2222
import java.util.concurrent.atomic.AtomicInteger;
2323

24-
import org.junit.Ignore;
2524
import org.junit.Test;
2625

2726
import rx.Observable;
@@ -57,7 +56,7 @@ public Integer[] call(Integer t) {
5756
// TODO why is this exception not being thrown?
5857
throw new RuntimeException(e);
5958
}
60-
// System.out.println("V: " + t + " Thread: " + Thread.currentThread());
59+
// System.out.println("V: " + t + " Thread: " + Thread.currentThread());
6160
innerCount.incrementAndGet();
6261
return new Integer[] { t, t * 99 };
6362
}
@@ -112,8 +111,6 @@ public void call(String v) {
112111
assertEquals(NUM, count.get());
113112
}
114113

115-
// parallel does not support backpressure right now
116-
@Ignore
117114
@Test
118115
public void testBackpressureViaOuterObserveOn() {
119116
final AtomicInteger emitted = new AtomicInteger();
@@ -148,12 +145,10 @@ public String call(Integer t) {
148145
ts.awaitTerminalEvent();
149146
ts.assertNoErrors();
150147
System.out.println("testBackpressureViaObserveOn emitted => " + emitted.get());
151-
assertTrue(emitted.get() < 2000 + RxRingBuffer.SIZE); // should have no more than the buffer size beyond the 2000 in take
152-
assertEquals(2000, ts.getOnNextEvents().size());
148+
assertTrue(emitted.get() < 20000 + (RxRingBuffer.SIZE * Schedulers.computation().parallelism())); // should have no more than the buffer size beyond the 20000 in take
149+
assertEquals(20000, ts.getOnNextEvents().size());
153150
}
154151

155-
// parallel does not support backpressure right now
156-
@Ignore
157152
@Test
158153
public void testBackpressureOnInnerObserveOn() {
159154
final AtomicInteger emitted = new AtomicInteger();
@@ -188,8 +183,8 @@ public String call(Integer t) {
188183
ts.awaitTerminalEvent();
189184
ts.assertNoErrors();
190185
System.out.println("testBackpressureViaObserveOn emitted => " + emitted.get());
191-
assertTrue(emitted.get() < 20000 + RxRingBuffer.SIZE); // should have no more than the buffer size beyond the 2000 in take
192-
assertEquals(2000, ts.getOnNextEvents().size());
186+
assertTrue(emitted.get() < 20000 + (RxRingBuffer.SIZE * Schedulers.computation().parallelism())); // should have no more than the buffer size beyond the 20000 in take
187+
assertEquals(20000, ts.getOnNextEvents().size());
193188
}
194189

195190
@Test(timeout = 10000)
@@ -226,7 +221,8 @@ public String call(Integer t) {
226221
ts.awaitTerminalEvent();
227222
ts.assertNoErrors();
228223
System.out.println("emitted: " + emitted.get());
229-
assertEquals(2000, emitted.get()); // no async, so should be perfect
224+
// we allow buffering inside each parallel Observable
225+
assertEquals(RxRingBuffer.SIZE * Schedulers.computation().parallelism(), emitted.get()); // no async, so should be perfect
230226
assertEquals(2000, ts.getOnNextEvents().size());
231227
}
232228
}

0 commit comments

Comments
 (0)