3
3
import com .yannbriancon .exception .NPlusOneQueriesException ;
4
4
import org .hibernate .EmptyInterceptor ;
5
5
import org .hibernate .Transaction ;
6
- import org .hibernate .proxy .HibernateProxy ;
7
6
import org .slf4j .Logger ;
8
7
import org .slf4j .LoggerFactory ;
9
8
import org .springframework .boot .context .properties .EnableConfigurationProperties ;
10
9
import org .springframework .stereotype .Component ;
11
10
11
+ import javax .persistence .EntityManager ;
12
12
import java .io .Serializable ;
13
13
import java .util .HashMap ;
14
14
import java .util .HashSet ;
@@ -24,22 +24,44 @@ public class HibernateQueryInterceptor extends EmptyInterceptor {
24
24
private transient ThreadLocal <Long > threadQueryCount = new ThreadLocal <>();
25
25
26
26
private transient ThreadLocal <Set <String >> threadPreviouslyLoadedEntities =
27
- ThreadLocal .withInitial (new EmptySetSupplier ());
27
+ ThreadLocal .withInitial (new EmptySetSupplier <> ());
28
28
29
- private transient ThreadLocal <Map <String , String >> threadProxyMethodEntityMapping =
30
- ThreadLocal .withInitial (new EmptyMapSupplier ());
29
+ private transient ThreadLocal <Map <String , SelectQueriesInfo >> threadSelectQueriesInfoPerProxyMethod =
30
+ ThreadLocal .withInitial (new EmptyMapSupplier <> ());
31
31
32
32
private static final Logger LOGGER = LoggerFactory .getLogger (HibernateQueryInterceptor .class );
33
33
34
34
private final HibernateQueryInterceptorProperties hibernateQueryInterceptorProperties ;
35
35
36
- private final String HIBERNATE_PROXY_PREFIX = "org.hibernate.proxy" ;
37
- private final String PROXY_METHOD_PREFIX = "com.sun.proxy" ;
36
+ private static final String HIBERNATE_PROXY_PREFIX = "org.hibernate.proxy" ;
37
+ private static final String PROXY_METHOD_PREFIX = "com.sun.proxy" ;
38
38
39
- public HibernateQueryInterceptor (HibernateQueryInterceptorProperties hibernateQueryInterceptorProperties ) {
39
+ public HibernateQueryInterceptor (
40
+ HibernateQueryInterceptorProperties hibernateQueryInterceptorProperties
41
+ ) {
40
42
this .hibernateQueryInterceptorProperties = hibernateQueryInterceptorProperties ;
41
43
}
42
44
45
+ /**
46
+ * Reset the N+1 query detection state
47
+ */
48
+ private void resetNPlusOneQueryDetectionState () {
49
+ threadPreviouslyLoadedEntities .set (new HashSet <>());
50
+ threadSelectQueriesInfoPerProxyMethod .set (new HashMap <>());
51
+ }
52
+
53
+ /**
54
+ * Clear the Hibernate Session and reset the N+1 query detection state
55
+ * <p>
56
+ * Clearing the Hibernate Session is necessary to detect N+1 queries in tests as they would be in production.
57
+ * Otherwise, every objects created in the setup of the tests would already be loaded in the Session and would
58
+ * hide potential N+1 queries.
59
+ */
60
+ public void clearNPlusOneQuerySession (EntityManager entityManager ) {
61
+ entityManager .clear ();
62
+ this .resetNPlusOneQueryDetectionState ();
63
+ }
64
+
43
65
/**
44
66
* Start or reset the query count to 0 for the considered thread
45
67
*/
@@ -55,13 +77,17 @@ public Long getQueryCount() {
55
77
}
56
78
57
79
/**
58
- * Increment the query count for the considered thread for each new statement if the count has been initialized
80
+ * Detect the N+1 queries by keeping the history of sql queries generated per proxy method.
81
+ * Increment the query count for the considered thread for each new statement if the count has been initialized.
59
82
*
60
83
* @param sql Query to be executed
61
84
* @return Query to be executed
62
85
*/
63
86
@ Override
64
87
public String onPrepareStatement (String sql ) {
88
+ if (hibernateQueryInterceptorProperties .isnPlusOneDetectionEnabled ()) {
89
+ updateSelectQueriesInfoPerProxyMethod (sql );
90
+ }
65
91
Long count = threadQueryCount .get ();
66
92
if (count != null ) {
67
93
threadQueryCount .set (count + 1 );
@@ -77,34 +103,20 @@ public String onPrepareStatement(String sql) {
77
103
*/
78
104
@ Override
79
105
public void afterTransactionCompletion (Transaction tx ) {
80
- threadPreviouslyLoadedEntities .set (new HashSet <>());
81
- threadProxyMethodEntityMapping .set (new HashMap <>());
106
+ this .resetNPlusOneQueryDetectionState ();
82
107
}
83
108
84
109
/**
85
- * Detect the N+1 queries by checking if two calls were made to getEntity for the same instance
86
- * <p>
87
- * The first call is made with the instance filled with a {@link HibernateProxy}
88
- * and the second is made after a query was executed to fetch the data in the Entity
110
+ * Detect the N+1 queries by keeping the history of the entities previously gotten.
89
111
*
90
112
* @param entityName Name of the entity to get
91
113
* @param id Id of the entity to get
92
114
*/
93
115
@ Override
94
116
public Object getEntity (String entityName , Serializable id ) {
95
117
if (hibernateQueryInterceptorProperties .isnPlusOneDetectionEnabled ()) {
96
- detectNPlusOneQueriesOfMissingQueryEagerFetching (entityName , id );
97
- detectNPlusOneQueriesOfMissingEntityFieldLazyFetching (entityName , id );
98
- }
99
-
100
- Set <String > previouslyLoadedEntities = threadPreviouslyLoadedEntities .get ();
101
-
102
- if (previouslyLoadedEntities .contains (entityName + id )) {
103
- previouslyLoadedEntities .remove (entityName + id );
104
- threadPreviouslyLoadedEntities .set (previouslyLoadedEntities );
105
- } else {
106
- previouslyLoadedEntities .add (entityName + id );
107
- threadPreviouslyLoadedEntities .set (previouslyLoadedEntities );
118
+ detectNPlusOneQueriesFromMissingEagerFetchingOnAQuery (entityName , id );
119
+ detectNPlusOneQueriesFromClassFieldEagerFetching (entityName );
108
120
}
109
121
110
122
return null ;
@@ -113,23 +125,24 @@ public Object getEntity(String entityName, Serializable id) {
113
125
/**
114
126
* Detect the N+1 queries caused by a missing eager fetching configuration on a query with a lazy loaded field
115
127
* <p>
116
- * <p>
117
128
* Detection checks:
118
129
* - The getEntity was called twice for the couple (entity, id)
119
- * <p>
120
130
* - There is an occurrence of hibernate proxy followed by entity class in the stackTraceElements
121
131
* Avoid detecting calls to queries like findById and queries with eager fetching on some entity fields
122
132
*
123
133
* @param entityName Name of the entity
124
134
* @param id Id of the entity objecy
125
- * @return Boolean telling whether N+1 queries were detected or not
126
135
*/
127
- private boolean detectNPlusOneQueriesOfMissingQueryEagerFetching (String entityName , Serializable id ) {
136
+ private void detectNPlusOneQueriesFromMissingEagerFetchingOnAQuery (String entityName , Serializable id ) {
128
137
Set <String > previouslyLoadedEntities = threadPreviouslyLoadedEntities .get ();
129
138
130
139
if (!previouslyLoadedEntities .contains (entityName + id )) {
131
- return false ;
140
+ previouslyLoadedEntities .add (entityName + id );
141
+ threadPreviouslyLoadedEntities .set (previouslyLoadedEntities );
142
+ return ;
132
143
}
144
+ previouslyLoadedEntities .remove (entityName + id );
145
+ threadPreviouslyLoadedEntities .set (previouslyLoadedEntities );
133
146
134
147
// Detect N+1 queries by searching for newest occurrence of Hibernate proxy followed by entity class in stack
135
148
// elements
@@ -147,70 +160,97 @@ private boolean detectNPlusOneQueriesOfMissingQueryEagerFetching(String entityNa
147
160
}
148
161
149
162
if (originStackTraceElement == null ) {
150
- return false ;
163
+ return ;
151
164
}
152
165
153
166
String errorMessage = "N+1 queries detected on a getter of the entity " + entityName +
154
167
"\n at " + originStackTraceElement .toString () +
155
168
"\n Hint: Missing Eager fetching configuration on the query that fetched the object of " +
156
169
"type " + entityName + "\n " ;
157
170
logDetectedNPlusOneQueries (errorMessage );
171
+ }
158
172
159
- return true ;
173
+ /**
174
+ * Update the select queries info per proxy method to be able to detect potential N+1 queries
175
+ * due to Eager Fetching on a field of a class
176
+ * <p>
177
+ * Checks:
178
+ * - Detect queries that would not fit the N+1 queries problem, non select queries, and remove the entry
179
+ * - Detect multiple calls to same proxy method and reset the entry to avoid false positive
180
+ * - Detect select queries that could be potential N+1 queries and increment the count
181
+ */
182
+ private void updateSelectQueriesInfoPerProxyMethod (String sql ) {
183
+ Optional <String > optionalProxyMethodName = getProxyMethodName ();
184
+ if (!optionalProxyMethodName .isPresent ()) {
185
+ return ;
186
+ }
187
+ String proxyMethodName = optionalProxyMethodName .get ();
188
+
189
+ boolean isSelectQuery = sql .toLowerCase ().startsWith ("select" );
190
+
191
+ Map <String , SelectQueriesInfo > selectQueriesInfoPerProxyMethod = threadSelectQueriesInfoPerProxyMethod .get ();
192
+
193
+ // The N+1 queries problem is only related to select queries
194
+ // So we remove the entry when detecting non select query for the proxy method
195
+ if (!isSelectQuery ) {
196
+ selectQueriesInfoPerProxyMethod .remove (proxyMethodName );
197
+ threadSelectQueriesInfoPerProxyMethod .set (selectQueriesInfoPerProxyMethod );
198
+ return ;
199
+ }
200
+
201
+ SelectQueriesInfo selectQueriesInfo = selectQueriesInfoPerProxyMethod .get (proxyMethodName );
202
+
203
+ // Handle several calls to the same proxy method by resetting the SelectQueriesInfo
204
+ // when the initial select query is detected
205
+ if (selectQueriesInfo == null || selectQueriesInfo .getInitialSelectQuery ().equals (sql )) {
206
+ selectQueriesInfoPerProxyMethod .put (proxyMethodName , new SelectQueriesInfo (sql ));
207
+ threadSelectQueriesInfoPerProxyMethod .set (selectQueriesInfoPerProxyMethod );
208
+ return ;
209
+ }
210
+
211
+ selectQueriesInfoPerProxyMethod .put (proxyMethodName , selectQueriesInfo .incrementSelectQueriesCount ());
212
+ threadSelectQueriesInfoPerProxyMethod .set (selectQueriesInfoPerProxyMethod );
160
213
}
161
214
162
215
/**
163
216
* Detect the N+1 queries caused by a missing lazy fetching configuration on an entity field
164
217
* <p>
165
- * Detection checks:
166
- * - The getEntity was called twice for the couple (entity, id)
167
- * <p>
168
- * - The query that triggered the fetching of the entity object was first called for a different entity
169
- * Avoid detecting calls to queries like findById
170
- *
171
- * @param entityName Name of the entity
172
- * @param id Id of the entity objecy
173
- * @return Boolean telling whether N+1 queries were detected or not
218
+ * Detection checks that several select queries were generated from the same proxy method
174
219
*/
175
- private boolean detectNPlusOneQueriesOfMissingEntityFieldLazyFetching (String entityName , Serializable id ) {
220
+ private void detectNPlusOneQueriesFromClassFieldEagerFetching (String entityName ) {
176
221
Optional <String > optionalProxyMethodName = getProxyMethodName ();
177
222
if (!optionalProxyMethodName .isPresent ()) {
178
- return false ;
223
+ return ;
179
224
}
180
225
String proxyMethodName = optionalProxyMethodName .get ();
181
226
182
- Set <String > previouslyLoadedEntities = threadPreviouslyLoadedEntities .get ();
183
- Map <String , String > proxyMethodEntityMapping = threadProxyMethodEntityMapping .get ();
184
-
185
- boolean nPlusOneQueriesDetected = false ;
186
- if (
187
- previouslyLoadedEntities .contains (entityName + id )
188
- && proxyMethodEntityMapping .containsKey (proxyMethodName )
189
- && !proxyMethodEntityMapping .get (proxyMethodName ).equals (entityName )
190
- ) {
191
- nPlusOneQueriesDetected = true ;
192
-
193
- String errorMessage = "N+1 queries detected on a query for the entity " + entityName ;
194
-
195
- // Find origin of the N+1 queries in client package
196
- // by getting oldest occurrence of proxy method in stack elements
197
- StackTraceElement [] stackTraceElements = Thread .currentThread ().getStackTrace ();
198
-
199
- for (int i = stackTraceElements .length - 1 ; i >= 1 ; i --) {
200
- if (stackTraceElements [i - 1 ].getClassName ().indexOf (PROXY_METHOD_PREFIX ) == 0 ) {
201
- errorMessage += "\n at " + stackTraceElements [i ].toString ();
202
- break ;
203
- }
204
- }
227
+ Map <String , SelectQueriesInfo > selectQueriesInfoPerProxyMethod = threadSelectQueriesInfoPerProxyMethod .get ();
228
+ SelectQueriesInfo selectQueriesInfo = selectQueriesInfoPerProxyMethod .get (proxyMethodName );
229
+ if (selectQueriesInfo == null || selectQueriesInfo .getSelectQueriesCount () < 2 ) {
230
+ return ;
231
+ }
205
232
206
- errorMessage += "\n Hint: Missing Lazy fetching configuration on a field of one of the entities " +
207
- "fetched in the query\n " ;
233
+ // Reset the count to 1 to log a message once per additional query
234
+ selectQueriesInfoPerProxyMethod .put (proxyMethodName , selectQueriesInfo .resetSelectQueriesCount ());
235
+ threadSelectQueriesInfoPerProxyMethod .set (selectQueriesInfoPerProxyMethod );
208
236
209
- logDetectedNPlusOneQueries (errorMessage );
237
+ String errorMessage = "N+1 queries detected with eager fetching on the entity " + entityName ;
238
+
239
+ // Find origin of the N+1 queries in client package
240
+ // by getting oldest occurrence of proxy method in stack elements
241
+ StackTraceElement [] stackTraceElements = Thread .currentThread ().getStackTrace ();
242
+
243
+ for (int i = stackTraceElements .length - 1 ; i >= 1 ; i --) {
244
+ if (stackTraceElements [i - 1 ].getClassName ().indexOf (PROXY_METHOD_PREFIX ) == 0 ) {
245
+ errorMessage += "\n at " + stackTraceElements [i ].toString ();
246
+ break ;
247
+ }
210
248
}
211
249
212
- proxyMethodEntityMapping .putIfAbsent (proxyMethodName , entityName );
213
- return nPlusOneQueriesDetected ;
250
+ errorMessage += "\n Hint: Missing Lazy fetching configuration on a field of type " + entityName + " of " +
251
+ "one of the entities fetched in the query\n " ;
252
+
253
+ logDetectedNPlusOneQueries (errorMessage );
214
254
}
215
255
216
256
/**
@@ -224,7 +264,7 @@ private Optional<String> getProxyMethodName() {
224
264
for (int i = stackTraceElements .length - 1 ; i >= 0 ; i --) {
225
265
StackTraceElement stackTraceElement = stackTraceElements [i ];
226
266
227
- if (stackTraceElement .getClassName ().indexOf ("com.sun.proxy" ) == 0 ) {
267
+ if (stackTraceElement .getClassName ().indexOf (PROXY_METHOD_PREFIX ) == 0 ) {
228
268
return Optional .of (stackTraceElement .getClassName () + stackTraceElement .getMethodName ());
229
269
}
230
270
}
@@ -249,19 +289,19 @@ private void logDetectedNPlusOneQueries(String errorMessage) {
249
289
LOGGER .error (errorMessage );
250
290
break ;
251
291
default :
252
- throw new NPlusOneQueriesException (errorMessage , new Exception ( new Throwable ()) );
292
+ throw new NPlusOneQueriesException (errorMessage );
253
293
}
254
294
}
255
295
}
256
296
257
- class EmptySetSupplier implements Supplier <Set <String >> {
258
- public Set <String > get () {
297
+ class EmptySetSupplier < T > implements Supplier <Set <T >> {
298
+ public Set <T > get () {
259
299
return new HashSet <>();
260
300
}
261
301
}
262
302
263
- class EmptyMapSupplier implements Supplier <Map <String , String >> {
264
- public Map <String , String > get () {
303
+ class EmptyMapSupplier < T > implements Supplier <Map <String , T >> {
304
+ public Map <String , T > get () {
265
305
return new HashMap <>();
266
306
}
267
307
}
0 commit comments