Skip to content

Commit a4af991

Browse files
committed
rewrite OnSubscribeRefCount to handle synchronous source, add null check to OperatorMulticast
1 parent e63a4cb commit a4af991

File tree

3 files changed

+125
-169
lines changed

3 files changed

+125
-169
lines changed

rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java

Lines changed: 73 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -15,195 +15,100 @@
1515
*/
1616
package rx.internal.operators;
1717

18-
import java.util.ArrayList;
19-
import java.util.List;
20-
import java.util.Map;
21-
import java.util.WeakHashMap;
18+
import java.util.concurrent.atomic.AtomicInteger;
19+
import java.util.concurrent.locks.ReadWriteLock;
20+
import java.util.concurrent.locks.ReentrantReadWriteLock;
21+
2222
import rx.Observable.OnSubscribe;
2323
import rx.Subscriber;
2424
import rx.Subscription;
2525
import rx.functions.Action0;
26+
import rx.functions.Action1;
2627
import rx.observables.ConnectableObservable;
2728
import rx.subscriptions.Subscriptions;
2829

2930
/**
30-
* Returns an observable sequence that stays connected to the source as long
31-
* as there is at least one subscription to the observable sequence.
32-
* @param <T> the value type
31+
* Returns an observable sequence that stays connected to the source as long as
32+
* there is at least one subscription to the observable sequence.
33+
*
34+
* @param <T>
35+
* the value type
3336
*/
3437
public final class OnSubscribeRefCount<T> implements OnSubscribe<T> {
35-
final ConnectableObservable<? extends T> source;
36-
final Object guard;
37-
/** Guarded by guard. */
38-
int index;
39-
/** Guarded by guard. */
40-
boolean emitting;
41-
/** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */
42-
List<Token> queue;
43-
/** Manipulated while in the serialized section. */
44-
int count;
45-
/** Manipulated while in the serialized section. */
46-
Subscription connection;
47-
/** Manipulated while in the serialized section. */
48-
final Map<Token, Object> connectionStatus;
49-
/** Occupied indicator. */
50-
private static final Object OCCUPIED = new Object();
38+
39+
private ConnectableObservable<? extends T> source;
40+
private volatile Subscription baseSubscription;
41+
private AtomicInteger subscriptionCount = new AtomicInteger(0);
42+
43+
/**
44+
* Ensures that subscribers wait for the first subscription to be assigned
45+
* to baseSubcription before being subscribed themselves.
46+
*/
47+
private final ReadWriteLock lock = new ReentrantReadWriteLock();
48+
49+
/**
50+
* Constructor.
51+
*
52+
* @param source
53+
* observable to apply ref count to
54+
*/
5155
public OnSubscribeRefCount(ConnectableObservable<? extends T> source) {
5256
this.source = source;
53-
this.guard = new Object();
54-
this.connectionStatus = new WeakHashMap<Token, Object>();
5557
}
5658

5759
@Override
58-
public void call(Subscriber<? super T> t1) {
59-
int id;
60-
synchronized (guard) {
61-
id = ++index;
62-
}
63-
final Token t = new Token(id);
64-
t1.add(Subscriptions.create(new Action0() {
65-
@Override
66-
public void call() {
67-
disconnect(t);
68-
}
69-
}));
70-
source.unsafeSubscribe(t1);
71-
connect(t);
72-
}
73-
private void connect(Token id) {
74-
List<Token> localQueue;
75-
synchronized (guard) {
76-
if (emitting) {
77-
if (queue == null) {
78-
queue = new ArrayList<Token>();
60+
public void call(final Subscriber<? super T> subscriber) {
61+
62+
// ensure secondary subscriptions wait for baseSubscription to be set by
63+
// first subscription
64+
lock.writeLock().lock();
65+
66+
if (subscriptionCount.incrementAndGet() == 1) {
67+
// need to use this overload of connect to ensure that
68+
// baseSubscription is set in the case that source is a synchronous
69+
// Observable
70+
source.connect(new Action1<Subscription>() {
71+
@Override
72+
public void call(Subscription subscription) {
73+
baseSubscription = subscription;
74+
75+
// handle unsubscribing from the base subscription
76+
subscriber.add(disconnect());
77+
78+
// ready to subscribe to source so do it
79+
source.unsafeSubscribe(subscriber);
80+
81+
// release the write lock
82+
lock.writeLock().unlock();
7983
}
80-
queue.add(id);
81-
return;
82-
}
84+
});
85+
} else {
86+
// release the write lock
87+
lock.writeLock().unlock();
88+
89+
// wait till baseSubscription set
90+
lock.readLock().lock();
91+
92+
// handle unsubscribing from the base subscription
93+
subscriber.add(disconnect());
8394

84-
localQueue = queue;
85-
queue = null;
86-
emitting = true;
87-
}
88-
boolean once = true;
89-
do {
90-
drain(localQueue);
91-
if (once) {
92-
once = false;
93-
doConnect(id);
94-
}
95-
synchronized (guard) {
96-
localQueue = queue;
97-
queue = null;
98-
if (localQueue == null) {
99-
emitting = false;
100-
return;
101-
}
102-
}
103-
} while (true);
104-
}
105-
private void disconnect(Token id) {
106-
List<Token> localQueue;
107-
synchronized (guard) {
108-
if (emitting) {
109-
if (queue == null) {
110-
queue = new ArrayList<Token>();
111-
}
112-
queue.add(id.toDisconnect()); // negative value indicates disconnect
113-
return;
114-
}
95+
// ready to subscribe to source so do it
96+
source.unsafeSubscribe(subscriber);
11597

116-
localQueue = queue;
117-
queue = null;
118-
emitting = true;
119-
}
120-
boolean once = true;
121-
do {
122-
drain(localQueue);
123-
if (once) {
124-
once = false;
125-
doDisconnect(id);
126-
}
127-
synchronized (guard) {
128-
localQueue = queue;
129-
queue = null;
130-
if (localQueue == null) {
131-
emitting = false;
132-
return;
133-
}
134-
}
135-
} while (true);
136-
}
137-
private void drain(List<Token> localQueue) {
138-
if (localQueue == null) {
139-
return;
140-
}
141-
int n = localQueue.size();
142-
for (int i = 0; i < n; i++) {
143-
Token id = localQueue.get(i);
144-
if (id.isDisconnect()) {
145-
doDisconnect(id);
146-
} else {
147-
doConnect(id);
148-
}
149-
}
150-
}
151-
private void doConnect(Token id) {
152-
// this method is called only once per id
153-
// if add succeeds, id was not yet disconnected
154-
if (connectionStatus.put(id, OCCUPIED) == null) {
155-
if (count++ == 0) {
156-
connection = source.connect();
157-
}
158-
} else {
159-
// connection exists due to disconnect, just remove
160-
connectionStatus.remove(id);
161-
}
162-
}
163-
private void doDisconnect(Token id) {
164-
// this method is called only once per id
165-
// if remove succeeds, id was connected
166-
if (connectionStatus.remove(id) != null) {
167-
if (--count == 0) {
168-
connection.unsubscribe();
169-
connection = null;
170-
}
171-
} else {
172-
// mark id as if connected
173-
connectionStatus.put(id, OCCUPIED);
174-
}
175-
}
176-
/** Token that represens a connection request or a disconnection request. */
177-
private static final class Token {
178-
final int id;
179-
public Token(int id) {
180-
this.id = id;
98+
//release the read lock
99+
lock.readLock().unlock();
181100
}
182101

183-
@Override
184-
public boolean equals(Object obj) {
185-
if (obj == null) {
186-
return false;
187-
}
188-
if (obj.getClass() != getClass()) {
189-
return false;
190-
}
191-
int other = ((Token)obj).id;
192-
return id == other || -id == other;
193-
}
102+
}
194103

195-
@Override
196-
public int hashCode() {
197-
return id < 0 ? -id : id;
198-
}
199-
public boolean isDisconnect() {
200-
return id < 0;
201-
}
202-
public Token toDisconnect() {
203-
if (id < 0) {
204-
return this;
104+
private Subscription disconnect() {
105+
return Subscriptions.create(new Action0() {
106+
@Override
107+
public void call() {
108+
if (subscriptionCount.decrementAndGet() == 0) {
109+
baseSubscription.unsubscribe();
110+
}
205111
}
206-
return new Token(-id);
207-
}
112+
});
208113
}
209114
}

rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ public void call() {
137137
}));
138138

139139
// now that everything is hooked up let's subscribe
140-
source.unsafeSubscribe(subscription);
140+
if (subscription!=null)
141+
source.unsafeSubscribe(subscription);
141142
}
142143
}
143144
}

rxjava-core/src/test/java/rx/RefCountTests.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package rx;
1717

1818
import static org.junit.Assert.assertEquals;
19+
import static org.junit.Assert.assertTrue;
1920
import static org.mockito.Matchers.any;
2021
import static org.mockito.Mockito.inOrder;
2122
import static org.mockito.Mockito.mock;
@@ -25,6 +26,7 @@
2526
import java.util.ArrayList;
2627
import java.util.Arrays;
2728
import java.util.List;
29+
import java.util.concurrent.CountDownLatch;
2830
import java.util.concurrent.TimeUnit;
2931
import java.util.concurrent.atomic.AtomicInteger;
3032

@@ -34,11 +36,14 @@
3436
import org.mockito.MockitoAnnotations;
3537

3638
import rx.Observable.OnSubscribe;
39+
import rx.Observable.Operator;
3740
import rx.functions.Action0;
3841
import rx.functions.Action1;
3942
import rx.functions.Func2;
43+
import rx.observables.ConnectableObservable;
4044
import rx.observers.Subscribers;
4145
import rx.observers.TestSubscriber;
46+
import rx.schedulers.Schedulers;
4247
import rx.schedulers.TestScheduler;
4348
import rx.subjects.ReplaySubject;
4449
import rx.subscriptions.Subscriptions;
@@ -237,4 +242,49 @@ public Integer call(Integer t1, Integer t2) {
237242
ts2.assertNoErrors();
238243
ts2.assertReceivedOnNext(Arrays.asList(30));
239244
}
245+
246+
@Test
247+
public void testRefCountUnsubscribeForSynchronousSource() throws InterruptedException {
248+
final CountDownLatch latch = new CountDownLatch(1);
249+
Observable<Long> o = synchronousInterval().lift(detectUnsubscription(latch));
250+
Subscriber<Long> sub = Subscribers.empty();
251+
o.publish().refCount().subscribeOn(Schedulers.computation()).subscribe(sub);
252+
Thread.sleep(100);
253+
sub.unsubscribe();
254+
assertTrue(latch.await(3, TimeUnit.SECONDS));
255+
}
256+
257+
@Test
258+
public void testSubscribeToPublishWithAlreadyUnsubscribedSubscriber() {
259+
Subscriber<Object> sub = Subscribers.empty();
260+
sub.unsubscribe();
261+
ConnectableObservable<Object> o = Observable.empty().publish();
262+
o.subscribe(sub);
263+
o.connect();
264+
}
265+
266+
private Operator<Long, Long> detectUnsubscription(final CountDownLatch latch) {
267+
return new Operator<Long,Long>(){
268+
@Override
269+
public Subscriber<? super Long> call(Subscriber<? super Long> subscriber) {
270+
latch.countDown();
271+
return Subscribers.from(subscriber);
272+
}};
273+
}
274+
275+
private Observable<Long> synchronousInterval() {
276+
return Observable.create(new OnSubscribe<Long>() {
277+
278+
@Override
279+
public void call(Subscriber<? super Long> subscriber) {
280+
while (!subscriber.isUnsubscribed()) {
281+
try {
282+
Thread.sleep(100);
283+
} catch (InterruptedException e) {
284+
}
285+
subscriber.onNext(1L);
286+
}
287+
}});
288+
}
289+
240290
}

0 commit comments

Comments
 (0)