diff --git a/src/lib.rs b/src/lib.rs index c079d2b..26c1900 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ //! use std::collections::{HashMap, HashSet}; -use deadpool::managed::Object; +use deadpool::managed::{Object, PoolError}; use filter::{AndFilter, EqFilter, Filter}; use ldap3::{ log::{debug, error}, @@ -1155,7 +1155,7 @@ impl LdapClient { "objectClass".to_string(), "groupOfNames".to_string(), )); - + let user_filter = Box::new(EqFilter::from("member".to_string(), user_dn.to_string())); let mut filter = AndFilter::default(); filter.add(group_filter); @@ -1163,7 +1163,12 @@ impl LdapClient { let search = self .ldap - .search(group_ou, Scope::Subtree, filter.filter().as_str(), vec!["cn"]) + .search( + group_ou, + Scope::Subtree, + filter.filter().as_str(), + vec!["cn"], + ) .await; if let Err(error) = search { @@ -1188,11 +1193,19 @@ impl LdapClient { ))); } - let record = records.iter() - .map(|record| SearchEntry::construct(record.to_owned())).map(|se| se.attrs) - .flat_map(|att| att.get("cn").unwrap().iter().map(|x| x.to_owned()).collect::>()) + let record = records + .iter() + .map(|record| SearchEntry::construct(record.to_owned())) + .map(|se| se.attrs) + .flat_map(|att| { + att.get("cn") + .unwrap() + .iter() + .map(|x| x.to_owned()) + .collect::>() + }) .collect::>(); - + Ok(record) } } @@ -1202,7 +1215,6 @@ impl LdapClient { /// #[derive(Debug)] pub enum Error { - /// /// Error occured when performing a LDAP query Query(String, LdapError), /// No records found for the search criteria @@ -1219,6 +1231,10 @@ pub enum Error { Delete(String, LdapError), /// Error occured when mapping the search result to a struct Mapping(String), + /// Error occurred while attempting to create a LDAP connection + Connection(String, LdapError), + /// Error occurred while using the connection pool + Pool(PoolError), } #[cfg(test)] @@ -1313,6 +1329,7 @@ mod tests { let result = pool .get_connection() .await + .unwrap() .create( "bd9b91ec-7a69-4166-bf67-cc7e553b2fd9", "ou=people,dc=example,dc=com", @@ -1332,7 +1349,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from("cn".to_string(), "Sam".to_string()); let user = ldap .search::( @@ -1358,7 +1375,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from("cn".to_string(), "SamX".to_string()); let user = ldap .search::( @@ -1386,7 +1403,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from("cn".to_string(), "James".to_string()); let user = ldap .search::( @@ -1414,7 +1431,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let data = vec![ Mod::Replace("cn", HashSet::from(["Jhon_Update"])), Mod::Replace("sn", HashSet::from(["Eliet_Update"])), @@ -1440,7 +1457,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let data = vec![ Mod::Replace("cn", HashSet::from(["Jhon_Update"])), Mod::Replace("sn", HashSet::from(["Eliet_Update"])), @@ -1471,7 +1488,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let data = vec![ Mod::Replace("cn", HashSet::from(["David_Update"])), Mod::Replace("sn", HashSet::from(["Hanks_Update"])), @@ -1487,7 +1504,7 @@ mod tests { assert!(result.is_ok()); - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from( "uid".to_string(), "6da70e51-7897-411f-9290-649ebfcb3269".to_string(), @@ -1516,7 +1533,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from("cn".to_string(), "James".to_string()); let result = ldap @@ -1543,7 +1560,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let name_filter = EqFilter::from("cn".to_string(), "JamesX".to_string()); let result = ldap @@ -1570,7 +1587,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let result = ldap .delete( @@ -1591,7 +1608,7 @@ mod tests { }; let pool = pool::build_connection_pool(&ldap_config).await; - let mut ldap = pool.get_connection().await; + let mut ldap = pool.get_connection().await.unwrap(); let result = ldap .delete( @@ -1621,6 +1638,7 @@ mod tests { let result = pool .get_connection() .await + .unwrap() .create_group("test_group", "dc=example,dc=com", "Some Description") .await; @@ -1641,12 +1659,14 @@ mod tests { let _result = pool .get_connection() .await + .unwrap() .create_group("test_group_1", "dc=example,dc=com", "Some Decription") .await; let result = pool .get_connection() .await + .unwrap() .add_users_to_group( vec![ "uid=f92f4cb2-e821-44a4-bb13-b8ebadf4ecc5,ou=people,dc=example,dc=com", @@ -1674,12 +1694,14 @@ mod tests { let _result = pool .get_connection() .await + .unwrap() .create_group("test_group_3", "dc=example,dc=com", "Some Decription 2") .await; let _result = pool .get_connection() .await + .unwrap() .add_users_to_group( vec![ "uid=f92f4cb2-e821-44a4-bb13-b8ebadf4ecc5,ou=people,dc=example,dc=com", @@ -1692,6 +1714,7 @@ mod tests { let result = pool .get_connection() .await + .unwrap() .get_members::( "cn=test_group_3,dc=example,dc=com", "dc=example,dc=com", @@ -1727,12 +1750,14 @@ mod tests { let _result = pool .get_connection() .await + .unwrap() .create_group("test_group_2", "dc=example,dc=com", "Some Decription 2") .await; let _result = pool .get_connection() .await + .unwrap() .add_users_to_group( vec![ "uid=f92f4cb2-e821-44a4-bb13-b8ebadf4ecc5,ou=people,dc=example,dc=com", @@ -1745,6 +1770,7 @@ mod tests { let result = pool .get_connection() .await + .unwrap() .remove_users_from_group( "cn=test_group_2,dc=example,dc=com", vec![ @@ -1770,6 +1796,7 @@ mod tests { let result = pool .get_connection() .await + .unwrap() .get_associtated_groups( "ou=group,dc=example,dc=com", "uid=e219fbc0-6df5-4bc3-a6ee-986843bb157e,ou=people,dc=example,dc=com", @@ -1777,7 +1804,7 @@ mod tests { .await; assert!(result.is_ok()); - assert_eq!(result.unwrap().len(),2); + assert_eq!(result.unwrap().len(), 2); } #[derive(Deserialize)] diff --git a/src/pool.rs b/src/pool.rs index 58d9f69..b941381 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -4,7 +4,7 @@ use ldap3::{Ldap, LdapConnAsync, LdapConnSettings}; use log::debug; use serde::Deserialize; -use crate::LdapClient; +use crate::{Error, LdapClient}; pub struct Manager(String, LdapConnSettings); pub type Pool = deadpool::managed::Pool; @@ -66,15 +66,16 @@ pub struct LdapPool { } impl LdapPool { - pub async fn get_connection(&self) -> LdapClient { - let mut ldap = self.pool.get().await.unwrap(); + /// Returns an existing LDAP connection from the pool or creates a new one if required. + pub async fn get_connection(&self) -> Result { + let mut ldap = self.pool.get().await.map_err(Error::Pool)?; ldap.simple_bind(self.config.bind_dn.as_str(), self.config.bind_pw.as_str()) .await - .unwrap() + .map_err(|e| Error::Connection("unable to create connection".into(), e))? .success() - .unwrap(); + .map_err(|e| Error::Connection("unable to create connection".into(), e))?; - LdapClient::from(ldap) + Ok(LdapClient::from(ldap)) } }