19
19
import java .util .Map ;
20
20
import java .util .Optional ;
21
21
import java .util .Set ;
22
- import java .util .concurrent .atomic .AtomicReference ;
23
22
import java .util .function .Predicate ;
24
23
import java .util .stream .Collectors ;
25
24
import java .util .stream .Stream ;
26
25
27
26
import org .slf4j .Logger ;
28
27
import org .slf4j .LoggerFactory ;
29
28
import org .springframework .ai .chat .client .ChatClient ;
29
+ import org .springframework .ai .chat .client .ChatClient .Builder ;
30
30
import org .springframework .ai .chat .messages .Message ;
31
31
import org .springframework .ai .chat .messages .UserMessage ;
32
32
import org .springframework .ai .chat .model .ChatResponse ;
37
37
import org .springframework .jdbc .core .JdbcTemplate ;
38
38
import org .springframework .jdbc .support .rowset .SqlRowSet ;
39
39
import org .springframework .stereotype .Service ;
40
- import org .springframework .ai .chat .client .ChatClient .Builder ;
41
40
42
41
import ch .xxx .aidoclibchat .domain .client .ImportClient ;
43
42
import ch .xxx .aidoclibchat .domain .common .MetaData ;
@@ -99,6 +98,10 @@ Pay attention to use date('now') function to get the current date, if the questi
99
98
@ Value ("${spring.profiles.active:}" )
100
99
private String activeProfile ;
101
100
101
+ record MyTableData (String joinColumn , String joinTable , String columnValue , List <TableNameSchema > tableRecords ,
102
+ TableColumnNames tableColumnNames ) {
103
+ }
104
+
102
105
public TableService (ImportClient importClient , ImportService importService , Builder builder ,
103
106
JdbcTemplate jdbcTemplate , TableMetadataRepository tableMetadataRepository ,
104
107
DocumentVsRepository documentVsRepository ) {
@@ -149,37 +152,48 @@ private Prompt createPrompt(SearchDto searchDto, EmbeddingContainer documentCont
149
152
List <TableNameSchema > tableRecords = this .tableMetadataRepository
150
153
.findByTableNameIn (tableColumnNames .tableNames ()).stream ()
151
154
.map (tableMetaData -> new TableNameSchema (tableMetaData .getTableName (), tableMetaData .getTableDdl ()))
152
- .collect (Collectors .toList ());
153
- final AtomicReference <String > joinColumn = new AtomicReference <String >("" );
154
- final AtomicReference <String > joinTable = new AtomicReference <String >("" );
155
- final AtomicReference <String > columnValue = new AtomicReference <String >("" );
156
- sortedRowDocs .stream ().filter (myDoc -> minRowDistance <= MAX_ROW_DISTANCE )
155
+ .collect (Collectors .toList ());
156
+ var result = sortedRowDocs .stream ().filter (myDoc -> minRowDistance <= MAX_ROW_DISTANCE )
157
157
.filter (myRowDoc -> tableRecords .stream ()
158
158
.filter (myRecord -> myRecord .name ().equals (myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )))
159
159
.findFirst ().isEmpty ())
160
- .findFirst ().ifPresent (myRowDoc -> {
161
- joinTable .set (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )));
162
- joinColumn .set (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
163
- tableColumnNames .columnNames ().add (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
164
- columnValue .set (myRowDoc .getText ());
165
- this .tableMetadataRepository
166
- .findByTableNameIn (List .of (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME ))))
167
- .stream ()
168
- .map (myTableMetadata -> new TableNameSchema (myTableMetadata .getTableName (),
169
- myTableMetadata .getTableDdl ()))
170
- .findFirst ().ifPresent (myRecord -> tableRecords .add (myRecord ));
171
- });
172
- var messages = this .createMessages (searchDto , minRowDistance , tableColumnNames , tableRecords , joinColumn ,
173
- joinTable , columnValue );
160
+ .findFirst ().map (myRowDoc -> createTableData (tableColumnNames , tableRecords , myRowDoc ))
161
+ .orElseThrow ();
162
+ var messages = this .createMessages (searchDto , minRowDistance , result .tableColumnNames (), result .tableRecords (), result .joinColumn (),
163
+ result .joinTable (), result .columnValue ());
174
164
Prompt prompt = new Prompt (messages );
175
165
// LOGGER.info("Prompt: {}", prompt.getContents());
176
166
return prompt ;
177
167
}
178
168
169
+ private MyTableData createTableData (TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
170
+ Document myRowDoc ) {
171
+ tableColumnNames .columnNames ().add (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
172
+ return findTable (myRowDoc ).map (myRecord -> {
173
+ tableRecords .add (myRecord );
174
+ return createMyTableResult (tableColumnNames , tableRecords , myRowDoc );
175
+ }).orElse (createMyTableResult (tableColumnNames , tableRecords , myRowDoc ));
176
+ }
177
+
178
+ private MyTableData createMyTableResult (TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
179
+ Document myRowDoc ) {
180
+ return new MyTableData (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )),
181
+ ((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )), myRowDoc .getText (), tableRecords ,
182
+ tableColumnNames );
183
+ }
184
+
185
+ private Optional <TableNameSchema > findTable (Document myRowDoc ) {
186
+ return this .tableMetadataRepository
187
+ .findByTableNameIn (List .of (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )))).stream ()
188
+ .map (myTableMetadata -> new TableNameSchema (myTableMetadata .getTableName (),
189
+ myTableMetadata .getTableDdl ()))
190
+ .findFirst ();
191
+ }
192
+
179
193
private List <Message > createMessages (SearchDto searchDto , final Float minRowDistance ,
180
194
TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
181
- final AtomicReference < String > joinColumn , final AtomicReference < String > joinTable ,
182
- final AtomicReference < String > columnValue ) {
195
+ final String joinColumn , final String joinTable ,
196
+ final String columnValue ) {
183
197
SystemPromptTemplate systemPromptTemplate = this .activeProfile .contains ("ollama" )
184
198
? new SystemPromptTemplate (minRowDistance > MAX_ROW_DISTANCE ? String .format (this .ollamaPrompt , "" )
185
199
: String .format (this .ollamaPrompt , columnMatch ))
@@ -188,8 +202,8 @@ private List<Message> createMessages(SearchDto searchDto, final Float minRowDist
188
202
Message systemMessage = systemPromptTemplate .createMessage (
189
203
Map .of ("columns" , tableColumnNames .columnNames ().stream ().collect (Collectors .joining ("," )), "schemas" ,
190
204
tableRecords .stream ().map (myRecord -> myRecord .schema ()).collect (Collectors .joining (";" )),
191
- "prompt" , searchDto .getSearchString (), "joinColumn" , joinColumn . get () , "joinTable" ,
192
- joinTable . get () , "columnValue" , columnValue . get () ));
205
+ "prompt" , searchDto .getSearchString (), "joinColumn" , joinColumn , "joinTable" ,
206
+ joinTable , "columnValue" , columnValue ));
193
207
UserMessage userMessage = this .activeProfile .contains ("ollama" ) ? new UserMessage (systemMessage .getText ())
194
208
: new UserMessage (searchDto .getSearchString ());
195
209
return List .of (systemMessage , userMessage );
0 commit comments