1616
1717package com .palantir .javaformat .java ;
1818
19+ import static java .lang .Math .max ;
20+ import static java .nio .charset .StandardCharsets .UTF_8 ;
21+
1922import com .google .common .base .CharMatcher ;
2023import com .google .common .collect .HashMultimap ;
24+ import com .google .common .collect .ImmutableList ;
2125import com .google .common .collect .Multimap ;
2226import com .google .common .collect .Range ;
2327import com .google .common .collect .RangeMap ;
24- import com .google .common .collect .RangeSet ;
2528import com .google .common .collect .TreeRangeMap ;
26- import com .google .common .collect .TreeRangeSet ;
2729import com .palantir .javaformat .Newlines ;
2830import com .sun .source .doctree .DocCommentTree ;
2931import com .sun .source .doctree .ReferenceTree ;
3638import com .sun .source .util .TreePathScanner ;
3739import com .sun .source .util .TreeScanner ;
3840import com .sun .tools .javac .api .JavacTrees ;
41+ import com .sun .tools .javac .file .JavacFileManager ;
42+ import com .sun .tools .javac .parser .ParserFactory ;
3943import com .sun .tools .javac .tree .DCTree ;
4044import com .sun .tools .javac .tree .DCTree .DCReference ;
4145import com .sun .tools .javac .tree .JCTree ;
4246import com .sun .tools .javac .tree .JCTree .JCCompilationUnit ;
4347import com .sun .tools .javac .tree .JCTree .JCFieldAccess ;
44- import com .sun .tools .javac .tree .JCTree .JCIdent ;
4548import com .sun .tools .javac .tree .JCTree .JCImport ;
4649import com .sun .tools .javac .util .Context ;
50+ import com .sun .tools .javac .util .Log ;
4751import com .sun .tools .javac .util .Options ;
52+ import java .io .IOException ;
4853import java .lang .reflect .Method ;
54+ import java .net .URI ;
4955import java .util .LinkedHashSet ;
5056import java .util .List ;
5157import java .util .Map ;
5258import java .util .Set ;
59+ import javax .annotation .Nullable ;
60+ import javax .tools .DiagnosticCollector ;
61+ import javax .tools .DiagnosticListener ;
62+ import javax .tools .JavaFileObject ;
63+ import javax .tools .SimpleJavaFileObject ;
64+ import javax .tools .StandardLocation ;
5365
5466/**
5567 * Removes unused imports from a source file. Imports that are only used in javadoc are also removed, and the references
@@ -76,15 +88,12 @@ public class RemoveUnusedImports {
7688 private static final class UnusedImportScanner extends TreePathScanner <Void , Void > {
7789
7890 private final Set <String > usedNames = new LinkedHashSet <>();
79-
8091 private final Multimap <String , Range <Integer >> usedInJavadoc = HashMultimap .create ();
81-
82- final JavacTrees trees ;
83- final DocTreeScanner docTreeSymbolScanner ;
92+ private final DocTreeScanner docTreeSymbolScanner = new DocTreeScanner ();
93+ private final JavacTrees trees ;
8494
8595 private UnusedImportScanner (JavacTrees trees ) {
8696 this .trees = trees ;
87- docTreeSymbolScanner = new DocTreeScanner ();
8897 }
8998
9099 /** Skip the imports themselves when checking for usage. */
@@ -202,21 +211,49 @@ public Void visitIdentifier(IdentifierTree node, Void aVoid) {
202211 }
203212 }
204213
205- public static String removeUnusedImports (final String contents ) throws FormatterException {
214+ public static String removeUnusedImports (final String contents ) {
206215 Context context = new Context ();
207216 JCCompilationUnit unit = parse (context , contents );
208- if (unit == null ) {
209- // error handling is done during formatting
210- return contents ;
211- }
212217 UnusedImportScanner scanner = new UnusedImportScanner (JavacTrees .instance (context ));
213218 scanner .scan (unit , null );
214- return applyReplacements (contents , buildReplacements (contents , unit , scanner .usedNames , scanner .usedInJavadoc ));
219+ String s = applyReplacements (
220+ contents , buildReplacements (contents , unit , scanner .usedNames , scanner .usedInJavadoc ));
221+
222+ // Normalize newlines while preserving important blank lines
223+ String sep = Newlines .guessLineSeparator (contents );
224+
225+ // Ensure exactly one blank line after package declaration
226+ s = s .replaceAll ("(?m)^(package .+)" + sep + "\\ s+" + sep , "$1" + sep + sep );
227+
228+ // Ensure exactly one blank line between last import and class declaration
229+ s = s .replaceAll ("(?m)^(import .+)" + sep + "\\ s+" + sep + "(?=class|interface|enum|record)" , "$1" + sep + sep );
230+
231+ // Remove multiple blank lines elsewhere in imports section
232+ s = s .replaceAll ("(?m)^(import .+)" + sep + "\\ s+" + sep + "(?=import)" , "$1" + sep );
233+
234+ return s ;
215235 }
216236
217- private static JCCompilationUnit parse (Context context , String javaInput ) throws FormatterException {
237+ private static JCCompilationUnit parse (Context context , String javaInput ) {
238+ context .put (DiagnosticListener .class , new DiagnosticCollector <JavaFileObject >());
218239 Options .instance (context ).put ("allowStringFolding" , "false" );
219- return Formatter .parseJcCompilationUnit (context , javaInput );
240+ try (JavacFileManager fileManager = new JavacFileManager (context , true , UTF_8 )) {
241+ fileManager .setLocation (StandardLocation .PLATFORM_CLASS_PATH , ImmutableList .of ());
242+ } catch (IOException e ) {
243+ throw new RuntimeException (e );
244+ }
245+ SimpleJavaFileObject source = new SimpleJavaFileObject (URI .create ("source" ), JavaFileObject .Kind .SOURCE ) {
246+ @ Override
247+ public CharSequence getCharContent (boolean ignoreEncodingErrors ) {
248+ return javaInput ;
249+ }
250+ };
251+ Log .instance (context ).useSource (source );
252+ JCCompilationUnit unit = ParserFactory .instance (context )
253+ .newParser (javaInput , true , true , true )
254+ .parseCompilationUnit ();
255+ unit .sourcefile = source ;
256+ return unit ;
220257 }
221258
222259 /** Construct replacements to fix unused imports. */
@@ -226,70 +263,94 @@ private static RangeMap<Integer, String> buildReplacements(
226263 Set <String > usedNames ,
227264 Multimap <String , Range <Integer >> usedInJavadoc ) {
228265 RangeMap <Integer , String > replacements = TreeRangeMap .create ();
229- for (JCImport importTree : unit .getImports ()) {
230- String simpleName = getSimpleName (importTree );
231- if (!isUnused (unit , usedNames , usedInJavadoc , importTree , simpleName )) {
232- continue ;
233- }
234- // delete the import
235- int endPosition = importTree .getEndPosition (unit .endPositions );
236- endPosition = Math .max (CharMatcher .isNot (' ' ).indexIn (contents , endPosition ), endPosition );
237- String sep = Newlines .guessLineSeparator (contents );
238- if (endPosition + sep .length () < contents .length ()
239- && contents .subSequence (endPosition , endPosition + sep .length ())
240- .toString ()
241- .equals (sep )) {
266+ int size = unit .getImports ().size ();
267+ unit .getImports ().stream ()
268+ .filter (importTree -> isUnused (
269+ unit ,
270+ usedNames ,
271+ usedInJavadoc ,
272+ importTree ,
273+ getQualifiedIdentifier (importTree ).getIdentifier ().toString ()))
274+ .forEach (importTree -> replacements .put (
275+ Range .closedOpen (
276+ importTree .getStartPosition (),
277+ calculateEndPosition (
278+ contents ,
279+ importTree ,
280+ unit ,
281+ Newlines .guessLineSeparator (contents ),
282+ size ,
283+ size > 0 ? unit .getImports ().get (size - 1 ) : null )),
284+ "" ));
285+
286+ return replacements ;
287+ }
288+
289+ private static int calculateEndPosition (
290+ String contents ,
291+ JCTree importTree ,
292+ JCCompilationUnit unit ,
293+ String sep ,
294+ int size ,
295+ @ Nullable JCTree lastImport ) {
296+ int endPosition = importTree .getEndPosition (unit .endPositions );
297+ endPosition = max (CharMatcher .isNot (' ' ).indexIn (contents , endPosition ), endPosition );
298+ if (endPosition + sep .length () < contents .length ()
299+ && contents .subSequence (endPosition , endPosition + sep .length ())
300+ .toString ()
301+ .equals (sep )) {
302+ endPosition += sep .length ();
303+ }
304+ if ((size == 1 || importTree != lastImport ) && !checkForEmptyLineAfter (contents , endPosition , sep )) {
305+ while (endPosition + sep .length () <= contents .length ()
306+ && contents .regionMatches (endPosition , sep , 0 , sep .length ())) {
242307 endPosition += sep .length ();
243308 }
244- replacements .put (Range .closedOpen (importTree .getStartPosition (), endPosition ), "" );
245309 }
246- return replacements ;
310+ return endPosition ;
247311 }
248312
249- private static String getSimpleName (ImportTree importTree ) {
250- return importTree .getQualifiedIdentifier () instanceof JCIdent
251- ? ((JCIdent ) importTree .getQualifiedIdentifier ()).getName ().toString ()
252- : ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
253- .getIdentifier ()
254- .toString ();
313+ private static boolean checkForEmptyLineAfter (String contents , int endPosition , String sep ) {
314+ return endPosition + sep .length () * 2 <= contents .length ()
315+ && contents .substring (endPosition , endPosition + sep .length () * 2 )
316+ .equals (sep + sep );
255317 }
256318
257319 private static boolean isUnused (
258320 JCCompilationUnit unit ,
259321 Set <String > usedNames ,
260322 Multimap <String , Range <Integer >> usedInJavadoc ,
261- ImportTree importTree ,
323+ JCTree importTree ,
262324 String simpleName ) {
263- String qualifier = ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
264- .getExpression ()
265- .toString ();
325+ JCFieldAccess qualifiedIdentifier = getQualifiedIdentifier (importTree );
326+ String qualifier = qualifiedIdentifier .getExpression ().toString ();
266327 if (qualifier .equals ("java.lang" )) {
267328 return true ;
268329 }
330+ if (usedNames .contains (simpleName )) {
331+ return false ;
332+ }
269333 if (unit .getPackageName () != null && unit .getPackageName ().toString ().equals (qualifier )) {
270334 return true ;
271335 }
272- if (importTree .getQualifiedIdentifier () instanceof JCFieldAccess
273- && ((JCFieldAccess ) importTree .getQualifiedIdentifier ())
274- .getIdentifier ()
275- .contentEquals ("*" )) {
336+ if (qualifiedIdentifier .getIdentifier ().contentEquals ("*" ) && !((JCImport ) importTree ).isStatic ()) {
276337 return false ;
277338 }
339+ return !usedInJavadoc .containsKey (simpleName );
340+ }
278341
279- if (usedNames .contains (simpleName )) {
280- return false ;
342+ private static JCFieldAccess getQualifiedIdentifier (JCTree importTree ) {
343+ // Use reflection because the return type is JCTree in some versions and JCFieldAccess in others
344+ try {
345+ return (JCFieldAccess )
346+ JCImport .class .getMethod ("getQualifiedIdentifier" ).invoke (importTree );
347+ } catch (ReflectiveOperationException e ) {
348+ throw new RuntimeException (e );
281349 }
282- if (usedInJavadoc .containsKey (simpleName )) {
283- return false ;
284- }
285- return true ;
286350 }
287351
288352 /** Applies the replacements to the given source, and re-format any edited javadoc. */
289353 private static String applyReplacements (String source , RangeMap <Integer , String > replacements ) {
290- // save non-empty fixed ranges for reformatting after fixes are applied
291- RangeSet <Integer > fixedRanges = TreeRangeSet .create ();
292-
293354 // Apply the fixes in increasing order, adjusting ranges to account for
294355 // earlier fixes that change the length of the source. The output ranges are
295356 // needed so we can reformat fixed regions, otherwise the fixes could just
@@ -300,12 +361,7 @@ private static String applyReplacements(String source, RangeMap<Integer, String>
300361 replacements .asMapOfRanges ().entrySet ()) {
301362 Range <Integer > range = replacement .getKey ();
302363 String replaceWith = replacement .getValue ();
303- int start = offset + range .lowerEndpoint ();
304- int end = offset + range .upperEndpoint ();
305- sb .replace (start , end , replaceWith );
306- if (!replaceWith .isEmpty ()) {
307- fixedRanges .add (Range .closedOpen (start , end ));
308- }
364+ sb .replace (offset + range .lowerEndpoint (), offset + range .upperEndpoint (), replaceWith );
309365 offset += replaceWith .length () - (range .upperEndpoint () - range .lowerEndpoint ());
310366 }
311367 return sb .toString ();
0 commit comments