Skip to content

feat(completions): complete in WITH CHECK and USING clauses #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: feat/to-role
Choose a base branch
from
175 changes: 121 additions & 54 deletions crates/pgt_completions/src/context/base_parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::iter::Peekable;

use pgt_text_size::{TextRange, TextSize};
use std::iter::Peekable;

pub(crate) struct TokenNavigator {
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
Expand Down Expand Up @@ -101,73 +100,139 @@ impl WordWithIndex {
}
}

/// Note: A policy name within quotation marks will be considered a single word.
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
let mut words = vec![];

let mut start_of_word: Option<usize> = None;
let mut current_word = String::new();
let mut in_quotation_marks = false;

for (current_position, current_char) in sql.char_indices() {
if (current_char.is_ascii_whitespace() || current_char == ';')
&& !current_word.is_empty()
&& start_of_word.is_some()
&& !in_quotation_marks
{
words.push(WordWithIndex {
word: current_word,
start: start_of_word.unwrap(),
end: current_position,
});

current_word = String::new();
start_of_word = None;
} else if (current_char.is_ascii_whitespace() || current_char == ';')
&& current_word.is_empty()
{
// do nothing
} else if current_char == '"' && start_of_word.is_none() {
in_quotation_marks = true;
current_word.push(current_char);
start_of_word = Some(current_position);
} else if current_char == '"' && start_of_word.is_some() {
current_word.push(current_char);
in_quotation_marks = false;
} else if start_of_word.is_some() {
current_word.push(current_char)
pub(crate) struct SubStatementParser {
start_of_word: Option<usize>,
current_word: String,
in_quotation_marks: bool,
is_fn_call: bool,
words: Vec<WordWithIndex>,
}

impl SubStatementParser {
pub(crate) fn parse(sql: &str) -> Result<Vec<WordWithIndex>, String> {
let mut parser = SubStatementParser {
start_of_word: None,
current_word: String::new(),
in_quotation_marks: false,
is_fn_call: false,
words: vec![],
};

parser.collect_words(sql);

if parser.in_quotation_marks {
Err("String was not closed properly.".into())
} else {
start_of_word = Some(current_position);
current_word.push(current_char);
Ok(parser.words)
}
}

if let Some(start_of_word) = start_of_word {
if !current_word.is_empty() {
words.push(WordWithIndex {
word: current_word,
start: start_of_word,
end: sql.len(),
});
pub fn collect_words(&mut self, sql: &str) {
for (pos, c) in sql.char_indices() {
match c {
'"' => {
if !self.has_started_word() {
self.in_quotation_marks = true;
self.add_char(c);
self.start_word(pos);
} else {
self.in_quotation_marks = false;
self.add_char(c);
}
}

'(' => {
if !self.has_started_word() {
self.push_char_as_word(c, pos);
} else {
self.add_char(c);
self.is_fn_call = true;
}
}

')' => {
if self.is_fn_call {
self.add_char(c);
self.is_fn_call = false;
} else {
if self.has_started_word() {
self.push_word(pos);
}
self.push_char_as_word(c, pos);
}
}

_ => {
if c.is_ascii_whitespace() || c == ';' {
if self.in_quotation_marks {
self.add_char(c);
} else if !self.is_empty() && self.has_started_word() {
self.push_word(pos);
}
} else if self.has_started_word() {
self.add_char(c);
} else {
self.start_word(pos);
self.add_char(c)
}
}
}
}

if self.has_started_word() && !self.is_empty() {
self.push_word(sql.len())
}
}

if in_quotation_marks {
Err("String was not closed properly.".into())
} else {
Ok(words)
fn is_empty(&self) -> bool {
self.current_word.is_empty()
}

fn add_char(&mut self, c: char) {
self.current_word.push(c)
}

fn start_word(&mut self, pos: usize) {
self.start_of_word = Some(pos);
}

fn has_started_word(&self) -> bool {
self.start_of_word.is_some()
}

fn push_char_as_word(&mut self, c: char, pos: usize) {
self.words.push(WordWithIndex {
word: String::from(c),
start: pos,
end: pos + 1,
});
}

fn push_word(&mut self, current_position: usize) {
self.words.push(WordWithIndex {
word: self.current_word.clone(),
start: self.start_of_word.unwrap(),
end: current_position,
});
self.current_word = String::new();
self.start_of_word = None;
}
}

/// Note: A policy name within quotation marks will be considered a single word.
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
SubStatementParser::parse(sql)
}

#[cfg(test)]
mod tests {
use crate::context::base_parser::{WordWithIndex, sql_to_words};
use crate::context::base_parser::{SubStatementParser, WordWithIndex, sql_to_words};

#[test]
fn determines_positions_correctly() {
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string();
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (auth.uid());".to_string();

let words = sql_to_words(query.as_str()).unwrap();
let words = SubStatementParser::parse(query.as_str()).unwrap();

assert_eq!(words[0], to_word("create", 1, 7));
assert_eq!(words[1], to_word("policy", 8, 14));
Expand All @@ -181,7 +246,9 @@ mod tests {
assert_eq!(words[9], to_word("to", 73, 75));
assert_eq!(words[10], to_word("public", 78, 84));
assert_eq!(words[11], to_word("using", 87, 92));
assert_eq!(words[12], to_word("(true)", 93, 99));
assert_eq!(words[12], to_word("(", 93, 94));
assert_eq!(words[13], to_word("auth.uid()", 94, 104));
assert_eq!(words[14], to_word(")", 104, 105));
}

#[test]
Expand Down
27 changes: 25 additions & 2 deletions crates/pgt_completions/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ pub enum WrappingClause<'a> {
SetStatement,
AlterRole,
DropRole,

/// `PolicyCheck` refers to either the `WITH CHECK` or the `USING` clause
/// in a policy statement.
/// ```sql
/// CREATE POLICY "my pol" ON PUBLIC.USERS
/// FOR SELECT
/// USING (...) -- this one!
/// ```
PolicyCheck,
}

#[derive(PartialEq, Eq, Hash, Debug, Clone)]
Expand Down Expand Up @@ -78,6 +87,7 @@ pub(crate) enum NodeUnderCursor<'a> {
text: NodeText,
range: TextRange,
kind: String,
previous_node_kind: Option<String>,
},
}

Expand Down Expand Up @@ -222,6 +232,7 @@ impl<'a> CompletionContext<'a> {
text: revoke_context.node_text.into(),
range: revoke_context.node_range,
kind: revoke_context.node_kind.clone(),
previous_node_kind: None,
});

if revoke_context.node_kind == "revoke_table" {
Expand Down Expand Up @@ -249,6 +260,7 @@ impl<'a> CompletionContext<'a> {
text: grant_context.node_text.into(),
range: grant_context.node_range,
kind: grant_context.node_kind.clone(),
previous_node_kind: None,
});

if grant_context.node_kind == "grant_table" {
Expand Down Expand Up @@ -276,6 +288,7 @@ impl<'a> CompletionContext<'a> {
text: policy_context.node_text.into(),
range: policy_context.node_range,
kind: policy_context.node_kind.clone(),
previous_node_kind: Some(policy_context.previous_node_kind),
});

if policy_context.node_kind == "policy_table" {
Expand All @@ -295,7 +308,13 @@ impl<'a> CompletionContext<'a> {
}
"policy_role" => Some(WrappingClause::ToRoleAssignment),
"policy_table" => Some(WrappingClause::From),
_ => None,
_ => {
if policy_context.in_check_or_using_clause {
Some(WrappingClause::PolicyCheck)
} else {
None
}
}
};
}

Expand Down Expand Up @@ -785,7 +804,11 @@ impl<'a> CompletionContext<'a> {
.is_some_and(|sib| kinds.contains(&sib.kind()))
}

NodeUnderCursor::CustomNode { .. } => false,
NodeUnderCursor::CustomNode {
previous_node_kind, ..
} => previous_node_kind
.as_ref()
.is_some_and(|k| kinds.contains(&k.as_str())),
}
})
}
Expand Down
Loading
Loading