|
23 | 23 | import java.util.ArrayList;
|
24 | 24 | import java.util.Arrays;
|
25 | 25 | import java.util.Collections;
|
26 |
| -import java.util.LinkedHashMap; |
27 | 26 | import java.util.List;
|
28 | 27 | import java.util.Map;
|
29 | 28 | import java.util.stream.Collectors;
|
30 |
| -import org.apache.polaris.core.entity.PolarisBaseEntity; |
31 | 29 | import org.apache.polaris.core.entity.PolarisEntityCore;
|
32 | 30 | import org.apache.polaris.core.entity.PolarisEntityId;
|
33 |
| -import org.apache.polaris.core.entity.PolarisEntityType; |
34 |
| -import org.apache.polaris.core.policy.PolicyEntity; |
35 |
| -import org.apache.polaris.persistence.relational.jdbc.models.Converter; |
36 | 31 | import org.apache.polaris.persistence.relational.jdbc.models.ModelEntity;
|
37 | 32 | import org.apache.polaris.persistence.relational.jdbc.models.ModelGrantRecord;
|
38 |
| -import org.apache.polaris.persistence.relational.jdbc.models.ModelPolicyMappingRecord; |
39 |
| -import org.apache.polaris.persistence.relational.jdbc.models.ModelPrincipalAuthenticationData; |
40 |
| -import org.slf4j.Logger; |
41 |
| -import org.slf4j.LoggerFactory; |
42 | 33 |
|
| 34 | +/** |
| 35 | + * Utility class to generate parameterized SQL queries (SELECT, INSERT, UPDATE, DELETE). Ensures |
| 36 | + * consistent SQL generation and protects against injection by managing parameters separately. |
| 37 | + */ |
43 | 38 | public class QueryGenerator {
|
44 |
| - private static final Logger log = LoggerFactory.getLogger(QueryGenerator.class); |
45 | 39 | private final DatabaseType databaseType;
|
46 | 40 |
|
47 | 41 | public QueryGenerator(DatabaseType databaseType) {
|
48 | 42 | this.databaseType = databaseType;
|
49 | 43 | }
|
50 | 44 |
|
| 45 | + /** |
| 46 | + * @return The configured database type. |
| 47 | + */ |
51 | 48 | public DatabaseType getDatabaseType() {
|
52 | 49 | return databaseType;
|
53 | 50 | }
|
54 | 51 |
|
55 |
| - public static class PreparedQuery { |
56 |
| - private final String sql; |
57 |
| - private final List<Object> parameters; |
58 |
| - |
59 |
| - public PreparedQuery(String sql, List<Object> parameters) { |
60 |
| - this.sql = sql; |
61 |
| - this.parameters = parameters; |
62 |
| - } |
63 |
| - |
64 |
| - public String getSql() { |
65 |
| - return sql; |
66 |
| - } |
67 |
| - |
68 |
| - public List<Object> getParameters() { |
69 |
| - return parameters; |
70 |
| - } |
71 |
| - } |
72 |
| - |
73 |
| - public <T> PreparedQuery generateSelectQuery( |
74 |
| - @Nonnull Converter<T> entity, @Nonnull Map<String, Object> whereClause) { |
75 |
| - |
76 |
| - String tableName = getTableName(entity.getClass()); |
77 |
| - Map<String, Object> objectMap = entity.toMap(databaseType); |
78 |
| - checkInvalidColumns(objectMap, whereClause); |
79 |
| - |
80 |
| - String columns = String.join(", ", objectMap.keySet()); |
81 |
| - PreparedQuery whereClauseQuery = generateWhereClause(whereClause); |
82 |
| - String sql = "SELECT " + columns + " FROM " + tableName + whereClauseQuery.getSql(); |
83 |
| - |
84 |
| - return new PreparedQuery(sql, whereClauseQuery.getParameters()); |
| 52 | + /** A container for the SQL string and the ordered parameter values. */ |
| 53 | + public record PreparedQuery(String sql, List<Object> parameters) {} |
| 54 | + |
| 55 | + /** |
| 56 | + * Generates a SELECT query with projection and filtering. |
| 57 | + * |
| 58 | + * @param projections List of columns to retrieve. |
| 59 | + * @param tableName Target table name. |
| 60 | + * @param whereClause Column-value pairs used in WHERE filtering. |
| 61 | + * @return A parameterized SELECT query. |
| 62 | + * @throws IllegalArgumentException if any whereClause column isn't in projections. |
| 63 | + */ |
| 64 | + public PreparedQuery generateSelectQuery( |
| 65 | + @Nonnull List<String> projections, |
| 66 | + @Nonnull String tableName, |
| 67 | + @Nonnull Map<String, Object> whereClause) { |
| 68 | + checkInvalidColumns(projections, whereClause); |
| 69 | + PreparedQuery where = generateWhereClause(whereClause); |
| 70 | + PreparedQuery query = generateSelectQuery(projections, tableName, where.sql()); |
| 71 | + return new PreparedQuery(query.sql(), where.parameters()); |
85 | 72 | }
|
86 | 73 |
|
| 74 | + /** |
| 75 | + * Builds a DELETE query to remove grant records for a given entity. |
| 76 | + * |
| 77 | + * @param entity The target entity (either grantee or securable). |
| 78 | + * @param realmId The associated realm. |
| 79 | + * @return A DELETE query removing all grants for this entity. |
| 80 | + */ |
87 | 81 | public PreparedQuery generateDeleteQueryForEntityGrantRecords(
|
88 | 82 | @Nonnull PolarisEntityCore entity, @Nonnull String realmId) {
|
89 | 83 | String where =
|
90 |
| - " WHERE (grantee_id = ? AND grantee_catalog_id = ? OR securable_id = ? AND securable_catalog_id = ?) AND realm_id = ?"; |
| 84 | + " WHERE ((grantee_id = ? AND grantee_catalog_id = ?) OR (securable_id = ? AND securable_catalog_id = ?)) AND realm_id = ?"; |
91 | 85 | List<Object> params =
|
92 | 86 | Arrays.asList(
|
93 | 87 | entity.getId(), entity.getCatalogId(), entity.getId(), entity.getCatalogId(), realmId);
|
94 |
| - |
95 |
| - return new PreparedQuery("DELETE FROM " + getTableName(ModelGrantRecord.class) + where, params); |
96 |
| - } |
97 |
| - |
98 |
| - public PreparedQuery generateDeleteQueryForEntityPolicyMappingRecords( |
99 |
| - @Nonnull PolarisBaseEntity entity, @Nonnull String realmId) { |
100 |
| - Map<String, Object> queryParams = new LinkedHashMap<>(); |
101 |
| - if (entity.getType() == PolarisEntityType.POLICY) { |
102 |
| - PolicyEntity policyEntity = PolicyEntity.of(entity); |
103 |
| - queryParams.put("policy_type_code", policyEntity.getPolicyTypeCode()); |
104 |
| - queryParams.put("policy_catalog_id", policyEntity.getCatalogId()); |
105 |
| - queryParams.put("policy_id", policyEntity.getId()); |
106 |
| - } else { |
107 |
| - queryParams.put("target_catalog_id", entity.getCatalogId()); |
108 |
| - queryParams.put("target_id", entity.getId()); |
109 |
| - } |
110 |
| - queryParams.put("realm_id", realmId); |
111 |
| - |
112 |
| - return generateDeleteQuery(new ModelPolicyMappingRecord(), queryParams); |
| 88 | + return new PreparedQuery( |
| 89 | + "DELETE FROM " + getFullyQualifiedTableName(ModelGrantRecord.TABLE_NAME) + where, params); |
113 | 90 | }
|
114 | 91 |
|
| 92 | + /** |
| 93 | + * Builds a SELECT query using a list of entity ID pairs (catalog_id, id). |
| 94 | + * |
| 95 | + * @param realmId Realm to filter by. |
| 96 | + * @param entityIds List of PolarisEntityId pairs. |
| 97 | + * @return SELECT query to retrieve matching entities. |
| 98 | + * @throws IllegalArgumentException if entityIds is empty. |
| 99 | + */ |
115 | 100 | public PreparedQuery generateSelectQueryWithEntityIds(
|
116 | 101 | @Nonnull String realmId, @Nonnull List<PolarisEntityId> entityIds) {
|
117 | 102 | if (entityIds.isEmpty()) {
|
118 | 103 | throw new IllegalArgumentException("Empty entity ids");
|
119 | 104 | }
|
120 |
| - |
121 | 105 | String placeholders = entityIds.stream().map(e -> "(?, ?)").collect(Collectors.joining(", "));
|
122 | 106 | List<Object> params = new ArrayList<>();
|
123 | 107 | for (PolarisEntityId id : entityIds) {
|
124 | 108 | params.add(id.getCatalogId());
|
125 | 109 | params.add(id.getId());
|
126 | 110 | }
|
127 | 111 | params.add(realmId);
|
128 |
| - |
129 | 112 | String where = " WHERE (catalog_id, id) IN (" + placeholders + ") AND realm_id = ?";
|
130 |
| - return new PreparedQuery(generateSelectQuery(new ModelEntity(), where).getSql(), params); |
| 113 | + return new PreparedQuery( |
| 114 | + generateSelectQuery(ModelEntity.ALL_COLUMNS, ModelEntity.TABLE_NAME, where).sql(), params); |
131 | 115 | }
|
132 | 116 |
|
| 117 | + /** |
| 118 | + * Generates an INSERT query for a given table. |
| 119 | + * |
| 120 | + * @param allColumns Columns to insert values into. |
| 121 | + * @param tableName Target table name. |
| 122 | + * @param values Values for each column (must match order of columns). |
| 123 | + * @param realmId Realm value to append. |
| 124 | + * @return INSERT query with value bindings. |
| 125 | + */ |
133 | 126 | public <T> PreparedQuery generateInsertQuery(
|
134 |
| - @Nonnull Converter<T> entity, @Nonnull String realmId) { |
135 |
| - String tableName = getTableName(entity.getClass()); |
136 |
| - Map<String, Object> obj = entity.toMap(databaseType); |
137 |
| - List<String> columnNames = new ArrayList<>(obj.keySet()); |
138 |
| - List<Object> parameters = new ArrayList<>(obj.values()); |
139 |
| - |
140 |
| - columnNames.add("realm_id"); |
141 |
| - parameters.add(realmId); |
142 |
| - |
143 |
| - String columns = String.join(", ", columnNames); |
144 |
| - String placeholders = columnNames.stream().map(c -> "?").collect(Collectors.joining(", ")); |
145 |
| - |
146 |
| - String sql = "INSERT INTO " + tableName + " (" + columns + ") VALUES (" + placeholders + ")"; |
147 |
| - return new PreparedQuery(sql, parameters); |
| 127 | + @Nonnull List<String> allColumns, |
| 128 | + @Nonnull String tableName, |
| 129 | + List<Object> values, |
| 130 | + String realmId) { |
| 131 | + List<String> finalColumns = new ArrayList<>(allColumns); |
| 132 | + List<Object> finalValues = new ArrayList<>(values); |
| 133 | + finalColumns.add("realm_id"); |
| 134 | + finalValues.add(realmId); |
| 135 | + String columns = String.join(", ", finalColumns); |
| 136 | + String placeholders = finalColumns.stream().map(c -> "?").collect(Collectors.joining(", ")); |
| 137 | + String sql = |
| 138 | + "INSERT INTO " |
| 139 | + + getFullyQualifiedTableName(tableName) |
| 140 | + + " (" |
| 141 | + + columns |
| 142 | + + ") VALUES (" |
| 143 | + + placeholders |
| 144 | + + ")"; |
| 145 | + return new PreparedQuery(sql, finalValues); |
148 | 146 | }
|
149 | 147 |
|
| 148 | + /** |
| 149 | + * Builds an UPDATE query. |
| 150 | + * |
| 151 | + * @param allColumns Columns to update. |
| 152 | + * @param tableName Target table. |
| 153 | + * @param values New values (must match columns in order). |
| 154 | + * @param whereClause Conditions for filtering rows to update. |
| 155 | + * @return UPDATE query with parameter values. |
| 156 | + */ |
150 | 157 | public <T> PreparedQuery generateUpdateQuery(
|
151 |
| - @Nonnull Converter<T> entity, @Nonnull Map<String, Object> whereClause) { |
152 |
| - String tableName = getTableName(entity.getClass()); |
153 |
| - Map<String, Object> obj = entity.toMap(databaseType); |
154 |
| - checkInvalidColumns(obj, whereClause); |
155 |
| - |
156 |
| - List<String> setClauses = new ArrayList<>(); |
157 |
| - List<Object> parameters = new ArrayList<>(); |
158 |
| - |
159 |
| - for (Map.Entry<String, Object> entry : obj.entrySet()) { |
160 |
| - setClauses.add(entry.getKey() + " = ?"); |
161 |
| - parameters.add(entry.getValue()); |
162 |
| - } |
163 |
| - |
164 |
| - List<String> whereConditions = new ArrayList<>(); |
165 |
| - for (Map.Entry<String, Object> entry : whereClause.entrySet()) { |
166 |
| - whereConditions.add(entry.getKey() + " = ?"); |
167 |
| - parameters.add(entry.getValue()); |
168 |
| - } |
169 |
| - |
170 |
| - String sql = "UPDATE " + tableName + " SET " + String.join(", ", setClauses); |
171 |
| - if (!whereConditions.isEmpty()) { |
172 |
| - sql += " WHERE " + String.join(" AND ", whereConditions); |
173 |
| - } |
174 |
| - |
175 |
| - return new PreparedQuery(sql, parameters); |
| 158 | + @Nonnull List<String> allColumns, |
| 159 | + @Nonnull String tableName, |
| 160 | + @Nonnull List<Object> values, |
| 161 | + @Nonnull Map<String, Object> whereClause) { |
| 162 | + checkInvalidColumns(allColumns, whereClause); |
| 163 | + List<Object> bindingParams = new ArrayList<>(values); |
| 164 | + PreparedQuery where = generateWhereClause(whereClause); |
| 165 | + String setClause = allColumns.stream().map(c -> c + " = ?").collect(Collectors.joining(", ")); |
| 166 | + String sql = |
| 167 | + "UPDATE " + getFullyQualifiedTableName(tableName) + " SET " + setClause + where.sql(); |
| 168 | + bindingParams.addAll(where.parameters()); |
| 169 | + return new PreparedQuery(sql, bindingParams); |
176 | 170 | }
|
177 | 171 |
|
178 |
| - public <T> PreparedQuery generateDeleteQuery( |
179 |
| - @Nonnull Converter<T> entity, @Nonnull Map<String, Object> whereClause) { |
180 |
| - checkInvalidColumns(entity.toMap(databaseType), whereClause); |
181 |
| - PreparedQuery preparedQuery = generateWhereClause(whereClause); |
| 172 | + /** |
| 173 | + * Builds a DELETE query with the given conditions. |
| 174 | + * |
| 175 | + * @param tableColumns List of valid table columns. |
| 176 | + * @param tableName Target table. |
| 177 | + * @param whereClause Column-value filters. |
| 178 | + * @return DELETE query with parameter bindings. |
| 179 | + */ |
| 180 | + public PreparedQuery generateDeleteQuery( |
| 181 | + @Nonnull List<String> tableColumns, |
| 182 | + @Nonnull String tableName, |
| 183 | + @Nonnull Map<String, Object> whereClause) { |
| 184 | + checkInvalidColumns(tableColumns, whereClause); |
| 185 | + PreparedQuery where = generateWhereClause(whereClause); |
182 | 186 | return new PreparedQuery(
|
183 |
| - "DELETE FROM " + getTableName(entity.getClass()) + preparedQuery.getSql(), |
184 |
| - preparedQuery.getParameters()); |
185 |
| - } |
186 |
| - |
187 |
| - public PreparedQuery generateDeleteAll(@Nonnull Class<?> entityClass, @Nonnull String realmId) { |
188 |
| - String sql = "DELETE FROM " + getTableName(entityClass) + " WHERE 1 = 1 AND realm_id = ?"; |
189 |
| - return new PreparedQuery(sql, List.of(realmId)); |
| 187 | + "DELETE FROM " + getFullyQualifiedTableName(tableName) + where.sql(), where.parameters()); |
190 | 188 | }
|
191 | 189 |
|
192 | 190 | @VisibleForTesting
|
193 |
| - <T> PreparedQuery generateSelectQuery(@Nonnull Converter<T> entity, @Nonnull String filter) { |
194 |
| - String tableName = getTableName(entity.getClass()); |
195 |
| - Map<String, Object> objectMap = entity.toMap(databaseType); |
196 |
| - String columns = String.join(", ", objectMap.keySet()); |
197 |
| - String sql = "SELECT " + columns + " FROM " + tableName + filter; |
| 191 | + PreparedQuery generateSelectQuery( |
| 192 | + @Nonnull List<String> columnNames, @Nonnull String tableName, @Nonnull String filter) { |
| 193 | + String sql = |
| 194 | + "SELECT " |
| 195 | + + String.join(", ", columnNames) |
| 196 | + + " FROM " |
| 197 | + + getFullyQualifiedTableName(tableName) |
| 198 | + + filter; |
198 | 199 | return new PreparedQuery(sql, Collections.emptyList());
|
199 | 200 | }
|
200 | 201 |
|
201 | 202 | @VisibleForTesting
|
202 | 203 | PreparedQuery generateWhereClause(@Nonnull Map<String, Object> whereClause) {
|
203 | 204 | List<String> conditions = new ArrayList<>();
|
204 | 205 | List<Object> parameters = new ArrayList<>();
|
205 |
| - |
206 | 206 | for (Map.Entry<String, Object> entry : whereClause.entrySet()) {
|
207 | 207 | conditions.add(entry.getKey() + " = ?");
|
208 | 208 | parameters.add(entry.getValue());
|
209 | 209 | }
|
210 |
| - |
211 | 210 | String clause = conditions.isEmpty() ? "" : " WHERE " + String.join(" AND ", conditions);
|
212 | 211 | return new PreparedQuery(clause, parameters);
|
213 | 212 | }
|
214 | 213 |
|
215 |
| - @VisibleForTesting |
216 |
| - public static String getTableName(@Nonnull Class<?> entityClass) { |
217 |
| - String tableName; |
218 |
| - if (entityClass.equals(ModelEntity.class)) { |
219 |
| - tableName = "ENTITIES"; |
220 |
| - } else if (entityClass.equals(ModelGrantRecord.class)) { |
221 |
| - tableName = "GRANT_RECORDS"; |
222 |
| - } else if (entityClass.equals(ModelPrincipalAuthenticationData.class)) { |
223 |
| - tableName = "PRINCIPAL_AUTHENTICATION_DATA"; |
224 |
| - } else if (entityClass.equals(ModelPolicyMappingRecord.class)) { |
225 |
| - tableName = "POLICY_MAPPING_RECORD"; |
226 |
| - } else { |
227 |
| - throw new IllegalArgumentException("Unsupported entity class: " + entityClass.getName()); |
| 214 | + /** Validates that WHERE clause columns exist in the given list of valid columns. */ |
| 215 | + private void checkInvalidColumns(List<String> entity, Map<String, Object> whereClause) { |
| 216 | + if (!whereClause.isEmpty()) { |
| 217 | + for (String key : whereClause.keySet()) { |
| 218 | + if (!entity.contains(key) && !key.equals("realm_id")) { |
| 219 | + throw new IllegalArgumentException("Invalid query column: " + key); |
| 220 | + } |
| 221 | + } |
228 | 222 | }
|
229 |
| - |
230 |
| - return "POLARIS_SCHEMA." + tableName; |
231 | 223 | }
|
232 | 224 |
|
233 |
| - private void checkInvalidColumns(Map<String, Object> entity, Map<String, Object> whereClause) { |
234 |
| - List<String> allColumns = new ArrayList<>(entity.keySet()); |
235 |
| - allColumns.add("realm_id"); |
236 |
| - if (!allColumns.containsAll(whereClause.keySet())) { |
237 |
| - throw new IllegalArgumentException("Invalid query " + whereClause.keySet()); |
238 |
| - } |
| 225 | + private String getFullyQualifiedTableName(String tableName) { |
| 226 | + // TODO: make schema name configurable. |
| 227 | + return "POLARIS_SCHEMA." + tableName; |
239 | 228 | }
|
240 | 229 | }
|
0 commit comments