Skip to content

Commit a9ac119

Browse files
authored
feat(lsp): invalidate schema cache (#588)
this is useful if you iterate quickly on your schema
1 parent 376654e commit a9ac119

File tree

8 files changed

+241
-4
lines changed

8 files changed

+241
-4
lines changed

crates/pgls_lsp/src/handlers/code_actions.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@ pub fn get_actions(
5353
.map(|reason| CodeActionDisabled { reason }),
5454
..Default::default()
5555
}),
56+
CommandActionCategory::InvalidateSchemaCache => Some(CodeAction {
57+
title: title.clone(),
58+
kind: Some(lsp_types::CodeActionKind::EMPTY),
59+
command: Some({
60+
Command {
61+
title: title.clone(),
62+
command: command_id,
63+
arguments: None,
64+
}
65+
}),
66+
disabled: action
67+
.disabled_reason
68+
.map(|reason| CodeActionDisabled { reason }),
69+
..Default::default()
70+
}),
5671
}
5772
}
5873

@@ -68,7 +83,8 @@ pub fn get_actions(
6883

6984
pub fn command_id(command: &CommandActionCategory) -> String {
7085
match command {
71-
CommandActionCategory::ExecuteStatement(_) => "pgt.executeStatement".into(),
86+
CommandActionCategory::ExecuteStatement(_) => "pgls.executeStatement".into(),
87+
CommandActionCategory::InvalidateSchemaCache => "pgls.invalidateSchemaCache".into(),
7288
}
7389
}
7490

@@ -80,7 +96,7 @@ pub async fn execute_command(
8096
let command = params.command;
8197

8298
match command.as_str() {
83-
"pgt.executeStatement" => {
99+
"pgls.executeStatement" => {
84100
let statement_id = serde_json::from_value::<pgls_workspace::workspace::StatementId>(
85101
params.arguments[0].clone(),
86102
)?;
@@ -105,7 +121,16 @@ pub async fn execute_command(
105121

106122
Ok(None)
107123
}
124+
"pgls.invalidateSchemaCache" => {
125+
session.workspace.invalidate_schema_cache(true)?;
108126

127+
session
128+
.client
129+
.show_message(MessageType::INFO, "Schema cache invalidated")
130+
.await;
131+
132+
Ok(None)
133+
}
109134
any => Err(anyhow!(format!("Unknown command: {}", any))),
110135
}
111136
}

crates/pgls_lsp/src/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ impl ServerFactory {
461461
workspace_method!(builder, get_completions);
462462
workspace_method!(builder, register_project_folder);
463463
workspace_method!(builder, unregister_project_folder);
464+
workspace_method!(builder, invalidate_schema_cache);
464465

465466
let (service, socket) = builder.finish();
466467
ServerConnection { socket, service }

crates/pgls_lsp/tests/server.rs

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> {
916916
.find_map(|action_or_cmd| match action_or_cmd {
917917
lsp::CodeActionOrCommand::CodeAction(code_action) => {
918918
let command = code_action.command.as_ref();
919-
if command.is_some_and(|cmd| &cmd.command == "pgt.executeStatement") {
919+
if command.is_some_and(|cmd| &cmd.command == "pgls.executeStatement") {
920920
let command = command.unwrap();
921921
let arguments = command.arguments.as_ref().unwrap().clone();
922922
Some((command.command.clone(), arguments))
@@ -952,6 +952,164 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> {
952952
Ok(())
953953
}
954954

955+
#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")]
956+
async fn test_invalidate_schema_cache(test_db: PgPool) -> Result<()> {
957+
let factory = ServerFactory::default();
958+
let mut fs = MemoryFileSystem::default();
959+
960+
let database = test_db
961+
.connect_options()
962+
.get_database()
963+
.unwrap()
964+
.to_string();
965+
let host = test_db.connect_options().get_host().to_string();
966+
967+
// Setup: Create a table with only id column (no name column yet)
968+
let setup = r#"
969+
create table public.users (
970+
id serial primary key
971+
);
972+
"#;
973+
974+
test_db
975+
.execute(setup)
976+
.await
977+
.expect("Failed to setup test database");
978+
979+
let mut conf = PartialConfiguration::init();
980+
conf.merge_with(PartialConfiguration {
981+
db: Some(PartialDatabaseConfiguration {
982+
database: Some(database),
983+
host: Some(host),
984+
..Default::default()
985+
}),
986+
..Default::default()
987+
});
988+
989+
fs.insert(
990+
url!("postgres-language-server.jsonc")
991+
.to_file_path()
992+
.unwrap(),
993+
serde_json::to_string_pretty(&conf).unwrap(),
994+
);
995+
996+
let (service, client) = factory
997+
.create_with_fs(None, DynRef::Owned(Box::new(fs)))
998+
.into_inner();
999+
1000+
let (stream, sink) = client.split();
1001+
let mut server = Server::new(service);
1002+
1003+
let (sender, _receiver) = channel(CHANNEL_BUFFER_SIZE);
1004+
let reader = tokio::spawn(client_handler(stream, sink, sender));
1005+
1006+
server.initialize().await?;
1007+
server.initialized().await?;
1008+
1009+
server.load_configuration().await?;
1010+
1011+
// Open a document to get completions from
1012+
let doc_content = "select from public.users;";
1013+
server.open_document(doc_content).await?;
1014+
1015+
// Get completions before adding the column - 'name' should NOT be present
1016+
let completions_before = server
1017+
.get_completion(CompletionParams {
1018+
work_done_progress_params: WorkDoneProgressParams::default(),
1019+
partial_result_params: PartialResultParams::default(),
1020+
context: None,
1021+
text_document_position: TextDocumentPositionParams {
1022+
text_document: TextDocumentIdentifier {
1023+
uri: url!("document.sql"),
1024+
},
1025+
position: Position {
1026+
line: 0,
1027+
character: 7,
1028+
},
1029+
},
1030+
})
1031+
.await?
1032+
.unwrap();
1033+
1034+
let items_before = match completions_before {
1035+
CompletionResponse::Array(ref a) => a,
1036+
CompletionResponse::List(ref l) => &l.items,
1037+
};
1038+
1039+
let has_name_before = items_before.iter().any(|item| {
1040+
item.label == "name"
1041+
&& item.label_details.as_ref().is_some_and(|d| {
1042+
d.description
1043+
.as_ref()
1044+
.is_some_and(|desc| desc.contains("public.users"))
1045+
})
1046+
});
1047+
1048+
assert!(
1049+
!has_name_before,
1050+
"Column 'name' should not be in completions before it's added to the table"
1051+
);
1052+
1053+
// Add the missing column to the database
1054+
let alter_table = r#"
1055+
alter table public.users
1056+
add column name text;
1057+
"#;
1058+
1059+
test_db
1060+
.execute(alter_table)
1061+
.await
1062+
.expect("Failed to add column to table");
1063+
1064+
// Invalidate the schema cache (all = false for current connection only)
1065+
server
1066+
.request::<bool, ()>("pgt/invalidate_schema_cache", "_invalidate_cache", false)
1067+
.await?;
1068+
1069+
// Get completions after invalidating cache - 'name' should NOW be present
1070+
let completions_after = server
1071+
.get_completion(CompletionParams {
1072+
work_done_progress_params: WorkDoneProgressParams::default(),
1073+
partial_result_params: PartialResultParams::default(),
1074+
context: None,
1075+
text_document_position: TextDocumentPositionParams {
1076+
text_document: TextDocumentIdentifier {
1077+
uri: url!("document.sql"),
1078+
},
1079+
position: Position {
1080+
line: 0,
1081+
character: 7,
1082+
},
1083+
},
1084+
})
1085+
.await?
1086+
.unwrap();
1087+
1088+
let items_after = match completions_after {
1089+
CompletionResponse::Array(ref a) => a,
1090+
CompletionResponse::List(ref l) => &l.items,
1091+
};
1092+
1093+
let has_name_after = items_after.iter().any(|item| {
1094+
item.label == "name"
1095+
&& item.label_details.as_ref().is_some_and(|d| {
1096+
d.description
1097+
.as_ref()
1098+
.is_some_and(|desc| desc.contains("public.users"))
1099+
})
1100+
});
1101+
1102+
assert!(
1103+
has_name_after,
1104+
"Column 'name' should be in completions after schema cache invalidation"
1105+
);
1106+
1107+
server.shutdown().await?;
1108+
reader.abort();
1109+
1110+
Ok(())
1111+
}
1112+
9551113
#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")]
9561114
async fn test_issue_281(test_db: PgPool) -> Result<()> {
9571115
let factory = ServerFactory::default();

crates/pgls_workspace/src/features/code_actions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pub struct CommandAction {
4848
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
4949
pub enum CommandActionCategory {
5050
ExecuteStatement(StatementId),
51+
InvalidateSchemaCache,
5152
}
5253

5354
#[derive(Debug, serde::Serialize, serde::Deserialize)]

crates/pgls_workspace/src/workspace.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ pub trait Workspace: Send + Sync + RefUnwindSafe {
158158
&self,
159159
params: ExecuteStatementParams,
160160
) -> Result<ExecuteStatementResult, WorkspaceError>;
161+
162+
/// Invalidate the schema cache.
163+
///
164+
/// # Arguments
165+
/// * `all` - If true, clears all cached schemas. If false, clears only the current connection's cache.
166+
///
167+
/// The schema will be reloaded lazily on the next operation that requires it.
168+
fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError>;
161169
}
162170

163171
/// Convenience function for constructing a server instance of [Workspace]

crates/pgls_workspace/src/workspace/client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,8 @@ where
168168
) -> Result<crate::features::on_hover::OnHoverResult, WorkspaceError> {
169169
self.request("pgt/on_hover", params)
170170
}
171+
172+
fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> {
173+
self.request("pgt/invalidate_schema_cache", all)
174+
}
171175
}

crates/pgls_workspace/src/workspace/server.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ impl Workspace for WorkspaceServer {
358358
None => Some("Statement execution not allowed against database.".into()),
359359
};
360360

361-
let actions = parser
361+
let mut actions: Vec<CodeAction> = parser
362362
.iter_with_filter(
363363
DefaultMapper,
364364
CursorPositionFilter::new(params.cursor_position),
@@ -379,6 +379,20 @@ impl Workspace for WorkspaceServer {
379379
})
380380
.collect();
381381

382+
let invalidate_disabled_reason = if self.get_current_connection().is_some() {
383+
None
384+
} else {
385+
Some("No database connection available.".into())
386+
};
387+
388+
actions.push(CodeAction {
389+
title: "Invalidate Schema Cache".into(),
390+
kind: CodeActionKind::Command(CommandAction {
391+
category: CommandActionCategory::InvalidateSchemaCache,
392+
}),
393+
disabled_reason: invalidate_disabled_reason,
394+
});
395+
382396
Ok(CodeActionsResult { actions })
383397
}
384398

@@ -424,6 +438,19 @@ impl Workspace for WorkspaceServer {
424438
})
425439
}
426440

441+
fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> {
442+
if all {
443+
self.schema_cache.clear_all();
444+
} else {
445+
// Only clear current connection if one exists
446+
if let Some(pool) = self.get_current_connection() {
447+
self.schema_cache.clear(&pool);
448+
}
449+
// If no connection, nothing to clear - just return Ok
450+
}
451+
Ok(())
452+
}
453+
427454
#[ignored_path(path=&params.path)]
428455
fn pull_diagnostics(
429456
&self,

crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,17 @@ impl SchemaCacheManager {
4646
schemas.insert(key, schema_cache.clone());
4747
Ok(schema_cache)
4848
}
49+
50+
/// Clear the schema cache for a specific connection
51+
pub fn clear(&self, pool: &PgPool) {
52+
let key: ConnectionKey = pool.into();
53+
let mut schemas = self.schemas.write().unwrap();
54+
schemas.remove(&key);
55+
}
56+
57+
/// Clear all schema caches
58+
pub fn clear_all(&self) {
59+
let mut schemas = self.schemas.write().unwrap();
60+
schemas.clear();
61+
}
4962
}

0 commit comments

Comments
 (0)