Skip to content

Commit 42e65cb

Browse files
committed
- Make sure primary keys are supported types
- Fix TODO nodeId check, move out of transpile into builder - Create SupportedPrimaryKeyType enum for PK type matching
1 parent 6a4ce3e commit 42e65cb

File tree

4 files changed

+107
-49
lines changed

4 files changed

+107
-49
lines changed

src/builder.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,16 @@ pub struct NodeIdInstance {
10221022
pub values: Vec<serde_json::Value>,
10231023
}
10241024

1025+
impl NodeIdInstance {
1026+
pub fn validate(&self, table: &Table) -> Result<(), String> {
1027+
// Validate that nodeId belongs to the table being queried
1028+
if self.schema_name != table.schema || self.table_name != table.name {
1029+
return Err("nodeId belongs to a different collection".to_string());
1030+
}
1031+
Ok(())
1032+
}
1033+
}
1034+
10251035
#[derive(Clone, Debug)]
10261036
pub struct NodeIdBuilder {
10271037
pub alias: String,

src/graphql.rs

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,55 +1249,58 @@ impl ___Type for QueryType {
12491249
f.push(collection_entrypoint);
12501250

12511251
// Add single record query by primary key if the table has a primary key
1252+
// and the primary key types are supported (int, bigint, uuid, string)
12521253
if let Some(primary_key) = table.primary_key() {
1253-
let node_type = NodeType {
1254-
table: Arc::clone(table),
1255-
fkey: None,
1256-
reverse_reference: None,
1257-
schema: Arc::clone(&self.schema),
1258-
};
1259-
1260-
// Create arguments for each primary key column
1261-
let mut pk_args = Vec::new();
1262-
for col_name in &primary_key.column_names {
1263-
if let Some(col) = table.columns.iter().find(|c| &c.name == col_name) {
1264-
let col_type = sql_column_to_graphql_type(col, &self.schema)
1265-
.ok_or_else(|| {
1266-
format!(
1267-
"Could not determine GraphQL type for column {}",
1268-
col_name
1269-
)
1270-
})
1271-
.unwrap_or_else(|_| __Type::Scalar(Scalar::String(None)));
1272-
1273-
// Use graphql_column_field_name to convert snake_case to camelCase if needed
1274-
let arg_name = self.schema.graphql_column_field_name(col);
1275-
1276-
pk_args.push(__InputValue {
1277-
name_: arg_name,
1278-
type_: __Type::NonNull(NonNullType {
1279-
type_: Box::new(col_type),
1280-
}),
1281-
description: Some(format!("The record's `{}` value", col_name)),
1282-
default_value: None,
1283-
sql_type: Some(NodeSQLType::Column(Arc::clone(col))),
1284-
});
1254+
if table.has_supported_pk_types_for_by_pk() {
1255+
let node_type = NodeType {
1256+
table: Arc::clone(table),
1257+
fkey: None,
1258+
reverse_reference: None,
1259+
schema: Arc::clone(&self.schema),
1260+
};
1261+
1262+
// Create arguments for each primary key column
1263+
let mut pk_args = Vec::new();
1264+
for col_name in &primary_key.column_names {
1265+
if let Some(col) = table.columns.iter().find(|c| &c.name == col_name) {
1266+
let col_type = sql_column_to_graphql_type(col, &self.schema)
1267+
.ok_or_else(|| {
1268+
format!(
1269+
"Could not determine GraphQL type for column {}",
1270+
col_name
1271+
)
1272+
})
1273+
.unwrap_or_else(|_| __Type::Scalar(Scalar::String(None)));
1274+
1275+
// Use graphql_column_field_name to convert snake_case to camelCase if needed
1276+
let arg_name = self.schema.graphql_column_field_name(col);
1277+
1278+
pk_args.push(__InputValue {
1279+
name_: arg_name,
1280+
type_: __Type::NonNull(NonNullType {
1281+
type_: Box::new(col_type),
1282+
}),
1283+
description: Some(format!("The record's `{}` value", col_name)),
1284+
default_value: None,
1285+
sql_type: Some(NodeSQLType::Column(Arc::clone(col))),
1286+
});
1287+
}
12851288
}
1286-
}
12871289

1288-
let pk_entrypoint = __Field {
1289-
name_: format!("{}ByPk", lowercase_first_letter(table_base_type_name)),
1290-
type_: __Type::Node(node_type),
1291-
args: pk_args,
1292-
description: Some(format!(
1293-
"Retrieve a record of type `{}` by its primary key",
1294-
table_base_type_name
1295-
)),
1296-
deprecation_reason: None,
1297-
sql_type: None,
1298-
};
1290+
let pk_entrypoint = __Field {
1291+
name_: format!("{}ByPk", lowercase_first_letter(table_base_type_name)),
1292+
type_: __Type::Node(node_type),
1293+
args: pk_args,
1294+
description: Some(format!(
1295+
"Retrieve a record of type `{}` by its primary key",
1296+
table_base_type_name
1297+
)),
1298+
deprecation_reason: None,
1299+
sql_type: None,
1300+
};
12991301

1300-
f.push(pk_entrypoint);
1302+
f.push(pk_entrypoint);
1303+
}
13011304
}
13021305
}
13031306
}
@@ -3485,7 +3488,7 @@ impl FromStr for FilterOp {
34853488
"contains" => Ok(Self::Contains),
34863489
"containedBy" => Ok(Self::ContainedBy),
34873490
"overlaps" => Ok(Self::Overlap),
3488-
_ => Err("Invalid filter operation".to_string()),
3491+
other => Err(format!("Invalid filter operation: {}", other)),
34893492
}
34903493
}
34913494
}

src/sql_types.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,18 @@ impl Table {
576576
.collect::<Vec<&Arc<Column>>>()
577577
}
578578

579+
pub fn has_supported_pk_types_for_by_pk(&self) -> bool {
580+
let pk_columns = self.primary_key_columns();
581+
if pk_columns.is_empty() {
582+
return false;
583+
}
584+
585+
// Check that all primary key columns have supported types
586+
pk_columns.iter().all(|col| {
587+
SupportedPrimaryKeyType::from_type_name(&col.type_name).is_some()
588+
})
589+
}
590+
579591
pub fn is_any_column_selectable(&self) -> bool {
580592
self.columns.iter().any(|x| x.permissions.is_selectable)
581593
}
@@ -597,6 +609,41 @@ impl Table {
597609
}
598610
}
599611

612+
#[derive(Debug, PartialEq)]
613+
pub enum SupportedPrimaryKeyType {
614+
// Integer types
615+
Int, // int, int4, integer
616+
BigInt, // bigint, int8
617+
SmallInt, // smallint, int2
618+
// String types
619+
Text, // text
620+
VarChar, // varchar
621+
Char, // char, bpchar
622+
CiText, // citext
623+
// UUID
624+
UUID, // uuid
625+
}
626+
627+
impl SupportedPrimaryKeyType {
628+
fn from_type_name(type_name: &str) -> Option<Self> {
629+
match type_name {
630+
// Integer types
631+
"int" | "int4" | "integer" => Some(Self::Int),
632+
"bigint" | "int8" => Some(Self::BigInt),
633+
"smallint" | "int2" => Some(Self::SmallInt),
634+
// String types
635+
"text" => Some(Self::Text),
636+
"varchar" => Some(Self::VarChar),
637+
"char" | "bpchar" => Some(Self::Char),
638+
"citext" => Some(Self::CiText),
639+
// UUID
640+
"uuid" => Some(Self::UUID),
641+
// Any other type is not supported
642+
_ => None,
643+
}
644+
}
645+
}
646+
600647
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
601648
pub struct SchemaDirectives {
602649
// @graphql({"inflect_names": true})

src/transpile.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,9 +1589,7 @@ impl NodeIdInstance {
15891589
param_context: &mut ParamContext,
15901590
) -> Result<String, String> {
15911591
// Validate that nodeId belongs to the table being queried
1592-
if self.schema_name != table.schema || self.table_name != table.name {
1593-
return Err("nodeId belongs to a different collection".to_string());
1594-
}
1592+
self.validate(table)?;
15951593

15961594
let pkey = table
15971595
.primary_key()

0 commit comments

Comments
 (0)