mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(dns): restrict plain DNS output traffic
This commit is contained in:
+1
-1
@@ -394,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
dnsLogger := logger.New(log.SetComponent("dns"))
|
dnsLogger := logger.New(log.SetComponent("dns"))
|
||||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
|
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
|
||||||
dnsLogger)
|
dnsLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating DNS loop: %w", err)
|
return fmt.Errorf("creating DNS loop: %w", err)
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(s string)
|
||||||
|
Info(s string)
|
||||||
|
Warn(s string)
|
||||||
|
Error(s string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Firewall interface {
|
||||||
|
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
|
||||||
|
}
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
type Logger interface {
|
|
||||||
Debug(s string)
|
|
||||||
Info(s string)
|
|
||||||
Warn(s string)
|
|
||||||
Error(s string)
|
|
||||||
}
|
|
||||||
@@ -24,6 +24,7 @@ type Loop struct {
|
|||||||
localResolvers []netip.Addr
|
localResolvers []netip.Addr
|
||||||
resolvConf string
|
resolvConf string
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
firewall Firewall
|
||||||
logger Logger
|
logger Logger
|
||||||
userTrigger bool
|
userTrigger bool
|
||||||
start <-chan struct{}
|
start <-chan struct{}
|
||||||
@@ -39,7 +40,7 @@ type Loop struct {
|
|||||||
const defaultBackoffTime = 10 * time.Second
|
const defaultBackoffTime = 10 * time.Second
|
||||||
|
|
||||||
func NewLoop(settings settings.DNS,
|
func NewLoop(settings settings.DNS,
|
||||||
client *http.Client, logger Logger,
|
client *http.Client, firewall Firewall, logger Logger,
|
||||||
) (loop *Loop, err error) {
|
) (loop *Loop, err error) {
|
||||||
start := make(chan struct{})
|
start := make(chan struct{})
|
||||||
running := make(chan models.LoopStatus)
|
running := make(chan models.LoopStatus)
|
||||||
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
|
|||||||
filter: filter,
|
filter: filter,
|
||||||
resolvConf: "/etc/resolv.conf",
|
resolvConf: "/etc/resolv.conf",
|
||||||
client: client,
|
client: client,
|
||||||
|
firewall: firewall,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
userTrigger: true,
|
userTrigger: true,
|
||||||
start: start,
|
start: start,
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *Loop) useUnencryptedDNS(fallback bool) {
|
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
|
|
||||||
targetIP := settings.GetFirstPlaintextIPv4()
|
targetIP := settings.GetFirstPlaintextIPv4()
|
||||||
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
|||||||
|
|
||||||
const dialTimeout = 3 * time.Second
|
const dialTimeout = 3 * time.Second
|
||||||
const defaultDNSPort = 53
|
const defaultDNSPort = 53
|
||||||
|
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
|
||||||
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
||||||
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
|
AddrPort: addrPort,
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
}
|
}
|
||||||
nameserver.UseDNSInternally(settingsInternalDNS)
|
nameserver.UseDNSInternally(settingsInternalDNS)
|
||||||
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-5
@@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
"and go through your container network DNS outside the VPN tunnel!")
|
"and go through your container network DNS outside the VPN tunnel!")
|
||||||
} else {
|
} else {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
|
|
||||||
if !errors.Is(err, errUpdateBlockLists) {
|
if !errors.Is(err, errUpdateBlockLists) {
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
settings = l.GetSettings()
|
settings = l.GetSettings()
|
||||||
@@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
settings = l.GetSettings()
|
settings = l.GetSettings()
|
||||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.userTrigger = false
|
l.userTrigger = false
|
||||||
@@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
|||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
l.stopServer()
|
l.stopServer()
|
||||||
}
|
}
|
||||||
l.stopped <- struct{}{}
|
l.stopped <- struct{}{}
|
||||||
@@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
|||||||
case err := <-runError: // unexpected error
|
case err := <-runError: // unexpected error
|
||||||
l.statusManager.SetStatus(constants.Crashed)
|
l.statusManager.SetStatus(constants.Crashed)
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
|||||||
|
|
||||||
// use internal DNS server
|
// use internal DNS server
|
||||||
const defaultDNSPort = 53
|
const defaultDNSPort = 53
|
||||||
|
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
|
||||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
AddrPort: addrPort,
|
||||||
})
|
})
|
||||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||||
IPs: []netip.Addr{settings.ServerAddress},
|
IPs: []netip.Addr{settings.ServerAddress},
|
||||||
@@ -50,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
|||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
err = check.WaitForDNS(ctx, check.Settings{})
|
err = check.WaitForDNS(ctx, check.Settings{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.stopServer()
|
l.stopServer()
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
"invalid_instruction": {
|
"invalid_instruction": {
|
||||||
instruction: "invalid",
|
instruction: "invalid",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||||
"fields count 1 is not even: \"invalid\"",
|
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"list_error": {
|
"list_error": {
|
||||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ type Config struct {
|
|||||||
outboundSubnets []netip.Prefix
|
outboundSubnets []netip.Prefix
|
||||||
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||||
portRedirections portRedirections
|
portRedirections portRedirections
|
||||||
|
outputAddrPort map[uint16]netip.Addr
|
||||||
stateMutex sync.Mutex
|
stateMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ func NewConfig(ctx context.Context, logger Logger,
|
|||||||
runner: runner,
|
runner: runner,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||||
|
outputAddrPort: make(map[uint16]netip.Addr),
|
||||||
ipTables: iptables,
|
ipTables: iptables,
|
||||||
ip6Tables: ip6tables,
|
ip6Tables: ip6tables,
|
||||||
customRulesPath: "/iptables/post-rules.txt",
|
customRulesPath: "/iptables/post-rules.txt",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package firewall
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||||
@@ -19,3 +20,15 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
|
|||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
|
||||||
|
ipv4Instructions, ipv6Instructions []string,
|
||||||
|
) error {
|
||||||
|
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
|
||||||
|
return fmt.Errorf("running iptables instructions: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
|
||||||
|
return fmt.Errorf("running ip6tables instructions: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
+111
-19
@@ -9,9 +9,19 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type operation uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
opNone operation = iota
|
||||||
|
opAppend
|
||||||
|
opDelete
|
||||||
|
opInsert
|
||||||
|
opReplace
|
||||||
|
)
|
||||||
|
|
||||||
type iptablesInstruction struct {
|
type iptablesInstruction struct {
|
||||||
table string // defaults to "filter", and can be "nat" for example.
|
table string // defaults to "filter", and can be "nat" for example.
|
||||||
append bool
|
operation operation
|
||||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||||
target string // for example ACCEPT. Can be empty.
|
target string // for example ACCEPT. Can be empty.
|
||||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||||
@@ -22,6 +32,7 @@ type iptablesInstruction struct {
|
|||||||
destinationPort uint16 // if zero, there is no destination port
|
destinationPort uint16 // if zero, there is no destination port
|
||||||
toPorts []uint16 // if empty, there is no redirection
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
ctstate []string // if empty, there is no ctstate
|
ctstate []string // if empty, there is no ctstate
|
||||||
|
lineNumber uint16 // for replace operation, the line number to replace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesInstruction) setDefaults() {
|
func (i *iptablesInstruction) setDefaults() {
|
||||||
@@ -60,6 +71,58 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iptablesInstruction) String() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if i.table != "" && i.table != "filter" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
|
||||||
|
}
|
||||||
|
switch i.operation {
|
||||||
|
case opNone:
|
||||||
|
panic("no operation specified")
|
||||||
|
case opAppend:
|
||||||
|
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
|
||||||
|
case opDelete:
|
||||||
|
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
|
||||||
|
case opInsert:
|
||||||
|
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
|
||||||
|
case opReplace:
|
||||||
|
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
|
||||||
|
}
|
||||||
|
if i.inputInterface != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
|
||||||
|
}
|
||||||
|
if i.outputInterface != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
|
||||||
|
}
|
||||||
|
if i.protocol != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
|
||||||
|
}
|
||||||
|
if i.source.IsValid() {
|
||||||
|
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
|
||||||
|
}
|
||||||
|
if i.destination.IsValid() {
|
||||||
|
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
|
||||||
|
}
|
||||||
|
if i.destinationPort != 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
|
||||||
|
}
|
||||||
|
if len(i.ctstate) > 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
|
||||||
|
}
|
||||||
|
if len(i.toPorts) > 0 {
|
||||||
|
var portStrings []string
|
||||||
|
for _, port := range i.toPorts {
|
||||||
|
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
|
||||||
|
}
|
||||||
|
if i.target != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||||
@@ -77,34 +140,63 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
|||||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||||
}
|
}
|
||||||
fields := strings.Fields(s)
|
fields := strings.Fields(s)
|
||||||
if len(fields)%2 != 0 {
|
|
||||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
|
||||||
ErrIptablesCommandMalformed, len(fields), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(fields); i += 2 {
|
i := 0
|
||||||
key := fields[i]
|
for i < len(fields) {
|
||||||
value := fields[i+1]
|
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||||
err = parseInstructionFlag(key, value, &instruction)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||||
}
|
}
|
||||||
|
i += consumed
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction.setDefaults()
|
instruction.setDefaults()
|
||||||
return instruction, nil
|
return instruction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||||
switch key {
|
flag := fields[0]
|
||||||
|
|
||||||
|
// All flags use one value after the flag, except the following:
|
||||||
|
switch flag {
|
||||||
|
case "-R", "--replace":
|
||||||
|
const expected = 3
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||||
|
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
default:
|
||||||
|
const expected = 2
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||||
|
ErrIptablesCommandMalformed, flag)
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
}
|
||||||
|
value := fields[1]
|
||||||
|
|
||||||
|
switch flag {
|
||||||
case "-t", "--table":
|
case "-t", "--table":
|
||||||
instruction.table = value
|
instruction.table = value
|
||||||
case "-D", "--delete":
|
case "-D", "--delete":
|
||||||
instruction.append = false
|
instruction.operation = opDelete
|
||||||
instruction.chain = value
|
instruction.chain = value
|
||||||
case "-A", "--append":
|
case "-A", "--append":
|
||||||
instruction.append = true
|
instruction.operation = opAppend
|
||||||
instruction.chain = value
|
instruction.chain = value
|
||||||
|
case "-I", "--insert":
|
||||||
|
instruction.operation = opInsert
|
||||||
|
instruction.chain = value
|
||||||
|
case "-R", "--replace":
|
||||||
|
instruction.operation = opReplace
|
||||||
|
instruction.chain = value
|
||||||
|
const base, bits = 10, 16
|
||||||
|
n, err := strconv.ParseUint(fields[2], base, bits)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
|
||||||
|
}
|
||||||
|
instruction.lineNumber = uint16(n)
|
||||||
case "-j", "--jump":
|
case "-j", "--jump":
|
||||||
instruction.target = value
|
instruction.target = value
|
||||||
case "-p", "--protocol":
|
case "-p", "--protocol":
|
||||||
@@ -117,18 +209,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
case "-s", "--source":
|
case "-s", "--source":
|
||||||
instruction.source, err = parseIPPrefix(value)
|
instruction.source, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "-d", "--destination":
|
case "-d", "--destination":
|
||||||
instruction.destination, err = parseIPPrefix(value)
|
instruction.destination, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "--dport":
|
case "--dport":
|
||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
instruction.destinationPort = uint16(destinationPort)
|
instruction.destinationPort = uint16(destinationPort)
|
||||||
case "--ctstate":
|
case "--ctstate":
|
||||||
@@ -140,14 +232,14 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing port redirection: %w", err)
|
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||||
}
|
}
|
||||||
instruction.toPorts[i] = uint16(port)
|
instruction.toPorts[i] = uint16(port)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
|
||||||
}
|
}
|
||||||
return nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||||
|
|||||||
@@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
"uneven_fields": {
|
"uneven_fields": {
|
||||||
s: "-A",
|
s: "-A",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"unknown_key": {
|
"unknown_key": {
|
||||||
s: "-x something",
|
s: "-x something",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
|
||||||
},
|
},
|
||||||
"one_pair": {
|
"one_pair": {
|
||||||
s: "-A INPUT",
|
s: "-I INPUT",
|
||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "filter",
|
table: "filter",
|
||||||
chain: "INPUT",
|
chain: "INPUT",
|
||||||
append: true,
|
operation: opInsert,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"instruction_A": {
|
"instruction_A": {
|
||||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "filter",
|
table: "filter",
|
||||||
chain: "INPUT",
|
chain: "INPUT",
|
||||||
append: true,
|
operation: opAppend,
|
||||||
inputInterface: "tun0",
|
inputInterface: "tun0",
|
||||||
protocol: "tcp",
|
protocol: "tcp",
|
||||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "nat",
|
table: "nat",
|
||||||
chain: "PREROUTING",
|
chain: "PREROUTING",
|
||||||
append: false,
|
operation: opDelete,
|
||||||
inputInterface: "tun0",
|
inputInterface: "tun0",
|
||||||
protocol: "tcp",
|
protocol: "tcp",
|
||||||
destinationPort: 43716,
|
destinationPort: 43716,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package firewall
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
|
||||||
|
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
|
||||||
|
// If the port was previously allowed for another IP address, that previous allowance will be removed.
|
||||||
|
// Giving an invalid address will remove any existing restrictions for the port specified.
|
||||||
|
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
existingIP := c.outputAddrPort[addrPort.Port()]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case existingIP == addrPort.Addr():
|
||||||
|
return nil
|
||||||
|
case !addrPort.Addr().IsValid():
|
||||||
|
// invalid address, remove any existing rules for the port
|
||||||
|
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
|
||||||
|
case !existingIP.IsValid():
|
||||||
|
// no previous existing address for the port
|
||||||
|
return c.insertOutputAddrPortRestriction(ctx, addrPort)
|
||||||
|
default:
|
||||||
|
// existing rule in the same IP family or different family
|
||||||
|
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
|
||||||
|
commonInstructions := []string{
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
|
||||||
|
}
|
||||||
|
ipv4Instructions := commonInstructions
|
||||||
|
ipv6Instructions := commonInstructions
|
||||||
|
|
||||||
|
familySpecificInstructions := []string{
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||||
|
}
|
||||||
|
if existingIP.Is4() {
|
||||||
|
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||||
|
} else {
|
||||||
|
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(c.outputAddrPort, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||||
|
commonInstructions := []string{
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
|
||||||
|
}
|
||||||
|
ipv4Instructions := commonInstructions
|
||||||
|
ipv6Instructions := commonInstructions
|
||||||
|
|
||||||
|
familySpecificInstructions := []string{
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||||
|
}
|
||||||
|
if addrPort.Addr().Is4() {
|
||||||
|
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||||
|
} else {
|
||||||
|
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||||
|
}
|
||||||
|
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
|
||||||
|
existingIP netip.Addr, addrPort netip.AddrPort,
|
||||||
|
) (err error) {
|
||||||
|
for _, protocol := range [...]string{"udp", "tcp"} {
|
||||||
|
switch {
|
||||||
|
case existingIP.Is4() && addrPort.Addr().Is4():
|
||||||
|
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is6() && addrPort.Addr().Is6():
|
||||||
|
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is4() && addrPort.Addr().Is6():
|
||||||
|
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
err = c.runIptablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("removing existing IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting new IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is6() && addrPort.Addr().Is4():
|
||||||
|
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("removing existing IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.runIptablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting new IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errRuleNotFound = errors.New("rule not found")
|
||||||
|
|
||||||
|
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||||
|
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||||
|
}
|
||||||
|
parsed, err := parseIptablesInstruction(newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||||
|
}
|
||||||
|
parsed.operation = opReplace
|
||||||
|
parsed.lineNumber = lineNumber
|
||||||
|
return c.runIptablesInstruction(ctx, parsed.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||||
|
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||||
|
}
|
||||||
|
parsed, err := parseIptablesInstruction(newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||||
|
}
|
||||||
|
parsed.operation = opReplace
|
||||||
|
parsed.lineNumber = lineNumber
|
||||||
|
return c.runIP6tablesInstruction(ctx, parsed.String())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user