feat(dns): restrict plain DNS output traffic

This commit is contained in:
Quentin McGaw
2026-02-10 16:19:08 +00:00
parent b0a75673bd
commit db947c17a8
14 changed files with 360 additions and 47 deletions
+131
View File
@@ -3,6 +3,7 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"strconv"
)
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
return nil
}
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
// If the port was previously allowed for another IP address, that previous allowance will be removed.
// Giving an invalid address will remove any existing restrictions for the port specified.
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
existingIP := c.outputAddrPort[addrPort.Port()]
switch {
case existingIP == addrPort.Addr():
return nil
case !addrPort.Addr().IsValid():
// invalid address, remove any existing rules for the port
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
case !existingIP.IsValid():
// no previous existing address for the port
return c.insertOutputAddrPortRestriction(ctx, addrPort)
default:
// existing rule in the same IP family or different family
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
}
}
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
commonInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
}
if existingIP.Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
delete(c.outputAddrPort, port)
return nil
}
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
commonInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
}
if addrPort.Addr().Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
existingIP netip.Addr, addrPort netip.AddrPort,
) (err error) {
for _, protocol := range [...]string{"udp", "tcp"} {
switch {
case existingIP.Is4() && addrPort.Addr().Is4():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is6():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
}
case existingIP.Is4() && addrPort.Addr().Is6():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv4 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv6 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is4():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv6 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv4 rule: %w", err)
}
}
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}