Skip to content

Commit 9cfbae8

Browse files
committed
fix: refactor
1 parent b04d847 commit 9cfbae8

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java

+39-25
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import java.util.Map;
2020
import java.util.Optional;
2121
import java.util.Set;
22-
import java.util.concurrent.atomic.AtomicReference;
2322
import java.util.function.Predicate;
2423
import java.util.stream.Collectors;
2524
import java.util.stream.Stream;
2625

2726
import org.slf4j.Logger;
2827
import org.slf4j.LoggerFactory;
2928
import org.springframework.ai.chat.client.ChatClient;
29+
import org.springframework.ai.chat.client.ChatClient.Builder;
3030
import org.springframework.ai.chat.messages.Message;
3131
import org.springframework.ai.chat.messages.UserMessage;
3232
import org.springframework.ai.chat.model.ChatResponse;
@@ -37,7 +37,6 @@
3737
import org.springframework.jdbc.core.JdbcTemplate;
3838
import org.springframework.jdbc.support.rowset.SqlRowSet;
3939
import org.springframework.stereotype.Service;
40-
import org.springframework.ai.chat.client.ChatClient.Builder;
4140

4241
import ch.xxx.aidoclibchat.domain.client.ImportClient;
4342
import ch.xxx.aidoclibchat.domain.common.MetaData;
@@ -99,6 +98,10 @@ Pay attention to use date('now') function to get the current date, if the questi
9998
@Value("${spring.profiles.active:}")
10099
private String activeProfile;
101100

101+
record MyTableData(String joinColumn, String joinTable, String columnValue, List<TableNameSchema> tableRecords,
102+
TableColumnNames tableColumnNames) {
103+
}
104+
102105
public TableService(ImportClient importClient, ImportService importService, Builder builder,
103106
JdbcTemplate jdbcTemplate, TableMetadataRepository tableMetadataRepository,
104107
DocumentVsRepository documentVsRepository) {
@@ -149,37 +152,48 @@ private Prompt createPrompt(SearchDto searchDto, EmbeddingContainer documentCont
149152
List<TableNameSchema> tableRecords = this.tableMetadataRepository
150153
.findByTableNameIn(tableColumnNames.tableNames()).stream()
151154
.map(tableMetaData -> new TableNameSchema(tableMetaData.getTableName(), tableMetaData.getTableDdl()))
152-
.collect(Collectors.toList());
153-
final AtomicReference<String> joinColumn = new AtomicReference<String>("");
154-
final AtomicReference<String> joinTable = new AtomicReference<String>("");
155-
final AtomicReference<String> columnValue = new AtomicReference<String>("");
156-
sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE)
155+
.collect(Collectors.toList());
156+
var result = sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE)
157157
.filter(myRowDoc -> tableRecords.stream()
158158
.filter(myRecord -> myRecord.name().equals(myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))
159159
.findFirst().isEmpty())
160-
.findFirst().ifPresent(myRowDoc -> {
161-
joinTable.set(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)));
162-
joinColumn.set(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
163-
tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
164-
columnValue.set(myRowDoc.getText());
165-
this.tableMetadataRepository
166-
.findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME))))
167-
.stream()
168-
.map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(),
169-
myTableMetadata.getTableDdl()))
170-
.findFirst().ifPresent(myRecord -> tableRecords.add(myRecord));
171-
});
172-
var messages = this.createMessages(searchDto, minRowDistance, tableColumnNames, tableRecords, joinColumn,
173-
joinTable, columnValue);
160+
.findFirst().map(myRowDoc -> createTableData(tableColumnNames, tableRecords, myRowDoc))
161+
.orElseThrow();
162+
var messages = this.createMessages(searchDto, minRowDistance, result.tableColumnNames(), result.tableRecords(), result.joinColumn(),
163+
result.joinTable(), result.columnValue());
174164
Prompt prompt = new Prompt(messages);
175165
// LOGGER.info("Prompt: {}", prompt.getContents());
176166
return prompt;
177167
}
178168

169+
private MyTableData createTableData(TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
170+
Document myRowDoc) {
171+
tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
172+
return findTable(myRowDoc).map(myRecord -> {
173+
tableRecords.add(myRecord);
174+
return createMyTableResult(tableColumnNames, tableRecords, myRowDoc);
175+
}).orElse(createMyTableResult(tableColumnNames, tableRecords, myRowDoc));
176+
}
177+
178+
private MyTableData createMyTableResult(TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
179+
Document myRowDoc) {
180+
return new MyTableData(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)),
181+
((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)), myRowDoc.getText(), tableRecords,
182+
tableColumnNames);
183+
}
184+
185+
private Optional<TableNameSchema> findTable(Document myRowDoc) {
186+
return this.tableMetadataRepository
187+
.findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))).stream()
188+
.map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(),
189+
myTableMetadata.getTableDdl()))
190+
.findFirst();
191+
}
192+
179193
private List<Message> createMessages(SearchDto searchDto, final Float minRowDistance,
180194
TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
181-
final AtomicReference<String> joinColumn, final AtomicReference<String> joinTable,
182-
final AtomicReference<String> columnValue) {
195+
final String joinColumn, final String joinTable,
196+
final String columnValue) {
183197
SystemPromptTemplate systemPromptTemplate = this.activeProfile.contains("ollama")
184198
? new SystemPromptTemplate(minRowDistance > MAX_ROW_DISTANCE ? String.format(this.ollamaPrompt, "")
185199
: String.format(this.ollamaPrompt, columnMatch))
@@ -188,8 +202,8 @@ private List<Message> createMessages(SearchDto searchDto, final Float minRowDist
188202
Message systemMessage = systemPromptTemplate.createMessage(
189203
Map.of("columns", tableColumnNames.columnNames().stream().collect(Collectors.joining(",")), "schemas",
190204
tableRecords.stream().map(myRecord -> myRecord.schema()).collect(Collectors.joining(";")),
191-
"prompt", searchDto.getSearchString(), "joinColumn", joinColumn.get(), "joinTable",
192-
joinTable.get(), "columnValue", columnValue.get()));
205+
"prompt", searchDto.getSearchString(), "joinColumn", joinColumn, "joinTable",
206+
joinTable, "columnValue", columnValue));
193207
UserMessage userMessage = this.activeProfile.contains("ollama") ? new UserMessage(systemMessage.getText())
194208
: new UserMessage(searchDto.getSearchString());
195209
return List.of(systemMessage, userMessage);

0 commit comments

Comments
 (0)