|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +//! Simple example of a catalog/schema implementation. |
| 19 | +//! |
| 20 | +//! Example requires git submodules to be initialized in repo as it uses data from |
| 21 | +//! the `parquet-testing` repo. |
| 22 | +use async_trait::async_trait; |
| 23 | +use datafusion::{ |
| 24 | + arrow::util::pretty, |
| 25 | + catalog::{ |
| 26 | + catalog::{CatalogList, CatalogProvider}, |
| 27 | + schema::SchemaProvider, |
| 28 | + }, |
| 29 | + datasource::{ |
| 30 | + file_format::{csv::CsvFormat, parquet::ParquetFormat, FileFormat}, |
| 31 | + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, |
| 32 | + TableProvider, |
| 33 | + }, |
| 34 | + error::Result, |
| 35 | + execution::context::SessionState, |
| 36 | + prelude::SessionContext, |
| 37 | +}; |
| 38 | +use std::sync::RwLock; |
| 39 | +use std::{ |
| 40 | + any::Any, |
| 41 | + collections::HashMap, |
| 42 | + path::{Path, PathBuf}, |
| 43 | + sync::Arc, |
| 44 | +}; |
| 45 | + |
| 46 | +#[tokio::main] |
| 47 | +async fn main() -> Result<()> { |
| 48 | + let repo_dir = std::fs::canonicalize( |
| 49 | + PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
| 50 | + // parent dir of datafusion-examples = repo root |
| 51 | + .join(".."), |
| 52 | + ) |
| 53 | + .unwrap(); |
| 54 | + let mut ctx = SessionContext::new(); |
| 55 | + let state = ctx.state(); |
| 56 | + let catlist = Arc::new(CustomCatalogList::new()); |
| 57 | + // use our custom catalog list for context. each context has a single catalog list. |
| 58 | + // context will by default have MemoryCatalogList |
| 59 | + ctx.register_catalog_list(catlist.clone()); |
| 60 | + |
| 61 | + // intitialize our catalog and schemas |
| 62 | + let catalog = DirCatalog::new(); |
| 63 | + let parquet_schema = DirSchema::create( |
| 64 | + &state, |
| 65 | + DirSchemaOpts { |
| 66 | + format: Arc::new(ParquetFormat::default()), |
| 67 | + dir: &repo_dir.join("parquet-testing").join("data"), |
| 68 | + ext: "parquet", |
| 69 | + }, |
| 70 | + ) |
| 71 | + .await?; |
| 72 | + let csv_schema = DirSchema::create( |
| 73 | + &state, |
| 74 | + DirSchemaOpts { |
| 75 | + format: Arc::new(CsvFormat::default()), |
| 76 | + dir: &repo_dir.join("testing").join("data").join("csv"), |
| 77 | + ext: "csv", |
| 78 | + }, |
| 79 | + ) |
| 80 | + .await?; |
| 81 | + // register schemas into catalog |
| 82 | + catalog.register_schema("parquet", parquet_schema.clone())?; |
| 83 | + catalog.register_schema("csv", csv_schema.clone())?; |
| 84 | + // register our catalog in the context |
| 85 | + ctx.register_catalog("dircat", Arc::new(catalog)); |
| 86 | + { |
| 87 | + // catalog was passed down into our custom catalog list since we overide the ctx's default |
| 88 | + let catalogs = catlist.catalogs.read().unwrap(); |
| 89 | + assert!(catalogs.contains_key("dircat")); |
| 90 | + }; |
| 91 | + // take the first 5 (arbitrary amount) keys from our schema's hashmap. |
| 92 | + // in our `DirSchema`, the table names are equivalent to their key in the hashmap, |
| 93 | + // so any key in the hashmap will now be a queryable in our datafusion context. |
| 94 | + let parquet_tables = { |
| 95 | + let tables = parquet_schema.tables.read().unwrap(); |
| 96 | + tables.keys().take(5).cloned().collect::<Vec<_>>() |
| 97 | + }; |
| 98 | + for table in parquet_tables { |
| 99 | + println!("querying table {table} from parquet schema"); |
| 100 | + let df = ctx |
| 101 | + .sql(&format!("select * from dircat.parquet.\"{table}\" ")) |
| 102 | + .await? |
| 103 | + .limit(0, Some(5))?; |
| 104 | + let result = df.collect().await; |
| 105 | + match result { |
| 106 | + Ok(batches) => { |
| 107 | + pretty::print_batches(&batches).unwrap(); |
| 108 | + } |
| 109 | + Err(e) => { |
| 110 | + println!("table '{table}' query failed due to {e}"); |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | + let table_to_drop = { |
| 115 | + let parquet_tables = parquet_schema.tables.read().unwrap(); |
| 116 | + parquet_tables.keys().next().unwrap().to_owned() |
| 117 | + }; |
| 118 | + // DDL example |
| 119 | + let df = ctx |
| 120 | + .sql(&format!("DROP TABLE dircat.parquet.\"{table_to_drop}\"")) |
| 121 | + .await?; |
| 122 | + df.collect().await?; |
| 123 | + let parquet_tables = parquet_schema.tables.read().unwrap(); |
| 124 | + // datafusion has deregistered the table from our schema |
| 125 | + // (called our schema's deregister func) |
| 126 | + assert!(!parquet_tables.contains_key(&table_to_drop)); |
| 127 | + Ok(()) |
| 128 | +} |
| 129 | + |
| 130 | +struct DirSchemaOpts<'a> { |
| 131 | + ext: &'a str, |
| 132 | + dir: &'a Path, |
| 133 | + format: Arc<dyn FileFormat>, |
| 134 | +} |
| 135 | +/// Schema where every file with extension `ext` in a given `dir` is a table. |
| 136 | +struct DirSchema { |
| 137 | + ext: String, |
| 138 | + tables: RwLock<HashMap<String, Arc<dyn TableProvider>>>, |
| 139 | +} |
| 140 | +impl DirSchema { |
| 141 | + async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result<Arc<Self>> { |
| 142 | + let DirSchemaOpts { ext, dir, format } = opts; |
| 143 | + let mut tables = HashMap::new(); |
| 144 | + let listdir = std::fs::read_dir(dir).unwrap(); |
| 145 | + for res in listdir { |
| 146 | + let entry = res.unwrap(); |
| 147 | + let filename = entry.file_name().to_str().unwrap().to_string(); |
| 148 | + if !filename.ends_with(ext) { |
| 149 | + continue; |
| 150 | + } |
| 151 | + |
| 152 | + let table_path = ListingTableUrl::parse(entry.path().to_str().unwrap())?; |
| 153 | + let opts = ListingOptions::new(format.clone()); |
| 154 | + let conf = ListingTableConfig::new(table_path) |
| 155 | + .with_listing_options(opts) |
| 156 | + .infer_schema(state) |
| 157 | + .await?; |
| 158 | + let table = ListingTable::try_new(conf)?; |
| 159 | + tables.insert(filename, Arc::new(table) as Arc<dyn TableProvider>); |
| 160 | + } |
| 161 | + Ok(Arc::new(Self { |
| 162 | + tables: RwLock::new(tables), |
| 163 | + ext: ext.to_string(), |
| 164 | + })) |
| 165 | + } |
| 166 | + #[allow(unused)] |
| 167 | + fn name(&self) -> &str { |
| 168 | + &self.ext |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +#[async_trait] |
| 173 | +impl SchemaProvider for DirSchema { |
| 174 | + fn as_any(&self) -> &dyn Any { |
| 175 | + self |
| 176 | + } |
| 177 | + |
| 178 | + fn table_names(&self) -> Vec<String> { |
| 179 | + let tables = self.tables.read().unwrap(); |
| 180 | + tables.keys().cloned().collect::<Vec<_>>() |
| 181 | + } |
| 182 | + |
| 183 | + async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> { |
| 184 | + let tables = self.tables.read().unwrap(); |
| 185 | + tables.get(name).cloned() |
| 186 | + } |
| 187 | + |
| 188 | + fn table_exist(&self, name: &str) -> bool { |
| 189 | + let tables = self.tables.read().unwrap(); |
| 190 | + tables.contains_key(name) |
| 191 | + } |
| 192 | + fn register_table( |
| 193 | + &self, |
| 194 | + name: String, |
| 195 | + table: Arc<dyn TableProvider>, |
| 196 | + ) -> Result<Option<Arc<dyn TableProvider>>> { |
| 197 | + let mut tables = self.tables.write().unwrap(); |
| 198 | + println!("adding table {name}"); |
| 199 | + tables.insert(name, table.clone()); |
| 200 | + Ok(Some(table)) |
| 201 | + } |
| 202 | + |
| 203 | + /// If supported by the implementation, removes an existing table from this schema and returns it. |
| 204 | + /// If no table of that name exists, returns Ok(None). |
| 205 | + #[allow(unused_variables)] |
| 206 | + fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> { |
| 207 | + let mut tables = self.tables.write().unwrap(); |
| 208 | + println!("dropping table {name}"); |
| 209 | + Ok(tables.remove(name)) |
| 210 | + } |
| 211 | +} |
| 212 | +/// Catalog holds multiple schemas |
| 213 | +struct DirCatalog { |
| 214 | + schemas: RwLock<HashMap<String, Arc<dyn SchemaProvider>>>, |
| 215 | +} |
| 216 | +impl DirCatalog { |
| 217 | + fn new() -> Self { |
| 218 | + Self { |
| 219 | + schemas: RwLock::new(HashMap::new()), |
| 220 | + } |
| 221 | + } |
| 222 | +} |
| 223 | +impl CatalogProvider for DirCatalog { |
| 224 | + fn as_any(&self) -> &dyn Any { |
| 225 | + self |
| 226 | + } |
| 227 | + fn register_schema( |
| 228 | + &self, |
| 229 | + name: &str, |
| 230 | + schema: Arc<dyn SchemaProvider>, |
| 231 | + ) -> Result<Option<Arc<dyn SchemaProvider>>> { |
| 232 | + let mut schema_map = self.schemas.write().unwrap(); |
| 233 | + schema_map.insert(name.to_owned(), schema.clone()); |
| 234 | + Ok(Some(schema)) |
| 235 | + } |
| 236 | + |
| 237 | + fn schema_names(&self) -> Vec<String> { |
| 238 | + let schemas = self.schemas.read().unwrap(); |
| 239 | + schemas.keys().cloned().collect() |
| 240 | + } |
| 241 | + |
| 242 | + fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> { |
| 243 | + let schemas = self.schemas.read().unwrap(); |
| 244 | + let maybe_schema = schemas.get(name); |
| 245 | + if let Some(schema) = maybe_schema { |
| 246 | + let schema = schema.clone() as Arc<dyn SchemaProvider>; |
| 247 | + Some(schema) |
| 248 | + } else { |
| 249 | + None |
| 250 | + } |
| 251 | + } |
| 252 | +} |
| 253 | +/// Catalog lists holds multiple catalogs. Each context has a single catalog list. |
| 254 | +struct CustomCatalogList { |
| 255 | + catalogs: RwLock<HashMap<String, Arc<dyn CatalogProvider>>>, |
| 256 | +} |
| 257 | +impl CustomCatalogList { |
| 258 | + fn new() -> Self { |
| 259 | + Self { |
| 260 | + catalogs: RwLock::new(HashMap::new()), |
| 261 | + } |
| 262 | + } |
| 263 | +} |
| 264 | +impl CatalogList for CustomCatalogList { |
| 265 | + fn as_any(&self) -> &dyn Any { |
| 266 | + self |
| 267 | + } |
| 268 | + fn register_catalog( |
| 269 | + &self, |
| 270 | + name: String, |
| 271 | + catalog: Arc<dyn CatalogProvider>, |
| 272 | + ) -> Option<Arc<dyn CatalogProvider>> { |
| 273 | + let mut cats = self.catalogs.write().unwrap(); |
| 274 | + cats.insert(name, catalog.clone()); |
| 275 | + Some(catalog) |
| 276 | + } |
| 277 | + |
| 278 | + /// Retrieves the list of available catalog names |
| 279 | + fn catalog_names(&self) -> Vec<String> { |
| 280 | + let cats = self.catalogs.read().unwrap(); |
| 281 | + cats.keys().cloned().collect() |
| 282 | + } |
| 283 | + |
| 284 | + /// Retrieves a specific catalog by name, provided it exists. |
| 285 | + fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> { |
| 286 | + let cats = self.catalogs.read().unwrap(); |
| 287 | + cats.get(name).cloned() |
| 288 | + } |
| 289 | +} |
0 commit comments