From db947c17a87148ace93a3a242b8528f76264ab3f Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 10 Feb 2026 16:19:08 +0000 Subject: [PATCH] feat(dns): restrict plain DNS output traffic --- cmd/gluetun/main.go | 2 +- internal/dns/interfaces.go | 17 ++++ internal/dns/logger.go | 8 -- internal/dns/loop.go | 4 +- internal/dns/plaintext.go | 11 ++- internal/dns/run.go | 10 +-- internal/dns/setup.go | 8 +- internal/firewall/delete_test.go | 4 +- internal/firewall/firewall.go | 2 + internal/firewall/iptablesmix.go | 13 +++ internal/firewall/parse.go | 130 +++++++++++++++++++++++++----- internal/firewall/parse_test.go | 16 ++-- internal/firewall/ports.go | 131 +++++++++++++++++++++++++++++++ internal/firewall/replace.go | 51 ++++++++++++ 14 files changed, 360 insertions(+), 47 deletions(-) create mode 100644 internal/dns/interfaces.go delete mode 100644 internal/dns/logger.go create mode 100644 internal/firewall/replace.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 206294a8..99c7884e 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -394,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } dnsLogger := logger.New(log.SetComponent("dns")) - dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, + dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf, dnsLogger) if err != nil { return fmt.Errorf("creating DNS loop: %w", err) diff --git a/internal/dns/interfaces.go b/internal/dns/interfaces.go new file mode 100644 index 00000000..b0423330 --- /dev/null +++ b/internal/dns/interfaces.go @@ -0,0 +1,17 @@ +package dns + +import ( + "context" + "net/netip" +) + +type Logger interface { + Debug(s string) + Info(s string) + Warn(s string) + Error(s string) +} + +type Firewall interface { + RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) +} diff --git a/internal/dns/logger.go b/internal/dns/logger.go deleted file mode 100644 index e661a13d..00000000 --- a/internal/dns/logger.go +++ /dev/null @@ -1,8 +0,0 @@ -package dns - -type Logger interface { - Debug(s string) - Info(s string) - Warn(s string) - Error(s string) -} diff --git a/internal/dns/loop.go b/internal/dns/loop.go index a8f3fa1a..ffe3ba98 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -24,6 +24,7 @@ type Loop struct { localResolvers []netip.Addr resolvConf string client *http.Client + firewall Firewall logger Logger userTrigger bool start <-chan struct{} @@ -39,7 +40,7 @@ type Loop struct { const defaultBackoffTime = 10 * time.Second func NewLoop(settings settings.DNS, - client *http.Client, logger Logger, + client *http.Client, firewall Firewall, logger Logger, ) (loop *Loop, err error) { start := make(chan struct{}) running := make(chan models.LoopStatus) @@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS, filter: filter, resolvConf: "/etc/resolv.conf", client: client, + firewall: firewall, logger: logger, userTrigger: true, start: start, diff --git a/internal/dns/plaintext.go b/internal/dns/plaintext.go index 728a4b99..a8983ae6 100644 --- a/internal/dns/plaintext.go +++ b/internal/dns/plaintext.go @@ -1,13 +1,14 @@ package dns import ( + "context" "net/netip" "time" "github.com/qdm12/dns/v2/pkg/nameserver" ) -func (l *Loop) useUnencryptedDNS(fallback bool) { +func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) { settings := l.GetSettings() targetIP := settings.GetFirstPlaintextIPv4() @@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) { const dialTimeout = 3 * time.Second const defaultDNSPort = 53 + addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort) settingsInternalDNS := nameserver.SettingsInternalDNS{ - AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort), + AddrPort: addrPort, Timeout: dialTimeout, } nameserver.UseDNSInternally(settingsInternalDNS) @@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) { if err != nil { l.logger.Error(err.Error()) } + + err = l.firewall.RestrictOutputAddrPort(ctx, addrPort) + if err != nil { + l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error()) + } } diff --git a/internal/dns/run.go b/internal/dns/run.go index 2fec05a9..749f8013 100644 --- a/internal/dns/run.go +++ b/internal/dns/run.go @@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { "and go through your container network DNS outside the VPN tunnel!") } else { const fallback = false - l.useUnencryptedDNS(fallback) + l.useUnencryptedDNS(ctx, fallback) } select { @@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { if !errors.Is(err, errUpdateBlockLists) { const fallback = true - l.useUnencryptedDNS(fallback) + l.useUnencryptedDNS(ctx, fallback) } l.logAndWait(ctx, err) settings = l.GetSettings() @@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { settings = l.GetSettings() if !*settings.KeepNameserver && !*settings.ServerEnabled { const fallback = false - l.useUnencryptedDNS(fallback) + l.useUnencryptedDNS(ctx, fallback) } l.userTrigger = false @@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo settings := l.GetSettings() if !*settings.KeepNameserver && *settings.ServerEnabled { const fallback = false - l.useUnencryptedDNS(fallback) + l.useUnencryptedDNS(ctx, fallback) l.stopServer() } l.stopped <- struct{}{} @@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo case err := <-runError: // unexpected error l.statusManager.SetStatus(constants.Crashed) const fallback = true - l.useUnencryptedDNS(fallback) + l.useUnencryptedDNS(ctx, fallback) l.logAndWait(ctx, err) return false } diff --git a/internal/dns/setup.go b/internal/dns/setup.go index f93004a7..bcffdf4c 100644 --- a/internal/dns/setup.go +++ b/internal/dns/setup.go @@ -39,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro // use internal DNS server const defaultDNSPort = 53 + addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort) nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{ - AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort), + AddrPort: addrPort, }) err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{ IPs: []netip.Addr{settings.ServerAddress}, @@ -50,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro l.logger.Error(err.Error()) } + err = l.firewall.RestrictOutputAddrPort(ctx, addrPort) + if err != nil { + l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error()) + } + err = check.WaitForDNS(ctx, check.Settings{}) if err != nil { l.stopServer() diff --git a/internal/firewall/delete_test.go b/internal/firewall/delete_test.go index 1f6b5ceb..a0f5a6fa 100644 --- a/internal/firewall/delete_test.go +++ b/internal/firewall/delete_test.go @@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) { "invalid_instruction": { instruction: "invalid", errWrapped: ErrIptablesCommandMalformed, - errMessage: "parsing iptables command: iptables command is malformed: " + - "fields count 1 is not even: \"invalid\"", + errMessage: "parsing iptables command: parsing \"invalid\": " + + "iptables command is malformed: flag \"invalid\" requires a value, but got none", }, "list_error": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index f1618443..8b4b66c2 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -29,6 +29,7 @@ type Config struct { outboundSubnets []netip.Prefix allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping portRedirections portRedirections + outputAddrPort map[uint16]netip.Addr stateMutex sync.Mutex } @@ -52,6 +53,7 @@ func NewConfig(ctx context.Context, logger Logger, runner: runner, logger: logger, allowedInputPorts: make(map[uint16]map[string]struct{}), + outputAddrPort: make(map[uint16]netip.Addr), ipTables: iptables, ip6Tables: ip6tables, customRulesPath: "/iptables/post-rules.txt", diff --git a/internal/firewall/iptablesmix.go b/internal/firewall/iptablesmix.go index 8d45c737..d713dc37 100644 --- a/internal/firewall/iptablesmix.go +++ b/internal/firewall/iptablesmix.go @@ -2,6 +2,7 @@ package firewall import ( "context" + "fmt" ) func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { @@ -19,3 +20,15 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st } return c.runIP6tablesInstruction(ctx, instruction) } + +func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context, + ipv4Instructions, ipv6Instructions []string, +) error { + if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil { + return fmt.Errorf("running iptables instructions: %w", err) + } + if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil { + return fmt.Errorf("running ip6tables instructions: %w", err) + } + return nil +} diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go index d2d046cb..9f21d4ad 100644 --- a/internal/firewall/parse.go +++ b/internal/firewall/parse.go @@ -9,9 +9,19 @@ import ( "strings" ) +type operation uint8 + +const ( + opNone operation = iota + opAppend + opDelete + opInsert + opReplace +) + type iptablesInstruction struct { table string // defaults to "filter", and can be "nat" for example. - append bool + operation operation chain string // for example INPUT, PREROUTING. Cannot be empty. target string // for example ACCEPT. Can be empty. protocol string // "tcp" or "udp" or "" for all protocols. @@ -22,6 +32,7 @@ type iptablesInstruction struct { destinationPort uint16 // if zero, there is no destination port toPorts []uint16 // if empty, there is no redirection ctstate []string // if empty, there is no ctstate + lineNumber uint16 // for replace operation, the line number to replace } func (i *iptablesInstruction) setDefaults() { @@ -60,6 +71,58 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) ( } } +func (i *iptablesInstruction) String() string { + var sb strings.Builder + if i.table != "" && i.table != "filter" { + sb.WriteString(fmt.Sprintf("-t %s ", i.table)) + } + switch i.operation { + case opNone: + panic("no operation specified") + case opAppend: + sb.WriteString(fmt.Sprintf("--append %s ", i.chain)) + case opDelete: + sb.WriteString(fmt.Sprintf("--delete %s ", i.chain)) + case opInsert: + sb.WriteString(fmt.Sprintf("--insert %s ", i.chain)) + case opReplace: + sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber)) + } + if i.inputInterface != "" { + sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface)) + } + if i.outputInterface != "" { + sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface)) + } + if i.protocol != "" { + sb.WriteString(fmt.Sprintf("-p %s ", i.protocol)) + } + if i.source.IsValid() { + sb.WriteString(fmt.Sprintf("-s %s ", i.source.String())) + } + if i.destination.IsValid() { + sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String())) + } + if i.destinationPort != 0 { + sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort)) + } + if len(i.ctstate) > 0 { + sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ","))) + } + if len(i.toPorts) > 0 { + var portStrings []string + for _, port := range i.toPorts { + portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10)) + } + sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ","))) + } + if i.target != "" { + sb.WriteString(fmt.Sprintf("-j %s ", i.target)) + } + + return strings.TrimSpace(sb.String()) +} + // instruction can be "" which equivalent to the "*" chain rule interface. func networkInterfacesEqual(instruction, chainRule string) bool { return instruction == chainRule || (instruction == "" && chainRule == "*") @@ -77,34 +140,63 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) } fields := strings.Fields(s) - if len(fields)%2 != 0 { - return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q", - ErrIptablesCommandMalformed, len(fields), s) - } - for i := 0; i < len(fields); i += 2 { - key := fields[i] - value := fields[i+1] - err = parseInstructionFlag(key, value, &instruction) + i := 0 + for i < len(fields) { + consumed, err := parseInstructionFlag(fields[i:], &instruction) if err != nil { return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err) } + i += consumed } instruction.setDefaults() return instruction, nil } -func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) { - switch key { +func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) { + flag := fields[0] + + // All flags use one value after the flag, except the following: + switch flag { + case "-R", "--replace": + const expected = 3 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", + ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) + } + consumed = expected + default: + const expected = 2 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires a value, but got none", + ErrIptablesCommandMalformed, flag) + } + consumed = expected + } + value := fields[1] + + switch flag { case "-t", "--table": instruction.table = value case "-D", "--delete": - instruction.append = false + instruction.operation = opDelete instruction.chain = value case "-A", "--append": - instruction.append = true + instruction.operation = opAppend instruction.chain = value + case "-I", "--insert": + instruction.operation = opInsert + instruction.chain = value + case "-R", "--replace": + instruction.operation = opReplace + instruction.chain = value + const base, bits = 10, 16 + n, err := strconv.ParseUint(fields[2], base, bits) + if err != nil { + return 0, fmt.Errorf("parsing line number for --replace operation: %w", err) + } + instruction.lineNumber = uint16(n) case "-j", "--jump": instruction.target = value case "-p", "--protocol": @@ -117,18 +209,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) ( case "-s", "--source": instruction.source, err = parseIPPrefix(value) if err != nil { - return fmt.Errorf("parsing source IP CIDR: %w", err) + return 0, fmt.Errorf("parsing source IP CIDR: %w", err) } case "-d", "--destination": instruction.destination, err = parseIPPrefix(value) if err != nil { - return fmt.Errorf("parsing destination IP CIDR: %w", err) + return 0, fmt.Errorf("parsing destination IP CIDR: %w", err) } case "--dport": const base, bitLength = 10, 16 destinationPort, err := strconv.ParseUint(value, base, bitLength) if err != nil { - return fmt.Errorf("parsing destination port: %w", err) + return 0, fmt.Errorf("parsing destination port: %w", err) } instruction.destinationPort = uint16(destinationPort) case "--ctstate": @@ -140,14 +232,14 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) ( const base, bitLength = 10, 16 port, err := strconv.ParseUint(portString, base, bitLength) if err != nil { - return fmt.Errorf("parsing port redirection: %w", err) + return 0, fmt.Errorf("parsing port redirection: %w", err) } instruction.toPorts[i] = uint16(port) } default: - return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key) + return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag) } - return nil + return consumed, nil } func parseIPPrefix(value string) (prefix netip.Prefix, err error) { diff --git a/internal/firewall/parse_test.go b/internal/firewall/parse_test.go index ae07bc6b..d56f98b7 100644 --- a/internal/firewall/parse_test.go +++ b/internal/firewall/parse_test.go @@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) { "uneven_fields": { s: "-A", errWrapped: ErrIptablesCommandMalformed, - errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"", + errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none", }, "unknown_key": { s: "-x something", errWrapped: ErrIptablesCommandMalformed, - errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"", + errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"", }, "one_pair": { - s: "-A INPUT", + s: "-I INPUT", instruction: iptablesInstruction{ - table: "filter", - chain: "INPUT", - append: true, + table: "filter", + chain: "INPUT", + operation: opInsert, }, }, "instruction_A": { @@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) { instruction: iptablesInstruction{ table: "filter", chain: "INPUT", - append: true, + operation: opAppend, inputInterface: "tun0", protocol: "tcp", source: netip.MustParsePrefix("1.2.3.4/32"), @@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) { instruction: iptablesInstruction{ table: "nat", chain: "PREROUTING", - append: false, + operation: opDelete, inputInterface: "tun0", protocol: "tcp", destinationPort: 43716, diff --git a/internal/firewall/ports.go b/internal/firewall/ports.go index 6f5867ee..cfbd6bca 100644 --- a/internal/firewall/ports.go +++ b/internal/firewall/ports.go @@ -3,6 +3,7 @@ package firewall import ( "context" "fmt" + "net/netip" "strconv" ) @@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) return nil } + +// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp, +// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6. +// If the port was previously allowed for another IP address, that previous allowance will be removed. +// Giving an invalid address will remove any existing restrictions for the port specified. +func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + existingIP := c.outputAddrPort[addrPort.Port()] + + switch { + case existingIP == addrPort.Addr(): + return nil + case !addrPort.Addr().IsValid(): + // invalid address, remove any existing rules for the port + return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port()) + case !existingIP.IsValid(): + // no previous existing address for the port + return c.insertOutputAddrPortRestriction(ctx, addrPort) + default: + // existing rule in the same IP family or different family + return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort) + } +} + +func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) { + commonInstructions := []string{ + fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port), + fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port), + } + ipv4Instructions := commonInstructions + ipv6Instructions := commonInstructions + + familySpecificInstructions := []string{ + fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP), + fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP), + } + if existingIP.Is4() { + ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...) + } else { + ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...) + } + + err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions) + if err != nil { + return err + } + delete(c.outputAddrPort, port) + return nil +} + +func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) { + commonInstructions := []string{ + fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()), + fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()), + } + ipv4Instructions := commonInstructions + ipv6Instructions := commonInstructions + + familySpecificInstructions := []string{ + fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()), + fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()), + } + if addrPort.Addr().Is4() { + ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...) + } else { + ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...) + } + err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions) + if err != nil { + return err + } + c.outputAddrPort[addrPort.Port()] = addrPort.Addr() + return nil +} + +func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context, + existingIP netip.Addr, addrPort netip.AddrPort, +) (err error) { + for _, protocol := range [...]string{"udp", "tcp"} { + switch { + case existingIP.Is4() && addrPort.Addr().Is4(): + oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), existingIP) + newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), addrPort.Addr()) + err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction) + if err != nil { + return fmt.Errorf("replacing existing IPv4 rule: %w", err) + } + case existingIP.Is6() && addrPort.Addr().Is6(): + oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), existingIP) + newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), addrPort.Addr()) + err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction) + if err != nil { + return fmt.Errorf("replacing existing IPv6 rule: %w", err) + } + case existingIP.Is4() && addrPort.Addr().Is6(): + instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), existingIP) + err = c.runIptablesInstruction(ctx, instruction) + if err != nil { + return fmt.Errorf("removing existing IPv4 rule: %w", err) + } + instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), addrPort.Addr()) + err = c.runIP6tablesInstruction(ctx, instruction) + if err != nil { + return fmt.Errorf("inserting new IPv6 rule: %w", err) + } + case existingIP.Is6() && addrPort.Addr().Is4(): + instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), existingIP) + err = c.runIP6tablesInstruction(ctx, instruction) + if err != nil { + return fmt.Errorf("removing existing IPv6 rule: %w", err) + } + instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT", + protocol, addrPort.Port(), addrPort.Addr()) + err = c.runIptablesInstruction(ctx, instruction) + if err != nil { + return fmt.Errorf("inserting new IPv4 rule: %w", err) + } + } + } + c.outputAddrPort[addrPort.Port()] = addrPort.Addr() + return nil +} diff --git a/internal/firewall/replace.go b/internal/firewall/replace.go new file mode 100644 index 00000000..ac3577f5 --- /dev/null +++ b/internal/firewall/replace.go @@ -0,0 +1,51 @@ +package firewall + +import ( + "context" + "errors" + "fmt" +) + +var errRuleNotFound = errors.New("rule not found") + +func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error { + targetRule, err := parseIptablesInstruction(oldInstruction) + if err != nil { + return fmt.Errorf("parsing iptables command to replace: %w", err) + } + + lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger) + if err != nil { + return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err) + } else if lineNumber == 0 { + return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction) + } + parsed, err := parseIptablesInstruction(newInstruction) + if err != nil { + return fmt.Errorf("parsing replacement iptables command: %w", err) + } + parsed.operation = opReplace + parsed.lineNumber = lineNumber + return c.runIptablesInstruction(ctx, parsed.String()) +} + +func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error { + targetRule, err := parseIptablesInstruction(oldInstruction) + if err != nil { + return fmt.Errorf("parsing iptables command to replace: %w", err) + } + + lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger) + if err != nil { + return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err) + } else if lineNumber == 0 { + return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction) + } + parsed, err := parseIptablesInstruction(newInstruction) + if err != nil { + return fmt.Errorf("parsing replacement iptables command: %w", err) + } + parsed.operation = opReplace + parsed.lineNumber = lineNumber + return c.runIP6tablesInstruction(ctx, parsed.String()) +}