mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(dns): restrict plain DNS output traffic
This commit is contained in:
+1
-1
@@ -394,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
|
||||
dnsLogger := logger.New(log.SetComponent("dns"))
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
|
||||
dnsLogger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating DNS loop: %w", err)
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type Firewall interface {
|
||||
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dns
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -24,6 +24,7 @@ type Loop struct {
|
||||
localResolvers []netip.Addr
|
||||
resolvConf string
|
||||
client *http.Client
|
||||
firewall Firewall
|
||||
logger Logger
|
||||
userTrigger bool
|
||||
start <-chan struct{}
|
||||
@@ -39,7 +40,7 @@ type Loop struct {
|
||||
const defaultBackoffTime = 10 * time.Second
|
||||
|
||||
func NewLoop(settings settings.DNS,
|
||||
client *http.Client, logger Logger,
|
||||
client *http.Client, firewall Firewall, logger Logger,
|
||||
) (loop *Loop, err error) {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
|
||||
filter: filter,
|
||||
resolvConf: "/etc/resolv.conf",
|
||||
client: client,
|
||||
firewall: firewall,
|
||||
logger: logger,
|
||||
userTrigger: true,
|
||||
start: start,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
)
|
||||
|
||||
func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
|
||||
settings := l.GetSettings()
|
||||
|
||||
targetIP := settings.GetFirstPlaintextIPv4()
|
||||
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
|
||||
const dialTimeout = 3 * time.Second
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
|
||||
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
Timeout: dialTimeout,
|
||||
}
|
||||
nameserver.UseDNSInternally(settingsInternalDNS)
|
||||
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
+5
-5
@@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
"and go through your container network DNS outside the VPN tunnel!")
|
||||
} else {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
if !errors.Is(err, errUpdateBlockLists) {
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
settings = l.GetSettings()
|
||||
@@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
settings = l.GetSettings()
|
||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
l.userTrigger = false
|
||||
@@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
settings := l.GetSettings()
|
||||
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.stopServer()
|
||||
}
|
||||
l.stopped <- struct{}{}
|
||||
@@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
case err := <-runError: // unexpected error
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -39,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
|
||||
// use internal DNS server
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
|
||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
})
|
||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||
IPs: []netip.Addr{settings.ServerAddress},
|
||||
@@ -50,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
|
||||
}
|
||||
|
||||
err = check.WaitForDNS(ctx, check.Settings{})
|
||||
if err != nil {
|
||||
l.stopServer()
|
||||
|
||||
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
|
||||
@@ -29,6 +29,7 @@ type Config struct {
|
||||
outboundSubnets []netip.Prefix
|
||||
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||
portRedirections portRedirections
|
||||
outputAddrPort map[uint16]netip.Addr
|
||||
stateMutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -52,6 +53,7 @@ func NewConfig(ctx context.Context, logger Logger,
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||
outputAddrPort: make(map[uint16]netip.Addr),
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
|
||||
@@ -2,6 +2,7 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
@@ -19,3 +20,15 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
|
||||
ipv4Instructions, ipv6Instructions []string,
|
||||
) error {
|
||||
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
|
||||
return fmt.Errorf("running iptables instructions: %w", err)
|
||||
}
|
||||
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
|
||||
return fmt.Errorf("running ip6tables instructions: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+111
-19
@@ -9,9 +9,19 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type operation uint8
|
||||
|
||||
const (
|
||||
opNone operation = iota
|
||||
opAppend
|
||||
opDelete
|
||||
opInsert
|
||||
opReplace
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
operation operation
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
@@ -22,6 +32,7 @@ type iptablesInstruction struct {
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
lineNumber uint16 // for replace operation, the line number to replace
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
@@ -60,6 +71,58 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
}
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) String() string {
|
||||
var sb strings.Builder
|
||||
if i.table != "" && i.table != "filter" {
|
||||
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
|
||||
}
|
||||
switch i.operation {
|
||||
case opNone:
|
||||
panic("no operation specified")
|
||||
case opAppend:
|
||||
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
|
||||
case opDelete:
|
||||
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
|
||||
case opInsert:
|
||||
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
|
||||
case opReplace:
|
||||
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
|
||||
}
|
||||
if i.inputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
|
||||
}
|
||||
if i.outputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
|
||||
}
|
||||
if i.protocol != "" {
|
||||
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
|
||||
}
|
||||
if i.source.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
|
||||
}
|
||||
if i.destination.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
|
||||
}
|
||||
if i.destinationPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
|
||||
}
|
||||
if len(i.ctstate) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
|
||||
}
|
||||
if len(i.toPorts) > 0 {
|
||||
var portStrings []string
|
||||
for _, port := range i.toPorts {
|
||||
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
|
||||
}
|
||||
if i.target != "" {
|
||||
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
@@ -77,34 +140,63 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "-R", "--replace":
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
consumed = expected
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
consumed = expected
|
||||
}
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.operation = opDelete
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.operation = opAppend
|
||||
instruction.chain = value
|
||||
case "-I", "--insert":
|
||||
instruction.operation = opInsert
|
||||
instruction.chain = value
|
||||
case "-R", "--replace":
|
||||
instruction.operation = opReplace
|
||||
instruction.chain = value
|
||||
const base, bits = 10, 16
|
||||
n, err := strconv.ParseUint(fields[2], base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
|
||||
}
|
||||
instruction.lineNumber = uint16(n)
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
@@ -117,18 +209,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
@@ -140,14 +232,14 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return nil
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
|
||||
@@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
s: "-I INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
operation: opInsert,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
operation: opAppend,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
operation: opDelete,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
|
||||
@@ -3,6 +3,7 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
|
||||
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
|
||||
// If the port was previously allowed for another IP address, that previous allowance will be removed.
|
||||
// Giving an invalid address will remove any existing restrictions for the port specified.
|
||||
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
existingIP := c.outputAddrPort[addrPort.Port()]
|
||||
|
||||
switch {
|
||||
case existingIP == addrPort.Addr():
|
||||
return nil
|
||||
case !addrPort.Addr().IsValid():
|
||||
// invalid address, remove any existing rules for the port
|
||||
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
|
||||
case !existingIP.IsValid():
|
||||
// no previous existing address for the port
|
||||
return c.insertOutputAddrPortRestriction(ctx, addrPort)
|
||||
default:
|
||||
// existing rule in the same IP family or different family
|
||||
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
}
|
||||
if existingIP.Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.outputAddrPort, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
}
|
||||
if addrPort.Addr().Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
|
||||
existingIP netip.Addr, addrPort netip.AddrPort,
|
||||
) (err error) {
|
||||
for _, protocol := range [...]string{"udp", "tcp"} {
|
||||
switch {
|
||||
case existingIP.Is4() && addrPort.Addr().Is4():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is6():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is4() && addrPort.Addr().Is6():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv4 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is4():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv6 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv4 rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIptablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
|
||||
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIP6tablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
Reference in New Issue
Block a user