@@ -30,16 +30,16 @@ private enum RewriteAction
3030 public record RewriteContext (
3131 AnalysisType AnalysisType ,
3232 SyntaxNode Expression ,
33- SyntaxNode RootNode ,
33+ IEnumerable < SyntaxNode > RootNodes ,
3434 SemanticModel SemanticModel ,
3535 TypesProcessor TypesProcessor ,
3636 ConstantsMapper ConstantsMapper )
3737 {
38- public static RewriteContext Builders ( SyntaxNode Expression , SyntaxNode RootNode , SemanticModel SemanticModel , TypesProcessor TypesProcessor ) =>
39- new ( AnalysisType . Builders , Expression , RootNode , SemanticModel , TypesProcessor , new ( ) ) ;
38+ public static RewriteContext Builders ( SyntaxNode Expression , IEnumerable < SyntaxNode > RootNodes , SemanticModel SemanticModel , TypesProcessor TypesProcessor ) =>
39+ new ( AnalysisType . Builders , Expression , RootNodes , SemanticModel , TypesProcessor , new ( ) ) ;
4040
4141 public static RewriteContext Linq ( SyntaxNode Expression , SyntaxNode RootNode , SemanticModel SemanticModel , TypesProcessor TypesProcessor ) =>
42- new ( AnalysisType . Linq , Expression , RootNode , SemanticModel , TypesProcessor , new ( ) ) ;
42+ new ( AnalysisType . Linq , Expression , new SyntaxNode [ ] { RootNode } , SemanticModel , TypesProcessor , new ( ) ) ;
4343 }
4444
4545 private record RewriteResult (
@@ -56,42 +56,96 @@ public RewriteResult(SyntaxNode NodeToReplace, SyntaxNode NewNode) :
5656 public static RewriteResult Invalid { get ; } = new ( RewriteAction . Invalid , null , null ) ;
5757 }
5858
59- public static ( SyntaxNode RewrittenLinqExpression , ConstantsMapper ConstantsMapper ) RewriteExpression ( RewriteContext rewriteContext )
59+ private static bool RewriteRootNodes ( RewriteContext rewriteContext , HashSet < SyntaxNode > nodesProcessed , Dictionary < SyntaxNode , SyntaxNode > nodesRemapping )
6060 {
61- var nodesProcessed = new HashSet < SyntaxNode > ( ) ;
62- var nodesRemapping = new Dictionary < SyntaxNode , SyntaxNode > ( ) ;
63- var expressionNode = rewriteContext . Expression ;
64- var rootNode = rewriteContext . RootNode ;
61+ var rootNodes = rewriteContext . RootNodes ;
62+ var typesProcessor = rewriteContext . TypesProcessor ;
6563
66- // Register literals
67- foreach ( var literalSyntax in expressionNode . DescendantNodes ( ) . OfType < LiteralExpressionSyntax > ( ) )
64+ switch ( rewriteContext . AnalysisType )
6865 {
69- rewriteContext . ConstantsMapper . RegisterLiteral ( literalSyntax ) ;
66+ case AnalysisType . Builders :
67+ {
68+ foreach ( var rootNode in rootNodes )
69+ {
70+ var rootType = rewriteContext . SemanticModel . GetTypeInfo ( rootNode ) . Type as INamedTypeSymbol ;
71+ if ( rootType . IsSupportedIMongoCollection ( ) )
72+ {
73+ nodesProcessed . Add ( rootNode ) ;
74+ nodesRemapping . Add ( rootNode , SyntaxFactory . IdentifierName ( MqlGeneratorSyntaxElements . Builders . CollectionName ) ) ;
75+ }
76+ else if ( rootType . IsBuilder ( ) )
77+ {
78+ var rewrittenTypeArguments = new List < TypeSyntax > ( ) ;
79+
80+ foreach ( var rootTypeArgument in rootType . TypeArguments )
81+ {
82+ var remappedType = typesProcessor . GetTypeSymbolToGeneratedTypeMapping ( rootTypeArgument ) ;
83+
84+ if ( remappedType == null )
85+ {
86+ return false ;
87+ }
88+
89+ rewrittenTypeArguments . Add ( SyntaxFactory . ParseTypeName ( remappedType ) ) ;
90+ }
91+
92+ var buildersGenericType = SyntaxFactory . GenericName ( MqlGeneratorSyntaxElements . Builders . BuildersName ) . WithTypeArgumentList (
93+ SyntaxFactory . TypeArgumentList ( SyntaxFactory . SeparatedList ( rewrittenTypeArguments ) ) ) ;
94+
95+ var buildersDefinitionNode = SyntaxFactory . MemberAccessExpression (
96+ SyntaxKind . SimpleMemberAccessExpression ,
97+ buildersGenericType ,
98+ SyntaxFactory . IdentifierName ( rootType . GetBuilderDefinitionName ( ) ) ) ;
99+
100+ nodesProcessed . Add ( rootNode ) ;
101+ nodesRemapping . Add ( rootNode , buildersDefinitionNode ) ;
102+ }
103+ }
104+
105+ break ;
106+ }
107+ case AnalysisType . Linq :
108+ {
109+ foreach ( var rootNode in rootNodes )
110+ {
111+ nodesProcessed . Add ( rootNode ) ;
112+ nodesRemapping . Add ( rootNode , SyntaxFactory . IdentifierName ( MqlGeneratorSyntaxElements . Linq . QueryableName ) ) ;
113+ }
114+
115+ break ;
116+ }
117+ default :
118+ throw new ArgumentOutOfRangeException ( nameof ( rewriteContext . AnalysisType ) , rewriteContext . AnalysisType , "Unsupported analysis type" ) ;
70119 }
71- rewriteContext . ConstantsMapper . FinalizeLiteralsRegistration ( ) ;
120+
121+ return true ;
122+ }
123+
124+ private static bool RewriteIdentifiers ( RewriteContext rewriteContext , HashSet < SyntaxNode > nodesProcessed , Dictionary < SyntaxNode , SyntaxNode > nodesRemapping )
125+ {
126+ var expressionNode = rewriteContext . Expression ;
127+ var rootNodes = rewriteContext . RootNodes ;
128+ var typesProcessor = rewriteContext . TypesProcessor ;
72129
73130 // Set analysis specific parameters
74131 var processGenerics = false ;
75132 var removeFluentParameters = false ;
76133 IdentifierNameSyntax [ ] lambdaAndQueryIdentifiers = null ;
77- SyntaxNode rootNodeRemapped ;
78134 IEnumerable < SimpleNameSyntax > expressionDescendants ;
79135
80136 switch ( rewriteContext . AnalysisType )
81137 {
82138 case AnalysisType . Builders :
83139 {
84140 expressionDescendants = expressionNode . DescendantNodesWithSkipList < SimpleNameSyntax > ( nodesProcessed ) ;
85- rootNodeRemapped = SyntaxFactory . IdentifierName ( MqlGeneratorSyntaxElements . Builders . CollectionName ) ;
86-
87141 processGenerics = true ;
88142 removeFluentParameters = true ;
89143 break ;
90144 }
91145 case AnalysisType . Linq :
92146 {
93147 lambdaAndQueryIdentifiers = expressionNode
94- . DescendantNodes ( n => n != rootNode )
148+ . DescendantNodes ( n => ! rootNodes . Contains ( n ) )
95149 . OfType < IdentifierNameSyntax > ( )
96150 . Where ( identifierNode =>
97151 {
@@ -100,23 +154,16 @@ public static (SyntaxNode RewrittenLinqExpression, ConstantsMapper ConstantsMapp
100154 } )
101155 . ToArray ( ) ;
102156
103- expressionDescendants = expressionNode . DescendantNodes ( n => n != rootNode ) . OfType < IdentifierNameSyntax > ( ) ;
104- rootNodeRemapped = SyntaxFactory . IdentifierName ( MqlGeneratorSyntaxElements . Linq . QueryableName ) ;
157+ expressionDescendants = expressionNode . DescendantNodes ( n => ! rootNodes . Contains ( n ) ) . OfType < IdentifierNameSyntax > ( ) ;
105158 break ;
106159 }
107160 default :
108161 throw new ArgumentOutOfRangeException ( nameof ( rewriteContext . AnalysisType ) , rewriteContext . AnalysisType , "Unsupported analysis type" ) ;
109162 }
110163
111- if ( rootNode != null )
112- {
113- nodesProcessed . Add ( rootNode ) ;
114- nodesRemapping . Add ( rootNode , rootNodeRemapped ) ;
115- }
116-
117164 foreach ( var identifierNode in expressionDescendants )
118165 {
119- if ( identifierNode == rootNode ||
166+ if ( rootNodes . Contains ( identifierNode ) ||
120167 ! identifierNode . IsLeaf ( ) ||
121168 nodesProcessed . Any ( e => e . Contains ( identifierNode ) ) )
122169 {
@@ -132,7 +179,7 @@ public static (SyntaxNode RewrittenLinqExpression, ConstantsMapper ConstantsMapp
132179 var symbolInfo = rewriteContext . SemanticModel . GetSymbolInfo ( nodeToHandle ) ;
133180 if ( symbolInfo . Symbol == null )
134181 {
135- return default ;
182+ return false ;
136183 }
137184
138185 var typeInfo = rewriteContext . SemanticModel . GetTypeInfo ( nodeToHandle ) ;
@@ -158,20 +205,43 @@ SymbolKind.Parameter or
158205 {
159206 if ( rewriteResult . NodeToReplace != nodeToHandle )
160207 {
161- // not need to process NodeToReplace
162208 nodesProcessed . Add ( rewriteResult . NodeToReplace ) ;
163209 }
164210
165211 nodesRemapping [ rewriteResult . NodeToReplace ] = rewriteResult . NewNode ;
166212 break ;
167213 }
168214 case RewriteAction . Invalid :
169- return default ;
215+ {
216+ return false ;
217+ }
170218 default :
171219 continue ;
172220 }
173221 }
174222
223+ return true ;
224+ }
225+
226+ public static ( SyntaxNode RewrittenLinqExpression , ConstantsMapper ConstantsMapper ) RewriteExpression ( RewriteContext rewriteContext )
227+ {
228+ var nodesProcessed = new HashSet < SyntaxNode > ( ) ;
229+ var nodesRemapping = new Dictionary < SyntaxNode , SyntaxNode > ( ) ;
230+ var expressionNode = rewriteContext . Expression ;
231+
232+ // Register literals
233+ foreach ( var literalSyntax in expressionNode . DescendantNodes ( ) . OfType < LiteralExpressionSyntax > ( ) )
234+ {
235+ rewriteContext . ConstantsMapper . RegisterLiteral ( literalSyntax ) ;
236+ }
237+ rewriteContext . ConstantsMapper . FinalizeLiteralsRegistration ( ) ;
238+
239+ if ( ! RewriteRootNodes ( rewriteContext , nodesProcessed , nodesRemapping ) ||
240+ ! RewriteIdentifiers ( rewriteContext , nodesProcessed , nodesRemapping ) )
241+ {
242+ return default ;
243+ }
244+
175245 var result = expressionNode . ReplaceNodes ( nodesRemapping . Keys , ( n , _ ) => nodesRemapping [ n ] ) ;
176246
177247 return ( result , rewriteContext . ConstantsMapper ) ;
0 commit comments