Skip to content

Commit 8c2986d

Browse files
Merge pull request #1695 from davidmoten/refcount-1688
rewrite OnSubscribeRefCount to handle synchronous source
2 parents ab85374 + b8da4a9 commit 8c2986d

File tree

3 files changed

+150
-166
lines changed

3 files changed

+150
-166
lines changed

rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java

Lines changed: 93 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -15,195 +15,123 @@
1515
*/
1616
package rx.internal.operators;
1717

18-
import java.util.ArrayList;
19-
import java.util.List;
20-
import java.util.Map;
21-
import java.util.WeakHashMap;
18+
import java.util.concurrent.atomic.AtomicBoolean;
19+
import java.util.concurrent.atomic.AtomicInteger;
20+
import java.util.concurrent.locks.ReentrantLock;
21+
2222
import rx.Observable.OnSubscribe;
2323
import rx.Subscriber;
2424
import rx.Subscription;
2525
import rx.functions.Action0;
26+
import rx.functions.Action1;
2627
import rx.observables.ConnectableObservable;
28+
import rx.subscriptions.CompositeSubscription;
2729
import rx.subscriptions.Subscriptions;
2830

2931
/**
30-
* Returns an observable sequence that stays connected to the source as long
31-
* as there is at least one subscription to the observable sequence.
32-
* @param <T> the value type
32+
* Returns an observable sequence that stays connected to the source as long as
33+
* there is at least one subscription to the observable sequence.
34+
*
35+
* @param <T>
36+
* the value type
3337
*/
3438
public final class OnSubscribeRefCount<T> implements OnSubscribe<T> {
35-
final ConnectableObservable<? extends T> source;
36-
final Object guard;
37-
/** Guarded by guard. */
38-
int index;
39-
/** Guarded by guard. */
40-
boolean emitting;
41-
/** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */
42-
List<Token> queue;
43-
/** Manipulated while in the serialized section. */
44-
int count;
45-
/** Manipulated while in the serialized section. */
46-
Subscription connection;
47-
/** Manipulated while in the serialized section. */
48-
final Map<Token, Object> connectionStatus;
49-
/** Occupied indicator. */
50-
private static final Object OCCUPIED = new Object();
39+
40+
private final ConnectableObservable<? extends T> source;
41+
private volatile CompositeSubscription baseSubscription = new CompositeSubscription();
42+
private final AtomicInteger subscriptionCount = new AtomicInteger(0);
43+
44+
/**
45+
* Use this lock for every subscription and disconnect action.
46+
*/
47+
private final ReentrantLock lock = new ReentrantLock();
48+
49+
/**
50+
* Constructor.
51+
*
52+
* @param source
53+
* observable to apply ref count to
54+
*/
5155
public OnSubscribeRefCount(ConnectableObservable<? extends T> source) {
5256
this.source = source;
53-
this.guard = new Object();
54-
this.connectionStatus = new WeakHashMap<Token, Object>();
5557
}
5658

5759
@Override
58-
public void call(Subscriber<? super T> t1) {
59-
int id;
60-
synchronized (guard) {
61-
id = ++index;
62-
}
63-
final Token t = new Token(id);
64-
t1.add(Subscriptions.create(new Action0() {
65-
@Override
66-
public void call() {
67-
disconnect(t);
68-
}
69-
}));
70-
source.unsafeSubscribe(t1);
71-
connect(t);
72-
}
73-
private void connect(Token id) {
74-
List<Token> localQueue;
75-
synchronized (guard) {
76-
if (emitting) {
77-
if (queue == null) {
78-
queue = new ArrayList<Token>();
79-
}
80-
queue.add(id);
81-
return;
82-
}
83-
84-
localQueue = queue;
85-
queue = null;
86-
emitting = true;
87-
}
88-
boolean once = true;
89-
do {
90-
drain(localQueue);
91-
if (once) {
92-
once = false;
93-
doConnect(id);
94-
}
95-
synchronized (guard) {
96-
localQueue = queue;
97-
queue = null;
98-
if (localQueue == null) {
99-
emitting = false;
100-
return;
101-
}
102-
}
103-
} while (true);
104-
}
105-
private void disconnect(Token id) {
106-
List<Token> localQueue;
107-
synchronized (guard) {
108-
if (emitting) {
109-
if (queue == null) {
110-
queue = new ArrayList<Token>();
111-
}
112-
queue.add(id.toDisconnect()); // negative value indicates disconnect
113-
return;
114-
}
115-
116-
localQueue = queue;
117-
queue = null;
118-
emitting = true;
119-
}
120-
boolean once = true;
121-
do {
122-
drain(localQueue);
123-
if (once) {
124-
once = false;
125-
doDisconnect(id);
126-
}
127-
synchronized (guard) {
128-
localQueue = queue;
129-
queue = null;
130-
if (localQueue == null) {
131-
emitting = false;
132-
return;
60+
public void call(final Subscriber<? super T> subscriber) {
61+
62+
lock.lock();
63+
if (subscriptionCount.incrementAndGet() == 1) {
64+
65+
final AtomicBoolean writeLocked = new AtomicBoolean(true);
66+
67+
try {
68+
// need to use this overload of connect to ensure that
69+
// baseSubscription is set in the case that source is a
70+
// synchronous Observable
71+
source.connect(onSubscribe(subscriber, writeLocked));
72+
} finally {
73+
// need to cover the case where the source is subscribed to
74+
// outside of this class thus preventing the above Action1
75+
// being called
76+
if (writeLocked.get()) {
77+
// Action1 was not called
78+
lock.unlock();
13379
}
13480
}
135-
} while (true);
136-
}
137-
private void drain(List<Token> localQueue) {
138-
if (localQueue == null) {
139-
return;
140-
}
141-
int n = localQueue.size();
142-
for (int i = 0; i < n; i++) {
143-
Token id = localQueue.get(i);
144-
if (id.isDisconnect()) {
145-
doDisconnect(id);
146-
} else {
147-
doConnect(id);
148-
}
149-
}
150-
}
151-
private void doConnect(Token id) {
152-
// this method is called only once per id
153-
// if add succeeds, id was not yet disconnected
154-
if (connectionStatus.put(id, OCCUPIED) == null) {
155-
if (count++ == 0) {
156-
connection = source.connect();
157-
}
15881
} else {
159-
// connection exists due to disconnect, just remove
160-
connectionStatus.remove(id);
161-
}
162-
}
163-
private void doDisconnect(Token id) {
164-
// this method is called only once per id
165-
// if remove succeeds, id was connected
166-
if (connectionStatus.remove(id) != null) {
167-
if (--count == 0) {
168-
connection.unsubscribe();
169-
connection = null;
82+
try {
83+
// handle unsubscribing from the base subscription
84+
subscriber.add(disconnect());
85+
86+
// ready to subscribe to source so do it
87+
source.unsafeSubscribe(subscriber);
88+
} finally {
89+
// release the read lock
90+
lock.unlock();
17091
}
171-
} else {
172-
// mark id as if connected
173-
connectionStatus.put(id, OCCUPIED);
17492
}
93+
17594
}
176-
/** Token that represens a connection request or a disconnection request. */
177-
private static final class Token {
178-
final int id;
179-
public Token(int id) {
180-
this.id = id;
181-
}
18295

183-
@Override
184-
public boolean equals(Object obj) {
185-
if (obj == null) {
186-
return false;
187-
}
188-
if (obj.getClass() != getClass()) {
189-
return false;
96+
private Action1<Subscription> onSubscribe(final Subscriber<? super T> subscriber,
97+
final AtomicBoolean writeLocked) {
98+
return new Action1<Subscription>() {
99+
@Override
100+
public void call(Subscription subscription) {
101+
102+
try {
103+
baseSubscription.add(subscription);
104+
105+
// handle unsubscribing from the base subscription
106+
subscriber.add(disconnect());
107+
108+
// ready to subscribe to source so do it
109+
source.unsafeSubscribe(subscriber);
110+
} finally {
111+
// release the write lock
112+
lock.unlock();
113+
writeLocked.set(false);
114+
}
190115
}
191-
int other = ((Token)obj).id;
192-
return id == other || -id == other;
193-
}
116+
};
117+
}
194118

195-
@Override
196-
public int hashCode() {
197-
return id < 0 ? -id : id;
198-
}
199-
public boolean isDisconnect() {
200-
return id < 0;
201-
}
202-
public Token toDisconnect() {
203-
if (id < 0) {
204-
return this;
119+
private Subscription disconnect() {
120+
return Subscriptions.create(new Action0() {
121+
@Override
122+
public void call() {
123+
lock.lock();
124+
try {
125+
if (subscriptionCount.decrementAndGet() == 0) {
126+
baseSubscription.unsubscribe();
127+
// need a new baseSubscription because once
128+
// unsubscribed stays that way
129+
baseSubscription = new CompositeSubscription();
130+
}
131+
} finally {
132+
lock.unlock();
133+
}
205134
}
206-
return new Token(-id);
207-
}
135+
});
208136
}
209137
}

rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,13 @@ public void call() {
137137
}));
138138

139139
// now that everything is hooked up let's subscribe
140-
source.unsafeSubscribe(subscription);
140+
// as long as the subscription is not null
141+
boolean subscriptionIsNull;
142+
synchronized(guard) {
143+
subscriptionIsNull = subscription == null;
144+
}
145+
if (!subscriptionIsNull)
146+
source.unsafeSubscribe(subscription);
141147
}
142148
}
143149
}

rxjava-core/src/test/java/rx/RefCountTests.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package rx;
1717

1818
import static org.junit.Assert.assertEquals;
19+
import static org.junit.Assert.assertTrue;
1920
import static org.mockito.Matchers.any;
2021
import static org.mockito.Mockito.inOrder;
2122
import static org.mockito.Mockito.mock;
@@ -25,6 +26,7 @@
2526
import java.util.ArrayList;
2627
import java.util.Arrays;
2728
import java.util.List;
29+
import java.util.concurrent.CountDownLatch;
2830
import java.util.concurrent.TimeUnit;
2931
import java.util.concurrent.atomic.AtomicInteger;
3032

@@ -34,11 +36,14 @@
3436
import org.mockito.MockitoAnnotations;
3537

3638
import rx.Observable.OnSubscribe;
39+
import rx.Observable.Operator;
3740
import rx.functions.Action0;
3841
import rx.functions.Action1;
3942
import rx.functions.Func2;
43+
import rx.observables.ConnectableObservable;
4044
import rx.observers.Subscribers;
4145
import rx.observers.TestSubscriber;
46+
import rx.schedulers.Schedulers;
4247
import rx.schedulers.TestScheduler;
4348
import rx.subjects.ReplaySubject;
4449
import rx.subscriptions.Subscriptions;
@@ -237,4 +242,49 @@ public Integer call(Integer t1, Integer t2) {
237242
ts2.assertNoErrors();
238243
ts2.assertReceivedOnNext(Arrays.asList(30));
239244
}
245+
246+
@Test
247+
public void testRefCountUnsubscribeForSynchronousSource() throws InterruptedException {
248+
final CountDownLatch latch = new CountDownLatch(1);
249+
Observable<Long> o = synchronousInterval().lift(detectUnsubscription(latch));
250+
Subscriber<Long> sub = Subscribers.empty();
251+
o.publish().refCount().subscribeOn(Schedulers.computation()).subscribe(sub);
252+
Thread.sleep(100);
253+
sub.unsubscribe();
254+
assertTrue(latch.await(3, TimeUnit.SECONDS));
255+
}
256+
257+
@Test
258+
public void testSubscribeToPublishWithAlreadyUnsubscribedSubscriber() {
259+
Subscriber<Object> sub = Subscribers.empty();
260+
sub.unsubscribe();
261+
ConnectableObservable<Object> o = Observable.empty().publish();
262+
o.subscribe(sub);
263+
o.connect();
264+
}
265+
266+
private Operator<Long, Long> detectUnsubscription(final CountDownLatch latch) {
267+
return new Operator<Long,Long>(){
268+
@Override
269+
public Subscriber<? super Long> call(Subscriber<? super Long> subscriber) {
270+
latch.countDown();
271+
return Subscribers.from(subscriber);
272+
}};
273+
}
274+
275+
private Observable<Long> synchronousInterval() {
276+
return Observable.create(new OnSubscribe<Long>() {
277+
278+
@Override
279+
public void call(Subscriber<? super Long> subscriber) {
280+
while (!subscriber.isUnsubscribed()) {
281+
try {
282+
Thread.sleep(100);
283+
} catch (InterruptedException e) {
284+
}
285+
subscriber.onNext(1L);
286+
}
287+
}});
288+
}
289+
240290
}

0 commit comments

Comments
 (0)