Skip to content

Commit 673b8d6

Browse files
merge with main
2 parents 6f12fcc + 895e14c commit 673b8d6

File tree

5 files changed

+96
-7
lines changed

5 files changed

+96
-7
lines changed

crates/pgt_statement_splitter/src/lib.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ mod tests {
4545
assert_eq!(
4646
self.parse.ranges.len(),
4747
expected.len(),
48-
"Expected {} statements, got {}: {:?}",
48+
"Expected {} statements for input {}, got {}: {:?}",
4949
expected.len(),
50+
self.input,
5051
self.parse.ranges.len(),
5152
self.parse
5253
.ranges
@@ -114,10 +115,24 @@ mod tests {
114115

115116
#[test]
116117
fn grant() {
117-
Tester::from("GRANT SELECT ON TABLE \"public\".\"my_table\" TO \"my_role\";")
118-
.expect_statements(vec![
119-
"GRANT SELECT ON TABLE \"public\".\"my_table\" TO \"my_role\";",
120-
]);
118+
let stmts = vec![
119+
"GRANT SELECT ON TABLE \"public\".\"my_table\" TO \"my_role\";",
120+
"GRANT UPDATE ON TABLE \"public\".\"my_table\" TO \"my_role\";",
121+
"GRANT DELETE ON TABLE \"public\".\"my_table\" TO \"my_role\";",
122+
"GRANT INSERT ON TABLE \"public\".\"my_table\" TO \"my_role\";",
123+
"GRANT CREATE ON SCHEMA \"public\" TO \"my_role\";",
124+
"GRANT ALL PRIVILEGES ON DATABASE \"my_database\" TO \"my_role\";",
125+
"GRANT USAGE ON SCHEMA \"public\" TO \"my_role\";",
126+
"GRANT EXECUTE ON FUNCTION \"public\".\"my_function\"() TO \"my_role\";",
127+
"GRANT REFERENCES ON TABLE \"public\".\"my_table\" TO \"my_role\";",
128+
"GRANT SELECT, UPDATE ON ALL TABLES IN SCHEMA \"public\" TO \"my_role\";",
129+
"GRANT SELECT, INSERT ON public.users TO anon WITH GRANT OPION GRANTED BY owner;",
130+
"GRANT owner, admin to anon WITH ADMIN;",
131+
];
132+
133+
for stmt in stmts {
134+
Tester::from(stmt).expect_statements(vec![stmt]);
135+
}
121136
}
122137

123138
#[test]

crates/pgt_statement_splitter/src/parser/common.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
217217
SyntaxKind::Except,
218218
// for grant
219219
SyntaxKind::Grant,
220+
SyntaxKind::Ascii44,
220221
]
221222
.iter()
222223
.all(|x| Some(x) != prev.as_ref())
@@ -246,6 +247,7 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
246247
SyntaxKind::Instead,
247248
// for grant
248249
SyntaxKind::Grant,
250+
SyntaxKind::Ascii44,
249251
]
250252
.iter()
251253
.all(|x| Some(x) != prev.as_ref())
@@ -263,6 +265,10 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
263265
SyntaxKind::Check,
264266
// TIMESTAMP WITH TIME ZONE should not start a new statement
265267
SyntaxKind::Time,
268+
SyntaxKind::Grant,
269+
SyntaxKind::Admin,
270+
SyntaxKind::Inherit,
271+
SyntaxKind::Set,
266272
]
267273
.iter()
268274
.all(|x| Some(x) != next.as_ref())
@@ -271,6 +277,22 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
271277
}
272278
p.advance();
273279
}
280+
281+
Some(SyntaxKind::Create) => {
282+
let prev = p.look_back().map(|t| t.kind);
283+
if [
284+
// for grant
285+
SyntaxKind::Grant,
286+
SyntaxKind::Ascii44,
287+
]
288+
.iter()
289+
.all(|x| Some(x) != prev.as_ref())
290+
{
291+
break;
292+
}
293+
294+
p.advance();
295+
}
274296
Some(_) => {
275297
break;
276298
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
GRANT CREATE ON SCHEMA public TO anon;
2+
3+
GRANT SELECT, INSERT ON public.users TO anon WITH GRANT OPTION GRANTED BY Owner;
4+
5+
GRANT read_access, write_access TO user_role
6+
WITH INHERIT TRUE
7+
GRANTED BY security_admin;
8+
9+
GRANT manager_role TO employee_role
10+
WITH ADMIN OPTION
11+
GRANTED BY admin_role;

crates/pgt_workspace/src/features/completions.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,24 @@ pub(crate) fn get_statement_for_completions(
4949
if count == 1 {
5050
eligible_statements.next()
5151
} else {
52-
let mut prev_stmt = None;
52+
let mut prev_stmt: Option<(StatementId, TextRange, String, Arc<tree_sitter::Tree>)> = None;
5353

5454
for current_stmt in eligible_statements {
5555
/*
5656
* If we have multiple statements, we want to make sure that we do not overlap
5757
* with the next one.
5858
*
5959
* select 1 |select 1;
60+
*
61+
* This is however ok if the current statement is a child of the previous one,
62+
* such as in CREATE FUNCTION bodies.
6063
*/
61-
if prev_stmt.is_some_and(|_| current_stmt.1.contains(position)) {
64+
if prev_stmt.is_some_and(|prev| {
65+
current_stmt.1.contains(position) && !current_stmt.0.is_child_of(&prev.0)
66+
}) {
6267
return None;
6368
}
69+
6470
prev_stmt = Some(current_stmt)
6571
}
6672

@@ -162,6 +168,30 @@ mod tests {
162168
assert_eq!(text, "select * from")
163169
}
164170

171+
#[test]
172+
fn identifies_nested_stmts() {
173+
let sql = format!(
174+
r#"
175+
create or replace function one()
176+
returns integer
177+
language sql
178+
as $$
179+
select {} from cool;
180+
$$;
181+
"#,
182+
CURSOR_POSITION
183+
);
184+
185+
let sql = sql.trim();
186+
187+
let (doc, position) = get_doc_and_pos(sql);
188+
189+
let (_, _, text, _) =
190+
get_statement_for_completions(&doc, position).expect("Expected Statement");
191+
192+
assert_eq!(text.trim(), "select from cool;")
193+
}
194+
165195
#[test]
166196
fn does_not_consider_too_far_offset() {
167197
let sql = format!("select * from {}", CURSOR_POSITION);

crates/pgt_workspace/src/workspace/server/statement_identifier.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ impl StatementId {
6666
matches!(self, StatementId::Child(_))
6767
}
6868

69+
pub fn is_child_of(&self, maybe_parent: &StatementId) -> bool {
70+
match self {
71+
StatementId::Root(_) => false,
72+
StatementId::Child(child_root) => match maybe_parent {
73+
StatementId::Root(parent_rood) => child_root == parent_rood,
74+
// TODO: can we have multiple nested statements?
75+
StatementId::Child(_) => false,
76+
},
77+
}
78+
}
79+
6980
pub fn parent(&self) -> Option<StatementId> {
7081
match self {
7182
StatementId::Root(_) => None,

0 commit comments

Comments
 (0)