diff --git a/src/catalog.rs b/src/catalog.rs index ffe48b5f..837e460a 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -45,6 +45,7 @@ pub enum Error { CollectionAlreadyExists { name: String }, FunctionAlreadyExists { name: String }, FunctionDeserializationError { reason: String }, + FunctionNotFound { names: String }, // Creating a table in / dropping the staging schema UsedStagingSchema, SqlxError(sqlx::Error), @@ -128,11 +129,13 @@ impl From for DataFusionError { Error::FunctionAlreadyExists { name } => { DataFusionError::Plan(format!("Function {name:?} already exists")) } + Error::FunctionNotFound { names } => { + DataFusionError::Plan(format!("Function {names:?} not found")) + } Error::UsedStagingSchema => DataFusionError::Plan( "The staging schema can only be referenced via CREATE EXTERNAL TABLE" .to_string(), ), - // Miscellaneous sqlx error. We want to log it but it's not worth showing to the user. Error::SqlxError(e) => DataFusionError::Internal(format!( "Internal SQL error: {:?}", @@ -235,6 +238,13 @@ pub trait FunctionCatalog: Sync + Send { &self, database_id: DatabaseId, ) -> Result>; + + async fn drop_function( + &self, + database_id: DatabaseId, + if_exists: bool, + func_names: &[String], + ) -> Result<()>; } #[derive(Clone)] @@ -680,4 +690,28 @@ impl FunctionCatalog for DefaultCatalog { }) .collect::>>() } + + async fn drop_function( + &self, + database_id: DatabaseId, + if_exists: bool, + func_names: &[String], + ) -> Result<()> { + match self.repository.drop_function(database_id, func_names).await { + Ok(id) => Ok(id), + Err(RepositoryError::FKConstraintViolation(_)) => { + Err(Error::DatabaseDoesNotExist { id: database_id }) + } + Err(RepositoryError::SqlxError(sqlx::error::Error::RowNotFound)) => { + if if_exists { + Ok(()) + } else { + Err(Error::FunctionNotFound { + names: func_names.join(", "), + }) + } + } + Err(e) => Err(Self::to_sqlx_error(e)), + } + } } diff --git a/src/context.rs b/src/context.rs index fcfd9195..73e0b2c0 100644 --- a/src/context.rs +++ b/src/context.rs @@ -111,8 +111,8 @@ use crate::{ catalog::{FunctionCatalog, TableCatalog}, data_types::DatabaseId, nodes::{ - CreateFunction, CreateTable, DropSchema, RenameTable, SeafowlExtensionNode, - Vacuum, + CreateFunction, CreateTable, DropFunction, DropSchema, RenameTable, + SeafowlExtensionNode, Vacuum, }, schema::Schema as SeafowlSchema, version::TableVersionProcessor, @@ -1216,6 +1216,21 @@ impl SeafowlContext for DefaultSeafowlContext { })), })) } + Statement::DropFunction{ + if_exists, + func_desc, + option: _ + } => { + let func_names: Vec = + func_desc.iter().map(|desc| desc.name.to_string()).collect(); + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(SeafowlExtensionNode::DropFunction(DropFunction { + if_exists, + func_names, + output_schema: Arc::new(DFSchema::empty()), + })) + })) + } _ => Err(Error::NotImplemented(format!( "Unsupported SQL statement: {s:?}" ))), @@ -1732,6 +1747,16 @@ impl SeafowlContext for DefaultSeafowlContext { Ok(make_dummy_exec()) } + SeafowlExtensionNode::DropFunction(DropFunction { + if_exists, + func_names, + output_schema: _, + }) => { + self.function_catalog + .drop_function(self.database_id, *if_exists, func_names) + .await?; + Ok(make_dummy_exec()) + } SeafowlExtensionNode::RenameTable(RenameTable { old_name, new_name, @@ -2598,4 +2623,66 @@ mod tests { "Internal error: Error initializing WASM + MessagePack UDF \"invalidfn\": Internal(\"Error loading WASM module: failed to parse WebAssembly module")); Ok(()) } + + #[tokio::test] + async fn test_drop_function() -> Result<()> { + let sf_context = in_memory_context().await; + + let err = sf_context + .plan_query(r#"DROP FUNCTION nonexistentfunction"#) + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Error during planning: Function \"nonexistentfunction\" not found" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_drop_function_if_exists() -> Result<()> { + let sf_context = in_memory_context().await; + + let plan = sf_context + .plan_query(r#"DROP FUNCTION IF EXISTS nonexistentfunction"#) + .await; + assert!(plan.is_ok()); + Ok(()) + } + + #[tokio::test] + async fn test_create_and_drop_two_functions() -> Result<()> { + let sf_context = in_memory_context().await; + + let create_function_stmt = r#"CREATE FUNCTION sintau AS ' + { + "entrypoint": "sintau", + "language": "wasm", + "input_types": ["int"], + "return_type": "int", + "data": "AGFzbQEAAAABDQJgAX0BfWADfX9/AX0DBQQAAAABBQQBAUREBxgDBnNpbnRhdQAABGV4cDIAAQRsb2cyAAIKjgEEKQECfUMAAAA/IgIgACAAjpMiACACk4siAZMgAZZBAEEYEAMgAiAAk5gLGQAgACAAjiIAk0EYQSwQA7wgAKhBF3RqvgslAQF/IAC8IgFBF3ZB/wBrsiABQQl0s0MAAIBPlUEsQcQAEAOSCyIBAX0DQCADIACUIAEqAgCSIQMgAUEEaiIBIAJrDQALIAMLC0oBAEEAC0Q/x2FC2eATQUuqKsJzsqY9QAHJQH6V0DZv+V88kPJTPSJndz6sZjE/HQCAP/clMD0D/T++F6bRPkzcNL/Tgrg//IiKNwBqBG5hbWUBHwQABnNpbnRhdQEEZXhwMgIEbG9nMgMIZXZhbHBvbHkCNwQAAwABeAECeDECBGhhbGYBAQABeAICAAF4AQJ4aQMEAAF4AQVzdGFydAIDZW5kAwZyZXN1bHQDCQEDAQAEbG9vcA==" + }';"#; + + let create_function_stmt2 = r#"CREATE FUNCTION sintau2 AS ' + { + "entrypoint": "sintau", + "language": "wasm", + "input_types": ["int"], + "return_type": "int", + "data": "AGFzbQEAAAABDQJgAX0BfWADfX9/AX0DBQQAAAABBQQBAUREBxgDBnNpbnRhdQAABGV4cDIAAQRsb2cyAAIKjgEEKQECfUMAAAA/IgIgACAAjpMiACACk4siAZMgAZZBAEEYEAMgAiAAk5gLGQAgACAAjiIAk0EYQSwQA7wgAKhBF3RqvgslAQF/IAC8IgFBF3ZB/wBrsiABQQl0s0MAAIBPlUEsQcQAEAOSCyIBAX0DQCADIACUIAEqAgCSIQMgAUEEaiIBIAJrDQALIAMLC0oBAEEAC0Q/x2FC2eATQUuqKsJzsqY9QAHJQH6V0DZv+V88kPJTPSJndz6sZjE/HQCAP/clMD0D/T++F6bRPkzcNL/Tgrg//IiKNwBqBG5hbWUBHwQABnNpbnRhdQEEZXhwMgIEbG9nMgMIZXZhbHBvbHkCNwQAAwABeAECeDECBGhhbGYBAQABeAICAAF4AQJ4aQMEAAF4AQVzdGFydAIDZW5kAwZyZXN1bHQDCQEDAQAEbG9vcA==" + }';"#; + + // Create two functions in two separate passes + sf_context.plan_query(create_function_stmt).await?; + sf_context.plan_query(create_function_stmt2).await?; + + // Test dropping both functions in one pass + let plan = sf_context + .plan_query(r#"DROP FUNCTION sintau, sintau2"#) + .await; + assert!(plan.is_ok()); + Ok(()) + } } diff --git a/src/nodes.rs b/src/nodes.rs index ff32abf4..c5edead5 100644 --- a/src/nodes.rs +++ b/src/nodes.rs @@ -31,6 +31,14 @@ pub struct CreateFunction { pub output_schema: DFSchemaRef, } +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct DropFunction { + pub if_exists: bool, + pub func_names: Vec, + /// Dummy result schema for the plan (empty) + pub output_schema: DFSchemaRef, +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct RenameTable { /// Old name @@ -63,6 +71,7 @@ pub struct Vacuum { pub enum SeafowlExtensionNode { CreateTable(CreateTable), CreateFunction(CreateFunction), + DropFunction(DropFunction), RenameTable(RenameTable), DropSchema(DropSchema), Vacuum(Vacuum), @@ -100,6 +109,9 @@ impl UserDefinedLogicalNode for SeafowlExtensionNode { output_schema, .. }) => output_schema, + SeafowlExtensionNode::DropFunction(DropFunction { + output_schema, .. + }) => output_schema, SeafowlExtensionNode::RenameTable(RenameTable { output_schema, .. }) => { output_schema } @@ -125,6 +137,10 @@ impl UserDefinedLogicalNode for SeafowlExtensionNode { SeafowlExtensionNode::CreateFunction(CreateFunction { name, .. }) => { write!(f, "CreateFunction: {name}") } + SeafowlExtensionNode::DropFunction(DropFunction { func_names, .. }) => { + let names_str = func_names.join(", "); + write!(f, "DropFunction: {names_str}") + } SeafowlExtensionNode::RenameTable(RenameTable { old_name, new_name, .. }) => { diff --git a/src/repository/default.rs b/src/repository/default.rs index 881437de..d01e931d 100644 --- a/src/repository/default.rs +++ b/src/repository/default.rs @@ -439,6 +439,34 @@ impl Repository for $repo { Ok(functions) } + + async fn drop_function( + &self, + database_id: DatabaseId, + func_names: &[String], + ) -> Result<(), Error> { + let query = format!( + r#" + DELETE FROM "function" + WHERE database_id = $1 + AND name IN ({}) + RETURNING id; + "#, + func_names.iter().map(|_| "$2").collect::>().join(", ") + ); + + let mut query_builder = sqlx::query(&query).bind(database_id); + for func_name in func_names { + query_builder = query_builder.bind(func_name); + } + query_builder + .fetch_one(&self.executor) + .await + .map_err($repo::interpret_error)?; + + Ok(()) + } + // Drop table/collection/database // In these methods, return the ID back so that we get an error if the diff --git a/src/repository/interface.rs b/src/repository/interface.rs index 48025397..33f9d3fc 100644 --- a/src/repository/interface.rs +++ b/src/repository/interface.rs @@ -171,6 +171,12 @@ pub trait Repository: Send + Sync + Debug { database_id: DatabaseId, ) -> Result, Error>; + async fn drop_function( + &self, + database_id: DatabaseId, + func_names: &[String], + ) -> Result<(), Error>; + async fn drop_table(&self, table_id: TableId) -> Result<(), Error>; async fn drop_collection(&self, collection_id: CollectionId) -> Result<(), Error>;