Skip to content

Commit 2be241c

Browse files
dhoardfstab
authored andcommitted
Added defensive code for scenario where thread id <= 0
Signed-off-by: Doug Hoard <[email protected]>
1 parent 2f31b96 commit 2be241c

File tree

2 files changed

+128
-11
lines changed

2 files changed

+128
-11
lines changed

simpleclient_hotspot/src/main/java/io/prometheus/client/hotspot/ThreadExports.java

+29-10
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import io.prometheus.client.Collector;
44
import io.prometheus.client.CounterMetricFamily;
55
import io.prometheus.client.GaugeMetricFamily;
6-
import io.prometheus.client.SampleNameFilter;
76
import io.prometheus.client.Predicate;
87

98
import java.lang.management.ManagementFactory;
109
import java.lang.management.ThreadInfo;
1110
import java.lang.management.ThreadMXBean;
1211
import java.util.ArrayList;
12+
import java.util.Arrays;
1313
import java.util.Collections;
1414
import java.util.HashMap;
1515
import java.util.List;
@@ -36,13 +36,16 @@
3636
*/
3737
public class ThreadExports extends Collector {
3838

39+
public static final String UNKNOWN = "UNKNOWN";
40+
41+
public static final String JVM_THREADS_STATE = "jvm_threads_state";
42+
3943
private static final String JVM_THREADS_CURRENT = "jvm_threads_current";
4044
private static final String JVM_THREADS_DAEMON = "jvm_threads_daemon";
4145
private static final String JVM_THREADS_PEAK = "jvm_threads_peak";
4246
private static final String JVM_THREADS_STARTED_TOTAL = "jvm_threads_started_total";
4347
private static final String JVM_THREADS_DEADLOCKED = "jvm_threads_deadlocked";
4448
private static final String JVM_THREADS_DEADLOCKED_MONITOR = "jvm_threads_deadlocked_monitor";
45-
private static final String JVM_THREADS_STATE = "jvm_threads_state";
4649

4750
private final ThreadMXBean threadBean;
4851

@@ -109,35 +112,51 @@ void addThreadMetrics(List<MetricFamilySamples> sampleFamilies, Predicate<String
109112
"Current count of threads by state",
110113
Collections.singletonList("state"));
111114

112-
Map<Thread.State, Integer> threadStateCounts = getThreadStateCountMap();
113-
for (Map.Entry<Thread.State, Integer> entry : threadStateCounts.entrySet()) {
115+
Map<String, Integer> threadStateCounts = getThreadStateCountMap();
116+
for (Map.Entry<String, Integer> entry : threadStateCounts.entrySet()) {
114117
threadStateFamily.addMetric(
115-
Collections.singletonList(entry.getKey().toString()),
118+
Collections.singletonList(entry.getKey()),
116119
entry.getValue()
117120
);
118121
}
119122
sampleFamilies.add(threadStateFamily);
120123
}
121124
}
122125

123-
private Map<Thread.State, Integer> getThreadStateCountMap() {
126+
private Map<String, Integer> getThreadStateCountMap() {
127+
long[] threadIds = threadBean.getAllThreadIds();
128+
129+
// Code to remove any thread id values <= 0
130+
int writePos = 0;
131+
for (int i = 0; i < threadIds.length; i++) {
132+
if (threadIds[i] > 0) {
133+
threadIds[writePos++] = threadIds[i];
134+
}
135+
}
136+
137+
int numberOfInvalidThreadIds = threadIds.length - writePos;
138+
threadIds = Arrays.copyOf(threadIds, writePos);
139+
124140
// Get thread information without computing any stack traces
125-
ThreadInfo[] allThreads = threadBean.getThreadInfo(threadBean.getAllThreadIds(), 0);
141+
ThreadInfo[] allThreads = threadBean.getThreadInfo(threadIds, 0);
126142

127143
// Initialize the map with all thread states
128-
HashMap<Thread.State, Integer> threadCounts = new HashMap<Thread.State, Integer>();
144+
HashMap<String, Integer> threadCounts = new HashMap<String, Integer>();
129145
for (Thread.State state : Thread.State.values()) {
130-
threadCounts.put(state, 0);
146+
threadCounts.put(state.name(), 0);
131147
}
132148

133149
// Collect the actual thread counts
134150
for (ThreadInfo curThread : allThreads) {
135151
if (curThread != null) {
136152
Thread.State threadState = curThread.getThreadState();
137-
threadCounts.put(threadState, threadCounts.get(threadState) + 1);
153+
threadCounts.put(threadState.name(), threadCounts.get(threadState.name()) + 1);
138154
}
139155
}
140156

157+
// Add the thread count for invalid thread ids
158+
threadCounts.put(UNKNOWN, numberOfInvalidThreadIds);
159+
141160
return threadCounts;
142161
}
143162

simpleclient_hotspot/src/test/java/io/prometheus/client/hotspot/ThreadExportsTest.java

+99-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package io.prometheus.client.hotspot;
22

3+
import io.prometheus.client.Collector.MetricFamilySamples;
34
import io.prometheus.client.CollectorRegistry;
45
import org.junit.Before;
56
import org.junit.Test;
67
import org.mockito.Mockito;
78

89
import java.lang.management.ThreadInfo;
910
import java.lang.management.ThreadMXBean;
10-
import java.util.Arrays;
11+
import java.util.HashMap;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.concurrent.CountDownLatch;
1115

1216
import static org.junit.Assert.assertEquals;
1317
import static org.mockito.Mockito.when;
@@ -97,4 +101,98 @@ public void testThreadPools() {
97101
"jvm_threads_state", STATE_LABEL, STATE_TERMINATED_LABEL),
98102
.0000001);
99103
}
104+
105+
@Test
106+
public void testInvalidThreadIds() {
107+
ThreadExports threadExports = new ThreadExports();
108+
109+
// Number of threads to create with invalid thread ids
110+
int numberOfInvalidThreadIds = 2;
111+
112+
// Get the current thread state counts
113+
Map<String, Double> expectedThreadStateCountMap = new HashMap<String, Double>();
114+
List<MetricFamilySamples> metricFamilySamplesList = threadExports.collect();
115+
for (MetricFamilySamples metricFamilySamples : metricFamilySamplesList) {
116+
if (ThreadExports.JVM_THREADS_STATE.equals(metricFamilySamples.name)) {
117+
for (MetricFamilySamples.Sample sample : metricFamilySamples.samples) {
118+
expectedThreadStateCountMap.put(ThreadExports.JVM_THREADS_STATE + "-" + sample.labelValues.get(0), sample.value);
119+
}
120+
}
121+
}
122+
123+
// Add numberOfInvalidThreadIds to the expected UNKNOWN thread state count
124+
expectedThreadStateCountMap.put(
125+
ThreadExports.JVM_THREADS_STATE + "-" + ThreadExports.UNKNOWN,
126+
expectedThreadStateCountMap.get(
127+
ThreadExports.JVM_THREADS_STATE + "-" + ThreadExports.UNKNOWN) + numberOfInvalidThreadIds);
128+
129+
final CountDownLatch countDownLatch = new CountDownLatch(numberOfInvalidThreadIds);
130+
131+
try {
132+
// Create and start threads with invalid thread ids (id=0, id=-1, etc.)
133+
for (int i = 0; i < numberOfInvalidThreadIds; i++) {
134+
new TestThread(-i, new TestRunnable(countDownLatch)).start();
135+
}
136+
137+
// Get the current thread state counts
138+
Map<String, Double> actualThreadStateCountMap = new HashMap<String, Double>();
139+
metricFamilySamplesList = threadExports.collect();
140+
for (MetricFamilySamples metricFamilySamples : metricFamilySamplesList) {
141+
if (ThreadExports.JVM_THREADS_STATE.equals(metricFamilySamples.name)) {
142+
for (MetricFamilySamples.Sample sample : metricFamilySamples.samples) {
143+
actualThreadStateCountMap.put(ThreadExports.JVM_THREADS_STATE + "-" + sample.labelValues.get(0), sample.value);
144+
}
145+
}
146+
}
147+
148+
// Assert that we have the same number of thread states
149+
assertEquals(expectedThreadStateCountMap.size(), actualThreadStateCountMap.size());
150+
151+
// Check each thread state count
152+
for (String threadState : expectedThreadStateCountMap.keySet()) {
153+
double expectedThreadStateCount = expectedThreadStateCountMap.get(threadState);
154+
double actualThreadStateCount = actualThreadStateCountMap.get(threadState);
155+
156+
// Assert the expected and actual thread count states are equal
157+
assertEquals(expectedThreadStateCount, actualThreadStateCount, 0.0);
158+
}
159+
} finally {
160+
for (int i = 0; i < numberOfInvalidThreadIds; i++) {
161+
countDownLatch.countDown();
162+
}
163+
}
164+
}
165+
166+
private class TestThread extends Thread {
167+
168+
private long id;
169+
170+
public TestThread(long id, Runnable runnable) {
171+
super(runnable);
172+
setDaemon(true);
173+
this.id = id;
174+
}
175+
176+
public long getId() {
177+
return this.id;
178+
}
179+
}
180+
181+
private class TestRunnable implements Runnable {
182+
183+
private CountDownLatch countDownLatch;
184+
185+
public TestRunnable(CountDownLatch countDownLatch) {
186+
this.countDownLatch = countDownLatch;
187+
}
188+
189+
@Override
190+
public void run() {
191+
try {
192+
countDownLatch.await();
193+
} catch (InterruptedException e) {
194+
// DO NOTHING
195+
}
196+
}
197+
}
100198
}

0 commit comments

Comments
 (0)