mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-17 14:10:09 +02:00
wip
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(intf + "\x00"),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conn.AddRule(rule)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
|
||||
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
|
||||
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
|
||||
// intf set to the VPN tunnel interface.
|
||||
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
|
||||
const tcp, udp uint8 = 6, 17
|
||||
protocols := []uint8{tcp, udp}
|
||||
|
||||
for _, protocol := range protocols {
|
||||
const maxExprsLen = 7
|
||||
exprs := make([]expr.Any, 0, maxExprsLen)
|
||||
if intf != "" && intf != "*" {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
if !remove {
|
||||
conn.AddRule(rule)
|
||||
f.rules = append(f.rules, rule)
|
||||
continue
|
||||
}
|
||||
err = f.deleteRule(conn, rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
f.rules = f.rules[:len(f.rules)-len(protocols)]
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
|
||||
const maxExprsLen = 5
|
||||
exprs := make([]expr.Any, 0, maxExprsLen)
|
||||
|
||||
if intf != "" && intf != "*" {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
|
||||
)
|
||||
}
|
||||
|
||||
var payloadOffset uint32
|
||||
if subnet.Addr().Is4() {
|
||||
payloadOffset = 16
|
||||
} else {
|
||||
payloadOffset = 24
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: payloadOffset,
|
||||
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: subnet.Addr().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
conn.AddRule(rule)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user