mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
51 lines
1.1 KiB
Go
51 lines
1.1 KiB
Go
package nftables
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/google/nftables"
|
|
)
|
|
|
|
var ErrPolicyUnknown = errors.New("unknown policy")
|
|
|
|
// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT)
|
|
// for the filter table to the given policy (accept or drop).
|
|
func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error {
|
|
f.mutex.Lock()
|
|
defer f.mutex.Unlock()
|
|
|
|
var chainPolicy nftables.ChainPolicy
|
|
switch strings.ToLower(policy) {
|
|
case "accept":
|
|
chainPolicy = nftables.ChainPolicyAccept
|
|
case "drop":
|
|
chainPolicy = nftables.ChainPolicyDrop
|
|
default:
|
|
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
|
}
|
|
|
|
conn, err := nftables.New()
|
|
if err != nil {
|
|
return fmt.Errorf("creating nftables connection: %w", err)
|
|
}
|
|
|
|
_, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn)
|
|
inputChain.Policy = &chainPolicy
|
|
forwardChain.Policy = &chainPolicy
|
|
outputChain.Policy = &chainPolicy
|
|
|
|
conn.AddChain(inputChain)
|
|
conn.AddChain(forwardChain)
|
|
conn.AddChain(outputChain)
|
|
|
|
err = conn.Flush()
|
|
if err != nil {
|
|
return fmt.Errorf("flushing nftables changes: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|