Skip to content

Commit

Permalink
Use own nftables fork
Browse files Browse the repository at this point in the history
  • Loading branch information
hack3ric committed Nov 15, 2024
1 parent 0f39394 commit 968961d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 65 deletions.
36 changes: 12 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ postcard = { version = "1.0.10", default-features = false, features = [
"use-std",
] }
replace_with = "0.1.7"
serde = { version = "1.0.210", features = ["derive"] }
serde = { version = "1.0.215", features = ["derive"] }
serde_json = "1.0.132"
smallvec = { version = "1.13.2", features = [
"union",
Expand All @@ -42,8 +42,9 @@ tokio = { version = "1.38.0", features = [
tokio-util = "0.7.11"

[target.'cfg(target_os = "linux")'.dependencies]
nftables = "0.5.0"
nftables-async = "0.1.1"
nftables = { git = "https://github.com/hack3ric/rust-nftables", branch = "for-flow", features = [
"tokio",
] }
rtnetlink = { git = "https://github.com/hack3ric/rust-rtnetlink", branch = "for-flow" }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/kernel/linux/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Kernel for Linux {
let nftables = NftablesReq {
objects: rules
.into_iter()
.map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x, Some(handle)))))
.map(|x| NfObject::CmdObject(NfCmd::Add(self.nft.make_new_rule(x.into(), Some(handle)))))
.collect(),
};

Expand Down
78 changes: 41 additions & 37 deletions src/kernel/linux/nft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::net::{Afi, IpPrefix};
use crate::util::{Intersect, TruthTable};
use nftables::batch::Batch;
use nftables::expr::Expression::{Number as NUM, String as STRING};
use nftables_async::{apply_ruleset, get_current_ruleset_raw};
use nftables::helper::{apply_ruleset_async, get_current_ruleset_raw_async};
use nftables::schema::Nftables as NftablesReq;
use nftables::{expr, schema, stmt, types};
use num::Integer;
Expand Down Expand Up @@ -40,19 +40,19 @@ impl Nftables {
let mut batch = Batch::new();
batch.add(schema::NfListObject::Table(schema::Table {
family: types::NfFamily::INet,
name: table.to_string(),
name: table.clone(),
..Default::default()
}));
batch.add(schema::NfListObject::Chain(schema::Chain {
family: types::NfFamily::INet,
table: table.to_string(),
name: chain.to_string(),
table: table.clone(),
name: chain.clone(),
_type: hooked.then_some(types::NfChainType::Filter),
hook: hooked.then_some(types::NfHook::Input),
prio: hooked.then_some(priority),
..Default::default()
}));
apply_ruleset(&batch.to_nftables(), None, None).await?;
apply_ruleset_async(&batch.to_nftables(), None, None).await?;
Ok(Self { table, chain, armed: true })
}

Expand All @@ -61,52 +61,56 @@ impl Nftables {
let mut batch = Batch::new();
batch.delete(schema::NfListObject::Chain(schema::Chain {
family: types::NfFamily::INet,
table: self.table.to_string(),
name: self.chain.to_string(),
table: self.table.clone(),
name: self.chain.clone(),
..Default::default()
}));
Ok(apply_ruleset(&batch.to_nftables(), None, None).await?)
Ok(apply_ruleset_async(&batch.to_nftables(), None, None).await?)
}

fn exit_sync(&self) -> Result<()> {
let mut batch = Batch::new();
batch.delete(schema::NfListObject::Chain(schema::Chain {
family: types::NfFamily::INet,
table: self.table.to_string(),
name: self.chain.to_string(),
table: self.table.clone(),
name: self.chain.clone(),
..Default::default()
}));
Ok(nftables::helper::apply_ruleset(&batch.to_nftables(), None, None)?)
}

pub fn make_new_rule(&self, stmts: Vec<stmt::Statement>, comment: Option<impl ToString>) -> schema::NfListObject {
pub fn make_new_rule(
&self,
stmts: Cow<'static, [stmt::Statement]>,
comment: Option<impl ToString>,
) -> schema::NfListObject {
schema::NfListObject::Rule(schema::Rule {
family: types::NfFamily::INet,
table: self.table.to_string(),
chain: self.chain.to_string(),
table: self.table.clone(),
chain: self.chain.clone(),
expr: stmts,
comment: comment.map(|x| x.to_string()),
comment: comment.map(|x| x.to_string().into()),
..Default::default()
})
}

pub fn make_rule_handle(&self, handle: u32) -> schema::NfListObject {
schema::NfListObject::Rule(schema::Rule {
family: types::NfFamily::INet,
table: self.table.to_string(),
chain: self.chain.to_string(),
table: self.table.clone(),
chain: self.chain.clone(),
handle: Some(handle),
..Default::default()
})
}

pub async fn get_current_ruleset_raw(&self) -> Result<String> {
let args = vec!["-n", "-s", "list", "chain", "inet", &self.table, &self.chain];
Ok(get_current_ruleset_raw(None, Some(args)).await?)
let args = ["-n", "-s", "list", "chain", "inet", &self.table, &self.chain];
Ok(get_current_ruleset_raw_async(None, Some(&args)).await?)
}

pub async fn apply_ruleset(&self, n: &NftablesReq) -> Result<()> {
Ok(apply_ruleset(n, None, None).await?)
Ok(apply_ruleset_async(n, None, None).await?)
}
}

Expand Down Expand Up @@ -206,10 +210,10 @@ impl Component {
}
smallvec_inline![smallvec_inline![make_match(
tt.inv.then_some(stmt::Operator::NEQ).unwrap_or(stmt::Operator::EQ),
expr::Expression::BinaryOperation(expr::BinaryOperation::AND(
Box::new(make_payload_field("tcp", "flags")),
Box::new(NUM(tt.mask as u32)),
)),
expr::Expression::BinaryOperation(Box::new(expr::BinaryOperation::AND(
make_payload_field("tcp", "flags"),
NUM(tt.mask as u32),
))),
expr::Expression::Named(expr::NamedExpression::Set(
(tt.truth.iter().copied())
.map(|x| expr::SetItem::Element(NUM(x as u32)))
Expand Down Expand Up @@ -355,7 +359,7 @@ impl TrafficFilterAction {
value: NUM(dscp.into()),
})],
RedirectToIp { ip, copy: true } => smallvec_inline![stmt::Statement::Dup(stmt::Dup {
addr: STRING(ip.to_string()),
addr: STRING(ip.to_string().into()),
dev: None,
})],
RedirectToIp { ip, copy: false } => {
Expand Down Expand Up @@ -384,11 +388,11 @@ fn make_match(op: stmt::Operator, left: expr::Expression, right: expr::Expressio
stmt::Statement::Match(stmt::Match { left, right, op })
}

fn make_limit(over: bool, rate: f32, unit: impl ToString, per: impl ToString) -> stmt::Statement {
fn make_limit(over: bool, rate: f32, unit: &'static str, per: &'static str) -> stmt::Statement {
stmt::Statement::Limit(stmt::Limit {
rate: rate.round() as u32,
rate_unit: Some(unit.to_string()),
per: Some(per.to_string()),
rate_unit: Some(unit.into()),
per: Some(per.into()),
burst: None,
burst_unit: None,
inv: Some(over),
Expand All @@ -401,31 +405,31 @@ fn make_payload_raw(base: expr::PayloadBase, offset: u32, len: u32) -> expr::Exp
)))
}

fn make_payload_field(protocol: impl ToString, field: impl ToString) -> expr::Expression {
fn make_payload_field(protocol: &'static str, field: &'static str) -> expr::Expression {
expr::Expression::Named(expr::NamedExpression::Payload(expr::Payload::PayloadField(
expr::PayloadField { protocol: protocol.to_string(), field: field.to_string() },
expr::PayloadField { protocol: protocol.into(), field: field.into() },
)))
}

fn make_meta(key: expr::MetaKey) -> expr::Expression {
expr::Expression::Named(expr::NamedExpression::Meta(expr::Meta { key }))
}

fn make_exthdr(name: impl ToString, field: impl ToString, offset: u32) -> expr::Expression {
fn make_exthdr(name: &'static str, field: &'static str, offset: u32) -> expr::Expression {
expr::Expression::Named(expr::NamedExpression::Exthdr(expr::Exthdr {
name: name.to_string(),
field: field.to_string(),
name: name.into(),
field: field.into(),
offset,
}))
}

fn prefix_stmt(field: impl ToString, prefix: IpPrefix) -> Option<stmt::Statement> {
fn prefix_stmt(field: &'static str, prefix: IpPrefix) -> Option<stmt::Statement> {
(prefix.len() != 0).then(|| {
make_match(
stmt::Operator::EQ,
make_payload_field(if prefix.afi() == Afi::Ipv4 { "ip" } else { "ip6" }, field),
expr::Expression::Named(expr::NamedExpression::Prefix(expr::Prefix {
addr: Box::new(STRING(format!("{}", prefix.prefix()))),
addr: Box::new(STRING(format!("{}", prefix.prefix()).into())),
len: prefix.len().into(),
})),
)
Expand Down Expand Up @@ -510,9 +514,9 @@ fn range_stmt(left: expr::Expression, ops: &Ops<Numeric>, max: u64) -> Result<Op
let (start, end) = x.into_inner();
// HACK: Does nftables itself support 64-bit integers? We shrink it for now.
// But most of the matching expressions is smaller than 32 bits anyway.
let expr = (start == end)
.then_some(NUM(start as u32))
.unwrap_or_else(|| expr::Expression::Range(expr::Range { range: vec![NUM(start as u32), NUM(end as u32)] }));
let expr = (start == end).then_some(NUM(start as u32)).unwrap_or_else(|| {
expr::Expression::Range(Box::new(expr::Range { range: [NUM(start as u32), NUM(end as u32)] }))
});
expr::SetItem::Element(expr)
})
.collect();
Expand Down

0 comments on commit 968961d

Please sign in to comment.