Compare commits

..

1 Commits

Author SHA1 Message Date
Quentin McGaw db947c17a8 feat(dns): restrict plain DNS output traffic 2026-02-10 16:19:08 +00:00
151 changed files with 4707 additions and 6641 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl tcpdump iptables nftables
RUN apk add wireguard-tools htop openssl
-1
View File
@@ -45,7 +45,6 @@ jobs:
level: error
exclude: |
./internal/storage/servers.json
./golangci.yml
*.md
- name: Linting
+1 -2
View File
@@ -22,7 +22,6 @@ linters:
- "^disabled$"
# Firewall and routing strings
- "^(ACCEPT|DROP)$"
- "^--append$"
- "^--delete$"
- "^all$"
- "^(tcp|udp)$"
@@ -48,7 +47,7 @@ linters:
path: internal\/server\/.+\.go
- linters:
- ireturn
text: returns interface \(golang\.org\/x\/sys\/unix\.Sockaddr\)
text: returns interface \(github\.com\/vishvananda\/netlink\.Link\)
- linters:
- ireturn
path: internal\/openvpn\/pkcs8\/descbc\.go
+2 -5
View File
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
RUN apk --update add git g++ findutils iptables
RUN apk --update add git g++ findutils
ENV CGO_ENABLED=0
COPY --from=golangci-lint /bin /go/bin/golangci-lint
COPY --from=mockgen /bin /go/bin/mockgen
@@ -110,11 +110,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU= \
WIREGUARD_MTU=1320 \
WIREGUARD_IMPLEMENTATION=auto \
# PMTUD
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
# VPN server filtering
SERVER_REGIONS= \
SERVER_COUNTRIES= \
+1 -1
View File
@@ -58,7 +58,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
## Features
- Based on Alpine 3.22 for a small Docker image of 41.1MB
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports OpenVPN for all providers listed
- Supports Wireguard both kernelspace and userspace
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
+5 -10
View File
@@ -168,7 +168,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
defer fmt.Println(gluetunLogo)
announcementExp, err := time.Parse(time.RFC3339, "2026-04-01T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
if err != nil {
return err
}
@@ -179,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version,
Commit: buildInfo.Commit,
Created: buildInfo.Created,
Announcement: "All control server routes are now private by default",
Announcement: "All control server routes will become private by default after the v3.41.0 release",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -237,10 +237,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
if err != nil {
return err
}
err = netLinker.FlushConntrack()
if err != nil {
logger.Warnf("flushing conntrack failed: %s", err)
}
}
// TODO run this in a loop or in openvpn to reload from file without restarting
@@ -268,7 +264,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID)
const clientTimeout = 35 * time.Second
const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
// Create configurators
alpineConf := alpine.New()
@@ -283,7 +279,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
err = printVersions(ctx, logger, []printVersionElement{
{name: "Alpine", getVersion: alpineConf.Version},
{name: "OpenVPN", getVersion: ovpnVersion},
{name: "Firewall", getVersion: firewallConf.Version},
{name: "IPtables", getVersion: firewallConf.Version},
})
if err != nil {
return err
@@ -398,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
dnsLogger := logger.New(log.SetComponent("dns"))
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
dnsLogger)
if err != nil {
return fmt.Errorf("creating DNS loop: %w", err)
@@ -560,7 +556,6 @@ type netLinker interface {
Linker
IsWireguardSupported() (ok bool, err error)
IsIPv6Supported() (ok bool, err error)
FlushConntrack() error
PatchLoggerLevel(level log.Level)
}
+4 -5
View File
@@ -4,17 +4,15 @@ go 1.25.0
require (
github.com/ProtonMail/go-srp v0.0.7
github.com/breml/rootcerts v0.3.4
github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/google/nftables v0.3.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0
@@ -22,7 +20,6 @@ require (
github.com/qdm12/log v0.1.0
github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1
github.com/ti-mo/netfilter v0.5.3
github.com/ulikunitz/xz v0.5.15
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
@@ -43,8 +40,10 @@ require (
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
+8 -12
View File
@@ -8,8 +8,8 @@ github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1J
github.com/ProtonMail/go-srp v0.0.7/go.mod h1:giCp+7qRnMIcCvI6V6U3S1lDDXDQYx2ewJ6F/9wdlJk=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/breml/rootcerts v0.3.4 h1:9i7WNl/ctd9OEAOaTfLy//Wrlfxq/tRQ7v4okYFN9Ys=
github.com/breml/rootcerts v0.3.4/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw=
github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -30,8 +30,8 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
@@ -49,8 +49,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
@@ -73,8 +73,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 h1:0ycKUDQ50cYb2QpeyGcEnvVs9HJmC9jsb/XZNC1z28c=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/dns/v2 v2.0.0-rc10 h1:IyeNEYXfhBsaE1dwxx5eAqdAz1HS98dT+8c7xoKODa0=
github.com/qdm12/dns/v2 v2.0.0-rc10/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
@@ -95,12 +95,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ=
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -23,9 +23,7 @@ type DNSBlacklist struct {
AddBlockedIPs []netip.Addr
AddBlockedIPPrefixes []netip.Prefix
// RebindingProtectionExemptHostnames is a list of hostnames
// exempt from DNS rebinding protection. It can contain parent
// domains which are of the form "*.example.com". Note the wildcard
// can only be used at the start of the hostname.
// exempt from DNS rebinding protection.
RebindingProtectionExemptHostnames []string
}
@@ -57,9 +55,6 @@ func (b DNSBlacklist) validate() (err error) {
}
for _, host := range b.RebindingProtectionExemptHostnames {
if len(host) > 2 && host[:2] == "*." {
host = host[2:]
}
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
}
-111
View File
@@ -1,111 +0,0 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// PMTUD contains settings to configure Path MTU Discovery.
type PMTUD struct {
// ICMPAddresses is the redundancy list of addresses to use
// for ICMP path MTU discovery. Each address MUST handle ICMP
// packets for PMTUD to work.
// It cannot be nil in the internal state.
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
// TCPAddresses is the redundancy list of addresses to use
// for TCP path MTU discovery. Each address MUST have a listening
// TCP server on the port specified.
// It cannot be nil in the internal state.
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
}
}
return nil
}
func (p *PMTUD) copy() (copied PMTUD) {
return PMTUD{
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
}
}
func (p *PMTUD) overrideWith(other PMTUD) {
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
}
func (p *PMTUD) setDefaults() {
defaultICMPAddresses := []netip.Addr{
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
}
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
const dnsPort, tlsPort = 53, 443
defaultTCPAddresses := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
}
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
}
func (p PMTUD) String() string {
return p.toLinesNode().String()
}
func (p PMTUD) toLinesNode() (node *gotree.Node) {
node = gotree.New("Path MTU discovery:")
icmpAddrNode := node.Append("ICMP addresses:")
for _, addr := range p.ICMPAddresses {
icmpAddrNode.Append(addr.String())
}
tcpAddrNode := node.Append("TCP addresses:")
for _, addr := range p.TCPAddresses {
tcpAddrNode.Append(addr.String())
}
return node
}
func (p *PMTUD) read(r *reader.Reader) (err error) {
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
if err != nil {
return err
}
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
if err != nil {
return err
}
return nil
}
+4 -7
View File
@@ -2,8 +2,6 @@ package settings
import (
"fmt"
"slices"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants/providers"
@@ -33,11 +31,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
if vpnType == vpn.OpenVPN {
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
mullvadIndex := slices.Index(validNames, providers.Mullvad)
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
validNames = validNames[:len(validNames)-1]
sort.Strings(validNames)
} else { // Wireguard
validNames = []string{
providers.Airvpn,
@@ -55,6 +48,10 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
}
if p.Name == providers.Mullvad && vpnType == vpn.OpenVPN {
warner.Warn("https://mullvad.net/en/blog/removing-openvpn-15th-january-2026")
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
if err != nil {
return fmt.Errorf("server selection: %w", err)
@@ -29,27 +29,14 @@ func Test_Settings_String(t *testing.T) {
| | └── OpenVPN server selection settings:
| | ├── Protocol: UDP
| | └── Private Internet Access encryption preset: strong
| ── OpenVPN settings:
| | ├── OpenVPN version: 2.6
| | ├── User: [not set]
| | ├── Password: [not set]
| | ├── Private Internet Access encryption preset: strong
| | ├── Network interface: tun0
| | ├── Run OpenVPN as: root
| | └── Verbosity level: 1
| └── Path MTU discovery:
| ├── ICMP addresses:
| | ├── 1.1.1.1
| | └── 8.8.8.8
| └── TCP addresses:
| ├── 1.1.1.1:53
| ├── 8.8.8.8:53
| ├── 1.1.1.1:443
| ├── 8.8.8.8:443
| ├── [2606:4700:4700::1111]:53
| ├── [2001:4860:4860::8888]:53
| ├── [2606:4700:4700::1111]:443
| └── [2001:4860:4860::8888]:443
| ── OpenVPN settings:
| ├── OpenVPN version: 2.6
| ├── User: [not set]
| ├── Password: [not set]
| ├── Private Internet Access encryption preset: strong
| ├── Network interface: tun0
| ├── Run OpenVPN as: root
| └── Verbosity level: 1
├── DNS settings:
| ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1
-15
View File
@@ -18,7 +18,6 @@ type VPN struct {
Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
}
// TODO v4 remove pointer for receiver (because of Surfshark).
@@ -46,11 +45,6 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
}
}
err = v.PMTUD.validate()
if err != nil {
return fmt.Errorf("PMTUD settings: %w", err)
}
return nil
}
@@ -60,7 +54,6 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
}
}
@@ -69,7 +62,6 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
}
func (v *VPN) setDefaults() {
@@ -77,7 +69,6 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
}
func (v VPN) String() string {
@@ -94,7 +85,6 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
} else {
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
return node
}
@@ -117,10 +107,5 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err)
}
err = v.PMTUD.read(r)
if err != nil {
return fmt.Errorf("PMTUD: %w", err)
}
return nil
}
+15 -10
View File
@@ -38,9 +38,15 @@ type Wireguard struct {
Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be nil in the internal state, and defaults to
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// It cannot be zero in the internal state, and defaults to
// 1320. Note it is not the wireguard-go MTU default of 1420
// because this impacts bandwidth a lot on some VPN providers,
// see https://github.com/qdm12/gluetun/issues/1650.
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -189,7 +195,8 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
const defaultMTU = 1320
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
}
@@ -225,11 +232,7 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
if *w.MTU == 0 {
interfaceNode.Append("MTU: use path MTU discovery")
} else {
interfaceNode.Appendf("MTU: %d", *w.MTU)
}
interfaceNode.Appendf("MTU: %d", w.MTU)
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
@@ -270,9 +273,11 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
w.MTU = *mtuPtr
}
return nil
}
+17
View File
@@ -0,0 +1,17 @@
package dns
import (
"context"
"net/netip"
)
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
type Firewall interface {
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
}
-8
View File
@@ -1,8 +0,0 @@
package dns
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
+3 -1
View File
@@ -24,6 +24,7 @@ type Loop struct {
localResolvers []netip.Addr
resolvConf string
client *http.Client
firewall Firewall
logger Logger
userTrigger bool
start <-chan struct{}
@@ -39,7 +40,7 @@ type Loop struct {
const defaultBackoffTime = 10 * time.Second
func NewLoop(settings settings.DNS,
client *http.Client, logger Logger,
client *http.Client, firewall Firewall, logger Logger,
) (loop *Loop, err error) {
start := make(chan struct{})
running := make(chan models.LoopStatus)
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
filter: filter,
resolvConf: "/etc/resolv.conf",
client: client,
firewall: firewall,
logger: logger,
userTrigger: true,
start: start,
+9 -2
View File
@@ -1,13 +1,14 @@
package dns
import (
"context"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/nameserver"
)
func (l *Loop) useUnencryptedDNS(fallback bool) {
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
settings := l.GetSettings()
targetIP := settings.GetFirstPlaintextIPv4()
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
const dialTimeout = 3 * time.Second
const defaultDNSPort = 53
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
settingsInternalDNS := nameserver.SettingsInternalDNS{
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
AddrPort: addrPort,
Timeout: dialTimeout,
}
nameserver.UseDNSInternally(settingsInternalDNS)
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
if err != nil {
l.logger.Error(err.Error())
}
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
if err != nil {
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
}
}
+11 -10
View File
@@ -2,6 +2,7 @@ package dns
import (
"context"
"errors"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/gluetun/internal/constants"
@@ -23,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
"and go through your container network DNS outside the VPN tunnel!")
} else {
const fallback = false
l.useUnencryptedDNS(fallback)
l.useUnencryptedDNS(ctx, fallback)
}
select {
@@ -43,12 +44,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
runError, err = l.setupServer(ctx)
if err == nil {
l.backoffTime = defaultBackoffTime
l.logger.Info("ready and using DNS server at address " + settings.ServerAddress.String())
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
l.logger.Info("ready")
break
}
@@ -57,6 +53,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if ctx.Err() != nil {
return
}
if !errors.Is(err, errUpdateBlockLists) {
const fallback = true
l.useUnencryptedDNS(ctx, fallback)
}
l.logAndWait(ctx, err)
settings = l.GetSettings()
}
@@ -65,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
settings = l.GetSettings()
if !*settings.KeepNameserver && !*settings.ServerEnabled {
const fallback = false
l.useUnencryptedDNS(fallback)
l.useUnencryptedDNS(ctx, fallback)
}
l.userTrigger = false
@@ -93,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
settings := l.GetSettings()
if !*settings.KeepNameserver && *settings.ServerEnabled {
const fallback = false
l.useUnencryptedDNS(fallback)
l.useUnencryptedDNS(ctx, fallback)
l.stopServer()
}
l.stopped <- struct{}{}
@@ -104,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
case err := <-runError: // unexpected error
l.statusManager.SetStatus(constants.Crashed)
const fallback = true
l.useUnencryptedDNS(fallback)
l.useUnencryptedDNS(ctx, fallback)
l.logAndWait(ctx, err)
return false
}
+14 -7
View File
@@ -2,24 +2,25 @@ package dns
import (
"context"
"errors"
"fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/dns/v2/pkg/server"
)
var errUpdateBlockLists = errors.New("cannot update filter block lists")
func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err error) {
settings := l.GetSettings()
var updateSettings update.Settings
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
err = l.updateFiles(ctx)
if err != nil {
return nil, fmt.Errorf("updating filter for rebinding protection: %w", err)
return nil, fmt.Errorf("%w: %w", errUpdateBlockLists, err)
}
settings := l.GetSettings()
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger)
if err != nil {
return nil, fmt.Errorf("building server settings: %w", err)
@@ -38,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
// use internal DNS server
const defaultDNSPort = 53
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
AddrPort: addrPort,
})
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
IPs: []netip.Addr{settings.ServerAddress},
@@ -49,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
l.logger.Error(err.Error())
}
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
if err != nil {
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
}
err = check.WaitForDNS(ctx, check.Settings{})
if err != nil {
l.stopServer()
+15 -4
View File
@@ -28,12 +28,23 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
return
case <-timer.C:
lastTick = l.timeNow()
settings := l.GetSettings()
if l.GetStatus() == constants.Running {
if err := l.updateFiles(ctx, settings); err != nil {
l.logger.Warn("updating block lists failed, skipping: " + err.Error())
status := l.GetStatus()
if status == constants.Running {
if err := l.updateFiles(ctx); err != nil {
l.statusManager.SetStatus(constants.Crashed)
l.logger.Error(err.Error())
l.logger.Warn("skipping DNS server restart due to failed files update")
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
continue
}
}
_, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped)
_, _ = l.statusManager.ApplyStatus(ctx, constants.Running)
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
case <-l.updateTicker:
if !timer.Stop() {
+4 -2
View File
@@ -6,10 +6,11 @@ import (
"github.com/qdm12/dns/v2/pkg/blockbuilder"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/gluetun/internal/configuration/settings"
)
func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err error) {
func (l *Loop) updateFiles(ctx context.Context) (err error) {
settings := l.GetSettings()
l.logger.Info("downloading hostnames and IP block lists")
blacklistSettings := settings.Blacklist.ToBlockBuilderSettings(l.client)
@@ -36,6 +37,7 @@ func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err erro
IPPrefixes: result.BlockedIPPrefixes,
}
updateSettings.BlockHostnames(result.BlockedHostnames)
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
if err != nil {
return fmt.Errorf("updating filter: %w", err)
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"fmt"
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
+58 -28
View File
@@ -22,7 +22,9 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled {
c.logger.Info("disabling...")
c.restore(ctx)
if err = c.disable(ctx); err != nil {
return fmt.Errorf("disabling firewall: %w", err)
}
c.enabled = false
c.logger.Info("disabled successfully")
return nil
@@ -39,33 +41,64 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil
}
func (c *Config) enable(ctx context.Context) (err error) {
c.restore, err = c.impl.SaveAndRestore(ctx)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}
defer func() {
if err != nil {
c.restore(context.Background())
}
}()
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}
if err = c.impl.SetBaseChainsPolicy(ctx, "DROP"); err != nil {
return nil
}
// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *Config) enable(ctx context.Context) (err error) {
touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
touched = true
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
const remove = false
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
defer func() {
if touched && err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return err
}
@@ -75,9 +108,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
for _, network := range c.localNetworks {
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
network.InterfaceName, network.IP, network.IPNet, remove)
if err != nil {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
return err
}
@@ -86,7 +117,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue
}
localInterfaces[network.InterfaceName] = struct{}{}
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
}
@@ -99,7 +130,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
return err
}
}
@@ -108,12 +139,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.redirectPorts(ctx)
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
@@ -133,7 +164,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
continue
}
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil {
return fmt.Errorf("accepting output traffic through VPN: %w", err)
}
@@ -155,7 +186,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
firewallUpdated = true
const remove = false
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
@@ -173,7 +204,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
for port, netInterfaces := range c.allowedInputPorts {
for netInterface := range netInterfaces {
const remove = false
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
err = c.acceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, netInterface, err)
@@ -183,10 +214,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
return nil
}
func (c *Config) redirectPorts(ctx context.Context) (err error) {
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
for _, portRedirection := range c.portRedirections {
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
+23 -15
View File
@@ -2,33 +2,34 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/routing"
)
type Config struct {
runner CmdRunner
logger Logger
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
// Fixed
impl firewallImpl
// Fixed state
ipTables string
ip6Tables string
customRulesPath string
// State
enabled bool
restore func(context.Context)
vpnConnection models.Connection
vpnIntf string
outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
portRedirections portRedirections
outputAddrPort map[uint16]netip.Addr
stateMutex sync.Mutex
}
@@ -38,19 +39,26 @@ func NewConfig(ctx context.Context, logger Logger,
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
localNetworks []routing.LocalNetwork,
) (config *Config, err error) {
impl, err := iptables.New(ctx, runner, logger)
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, fmt.Errorf("creating iptables firewall: %w", err)
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
return &Config{
runner: runner,
logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}),
outputAddrPort: make(map[uint16]netip.Addr),
ipTables: iptables,
ip6Tables: ip6tables,
customRulesPath: "/iptables/post-rules.txt",
// Obtained from routing
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
impl: impl,
customRulesPath: "/iptables/post-rules.txt",
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
}, nil
}
+1 -28
View File
@@ -1,12 +1,6 @@
package firewall
import (
"context"
"net/netip"
"os/exec"
"github.com/qdm12/gluetun/internal/models"
)
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
@@ -18,24 +12,3 @@ type Logger interface {
Warn(s string)
Error(s string)
}
type firewallImpl interface { //nolint:interfacebloat
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
connection models.Connection, remove bool) error
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string) error
SetBaseChainsPolicy(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error)
Version(ctx context.Context) (version string, err error)
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -14,8 +14,8 @@ import (
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
ip6tablesPath string, err error,
) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
if errors.Is(err, ErrNotSupported) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
if errors.Is(err, ErrIPTablesNotSupported) {
return "", nil
} else if err != nil {
return "", err
@@ -24,23 +24,8 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
}
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
return err
}
}
@@ -48,24 +33,11 @@ func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instruction
}
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
if c.ip6Tables == "" {
return nil
}
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
@@ -81,3 +53,18 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
}
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -26,6 +26,22 @@ func appendOrDelete(remove bool) string {
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
}
return rule
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
@@ -38,28 +54,12 @@ func (c *Config) Version(ctx context.Context) (string, error) {
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
}
return "iptables " + words[1], nil
return words[1], nil
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
@@ -70,19 +70,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
@@ -98,33 +85,42 @@ func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction s
return nil
}
func (c *Config) SetBaseChainsPolicy(ctx context.Context, policy string) error {
policy = strings.ToUpper(policy)
func (c *Config) clearAllRules(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
}
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runMixedIptablesInstructions(ctx, []string{
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"--append INPUT -i %s -j ACCEPT", intf))
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool,
) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
interfaceFlag, destination.String())
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
@@ -135,20 +131,20 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
})
}
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
protocol := connection.Protocol
@@ -166,11 +162,8 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
// Thanks to @npawelek.
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
@@ -191,24 +184,21 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4).
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool,
) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
appendOrDelete(remove), interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
// Used for port forwarding, with intf set to tun.
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
@@ -219,12 +209,8 @@ func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16
})
}
// RedirectPort redirects incoming traffic on the specified source port to the
// specified destination port, for both TCP and UDP protocols, on the interface intf.
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
// the redirection is removed instead of added. This is used for VPN server side
// port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) RedirectPort(ctx context.Context, intf string,
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) redirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool,
) (err error) {
interfaceFlag := "-i " + intf
@@ -232,17 +218,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
interfaceFlag = ""
}
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, []string{
err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -253,12 +229,11 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx)
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -269,7 +244,6 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx) // just in case
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
@@ -283,7 +257,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
return nil
}
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
@@ -299,17 +273,16 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
return err
}
lines := strings.Split(string(b), "\n")
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()
for _, line := range lines {
var ipv4 bool
var rule string
@@ -336,18 +309,23 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
continue
}
if remove {
rule = flipRule(rule)
}
switch {
case ipv4:
err = c.runIptablesInstructionNoSave(ctx, rule)
err = c.runIptablesInstruction(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstructionNoSave(ctx, rule)
err = c.runIP6tablesInstruction(ctx, rule)
}
if err != nil {
restore(ctx)
return err
}
successfulRules = append(successfulRules, rule)
}
return nil
}
-85
View File
@@ -1,85 +0,0 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strings"
)
// SaveAndRestore saves the current iptables and ip6tables rules and
// returns a restore function that can be called to restore the saved rules.
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
return c.saveAndRestore(ctx)
}
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
// before calling this function. Note the restore function does not interact with mutexes
// so the caller must make sure the mutexes are locked when calling the restore function.
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return nil, err
}
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return nil, err
}
restore = func(ctx context.Context) {
restoreIPv4(ctx)
if restoreIPv6 != nil {
restoreIPv6(ctx)
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
if c.ip6Tables == "" {
return nil, nil //nolint:nilnil
}
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
-51
View File
@@ -1,51 +0,0 @@
package iptables
import (
"context"
"sync"
"github.com/qdm12/gluetun/internal/mod"
)
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
// Fixed state
ipTables string
ip6Tables string
nftables bool
xtMark bool
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
modules := map[string]bool{
"xt_mark": false,
"nf_tables": false,
}
for module := range modules {
err := mod.Probe(module)
modules[module] = err == nil
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
nftables: modules["nf_tables"],
xtMark: modules["xt_mark"],
}, nil
}
-14
View File
@@ -1,14 +0,0 @@
package iptables
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
-49
View File
@@ -1,49 +0,0 @@
package iptables
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, instruction := range instructions {
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
restore(ctx)
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runMixedIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstructionNoSave(ctx, instruction)
}
-96
View File
@@ -1,96 +0,0 @@
package iptables
import (
"context"
"errors"
"fmt"
"net/netip"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
if !c.nftables && !c.xtMark {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
run := c.runIptablesInstruction
if dst.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}
+34
View File
@@ -0,0 +1,34 @@
package firewall
import (
"context"
"fmt"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
ipv4Instructions, ipv6Instructions []string,
) error {
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
return fmt.Errorf("running iptables instructions: %w", err)
}
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
return fmt.Errorf("running ip6tables instructions: %w", err)
}
return nil
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"errors"
@@ -26,18 +26,10 @@ type chainRule struct {
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
sourcePort uint16 // Not specified if set to zero.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
}
type mark struct {
invert bool
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -249,23 +241,19 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
i := 0
for i < len(optionalFields) {
switch optionalFields[i] {
case "udp":
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i++
consumed, err := parseUDPOptional(optionalFields[i:], rule)
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing UDP optional fields: %w", err)
return fmt.Errorf("parsing destination port %q: %w", value, err)
}
i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
rule.destinationPort = uint16(destinationPort)
case "redir":
i++
switch optionalFields[i] {
@@ -276,136 +264,20 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
i++
case "mark":
i++
mark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing mark: %w", err)
}
rule.mark = mark
i += consumed
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
return parsePort(value)
}
func parseSourcePort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "spt:")
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
@@ -414,40 +286,16 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
ports[i], err = parsePort(field)
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
if err != nil {
return nil, err
return nil, fmt.Errorf("parsing port %q: %w", field, err)
}
ports[i] = uint16(port)
}
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
consumed++
if optionalFields[consumed] == "!" {
m.invert = true
consumed++
}
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"net/netip"
@@ -1,3 +1,3 @@
package iptables
package firewall
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger)
// Package iptables is a generated GoMock package.
package iptables
// Package firewall is a generated GoMock package.
package firewall
import (
exec "os/exec"
-99
View File
@@ -1,99 +0,0 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
)
// SaveAndRestore saves the current nftables tree and returns a restore function that
// can be called to restore the saved tree.
func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("creating nftables connection: %w", err)
}
tables, err := saveTables(conn)
if err != nil {
return nil, fmt.Errorf("saving nftables state: %w", err)
}
return func(_ context.Context) {
conn, err := nftables.New()
if err != nil {
f.logger.Warnf("creating nftables connection for restore: %s", err)
return
}
err = restoreTables(conn, tables)
if err != nil {
f.logger.Warnf("restoring nftables state: %s", err)
}
}, nil
}
type savedTable struct {
table *nftables.Table
chains []savedChain
}
type savedChain struct {
chain *nftables.Chain
rules []*nftables.Rule
}
func saveTables(conn *nftables.Conn) ([]savedTable, error) {
tables, err := conn.ListTables()
if err != nil {
return nil, err
}
savedTables := make([]savedTable, len(tables))
for i, table := range tables {
savedTables[i].table = table
chains, err := conn.ListChains()
if err != nil {
return nil, err
}
for _, chain := range chains {
if chain.Table.Name != table.Name ||
chain.Table.Family != table.Family {
continue
}
rules, err := conn.GetRules(table, chain)
if err != nil {
return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err)
}
savedChain := savedChain{chain: chain, rules: rules}
savedTables[i].chains = append(savedTables[i].chains, savedChain)
}
}
return savedTables, nil
}
func restoreTables(conn *nftables.Conn, savedTables []savedTable) error {
conn.FlushRuleset()
for _, savedTable := range savedTables {
table := conn.AddTable(savedTable.table)
for _, savedChain := range savedTable.chains {
// Make the [nftables.Chain.Table] points to the new [nftables.Table]
// created in this connection.
savedChain.chain.Table = table
savedChain.chain = conn.AddChain(savedChain.chain)
for _, rule := range savedChain.rules {
rule.Table = table
rule.Chain = savedChain.chain
conn.AddRule(rule)
}
}
}
return conn.Flush()
}
-50
View File
@@ -1,50 +0,0 @@
package nftables
import (
"context"
"errors"
"fmt"
"strings"
"github.com/google/nftables"
)
var ErrPolicyUnknown = errors.New("unknown policy")
// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT)
// for the filter table to the given policy (accept or drop).
func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
var chainPolicy nftables.ChainPolicy
switch strings.ToLower(policy) {
case "accept":
chainPolicy = nftables.ChainPolicyAccept
case "drop":
chainPolicy = nftables.ChainPolicyDrop
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
_, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn)
inputChain.Policy = &chainPolicy
forwardChain.Policy = &chainPolicy
outputChain.Policy = &chainPolicy
conn.AddChain(inputChain)
conn.AddChain(forwardChain)
conn.AddChain(outputChain)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing nftables changes: %w", err)
}
return nil
}
-61
View File
@@ -1,61 +0,0 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptEstablishedRelatedTraffic(_ context.Context) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, outputChain := setupFilterWithBaseChains(conn)
ctStateExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4, //nolint:mnd
Mask: []byte{byte(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), 0x00, 0x00, 0x00},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
conn.AddRule(&nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: ctStateExprs,
})
conn.AddRule(&nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: ctStateExprs,
})
if err := conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
-27
View File
@@ -1,27 +0,0 @@
package nftables
import (
"errors"
"fmt"
"reflect"
"github.com/google/nftables"
)
var errRuleToDeleteNotFound = errors.New("rule not found for removal")
func (f *Firewall) deleteRule(conn *nftables.Conn, rule *nftables.Rule) error {
for i, existing := range f.rules {
if !reflect.DeepEqual(existing, rule) {
continue
}
err := conn.DelRule(existing)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
f.rules[i], f.rules[len(f.rules)-1] = f.rules[len(f.rules)-1], f.rules[i]
f.rules = f.rules[:len(f.rules)-1]
return nil
}
return fmt.Errorf("%w: %#v", errRuleToDeleteNotFound, rule)
}
-38
View File
@@ -1,38 +0,0 @@
package nftables
import "github.com/google/nftables"
func setupFilterWithBaseChains(conn *nftables.Conn) (table *nftables.Table,
inputChain, forwardChain, outputChain *nftables.Chain,
) {
table = conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "filter",
})
inputChain = conn.AddChain(&nftables.Chain{
Name: "input",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
})
forwardChain = conn.AddChain(&nftables.Chain{
Name: "forward",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
outputChain = conn.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
})
return table, inputChain, forwardChain, outputChain
}
-22
View File
@@ -1,22 +0,0 @@
package nftables
import (
"sync"
"github.com/google/nftables"
)
type Firewall struct {
logger Logger
// rules are only rules added and tracked for later removal.
// Not all rules added are tracked for removal.
rules []*nftables.Rule
mutex sync.Mutex
}
func New(logger Logger) *Firewall {
return &Firewall{
logger: logger,
}
}
-170
View File
@@ -1,170 +0,0 @@
package nftables
import (
"context"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(intf + "\x00"),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
const tcp, udp uint8 = 6, 17
protocols := []uint8{tcp, udp}
for _, protocol := range protocols {
const maxExprsLen = 7
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
if !remove {
conn.AddRule(rule)
f.rules = append(f.rules, rule)
continue
}
err = f.deleteRule(conn, rule)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
}
err = conn.Flush()
if err != nil {
f.rules = f.rules[:len(f.rules)-len(protocols)]
return fmt.Errorf("flushing: %w", err)
}
return nil
}
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
const maxExprsLen = 5
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
var payloadOffset uint32
if subnet.Addr().Is4() {
payloadOffset = 16
} else {
payloadOffset = 24
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: payloadOffset,
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: subnet.Addr().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
-5
View File
@@ -1,5 +0,0 @@
package nftables
type Logger interface {
Warnf(format string, args ...any)
}
-78
View File
@@ -1,78 +0,0 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptIpv6MulticastOutput(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, _, _, outputChain := setupFilterWithBaseChains(conn)
const maxExprsLen = 6
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
// ff02::1:ff00:0/104 mask is 13 bytes of 0xff
mask := []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
addr := []byte{
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 24, // IPv6 Destination Address offset //nolint:mnd
Len: 16, //nolint:mnd
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 16, //nolint:mnd
Mask: mask,
Xor: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //nolint:mnd
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: addr,
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
-12
View File
@@ -1,12 +0,0 @@
package nftables
import "github.com/google/nftables"
func IsSupported() bool {
conn, err := nftables.New()
if err != nil {
return false
}
_, err = conn.ListTable("filter")
return err == nil
}
+2 -2
View File
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
}
firewallUpdated = true
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subNet, remove)
if err != nil {
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
}
firewallUpdated = true
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"errors"
@@ -9,22 +9,30 @@ import (
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
operation operation
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
sourcePort uint16 // if zero, there is no source port
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
lineNumber uint16 // for replace operation, the line number to replace
}
func (i *iptablesInstruction) setDefaults() {
@@ -46,8 +54,6 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case i.destinationPort != rule.destinationPort:
return false
case i.sourcePort != rule.sourcePort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
@@ -60,16 +66,63 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
case i.mark != rule.mark:
return false
default:
return true
}
}
func (i *iptablesInstruction) String() string {
var sb strings.Builder
if i.table != "" && i.table != "filter" {
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
}
switch i.operation {
case opNone:
panic("no operation specified")
case opAppend:
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
case opDelete:
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
case opInsert:
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
case opReplace:
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
}
if i.inputInterface != "" {
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
}
if i.outputInterface != "" {
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
}
if i.protocol != "" {
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
}
if i.source.IsValid() {
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
}
if i.destination.IsValid() {
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
}
if i.destinationPort != 0 {
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
}
if len(i.ctstate) > 0 {
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
}
if len(i.toPorts) > 0 {
var portStrings []string
for _, port := range i.toPorts {
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
}
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
}
if i.target != "" {
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
}
return strings.TrimSpace(sb.String())
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
@@ -102,39 +155,53 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
}
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
consumed, err = preCheckInstructionFields(fields)
if err != nil {
return 0, err
}
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "-R", "--replace":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
consumed = expected
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
consumed = expected
}
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.operation = opDelete
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.operation = opAppend
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.chain = value
case "-R", "--replace":
instruction.operation = opReplace
instruction.chain = value
const base, bits = 10, 16
n, err := strconv.ParseUint(fields[2], base, bits)
if err != nil {
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
}
instruction.lineNumber = uint16(n)
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match":
consumed, err = parseMatchModule(fields, instruction)
if err != nil {
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
}
instruction.mark.value = uint(value)
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
@@ -144,61 +211,37 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
if err != nil {
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "--sport":
instruction.sourcePort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
instruction.destinationPort, err = parsePort(value)
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
instruction.toPorts, err = parseToPorts(value)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
}
return consumed, nil
}
func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
return expected, nil
}
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
@@ -211,52 +254,3 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}
func parsePort(value string) (port uint16, err error) {
const base, bitLength = 10, 16
portValue, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, err
}
return uint16(portValue), nil
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
_ = fields[consumed] // -m or --match flag already detected
consumed++
switch fields[consumed] {
case "tcp", "udp":
consumed++
// for now ignore the protocol match since it's auto-loaded
// when parsing the -p/--protocol flag, and we don't need to
// parse it twice.
case "mark":
consumed++
switch fields[consumed] {
case "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
}
return consumed, nil
}
func parseToPorts(value string) (toPorts []uint16, err error) {
portStrings := strings.Split(value, ",")
toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
toPorts[i], err = parsePort(portString)
if err != nil {
return nil, err
}
}
return toPorts, nil
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"net/netip"
@@ -28,14 +28,14 @@ func Test_parseIptablesInstruction(t *testing.T) {
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
},
"one_pair": {
s: "-A INPUT",
s: "-I INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
table: "filter",
chain: "INPUT",
operation: opInsert,
},
},
"instruction_A": {
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
operation: opAppend,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
operation: opDelete,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
+133 -2
View File
@@ -3,6 +3,7 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"strconv"
)
@@ -35,7 +36,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
const remove = false
if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err)
}
@@ -68,7 +69,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
const remove = true
for netInterface := range interfacesSet {
err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
err := c.acceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("removing allowed port %d on interface %s: %w",
port, netInterface, err)
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
return nil
}
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
// If the port was previously allowed for another IP address, that previous allowance will be removed.
// Giving an invalid address will remove any existing restrictions for the port specified.
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
existingIP := c.outputAddrPort[addrPort.Port()]
switch {
case existingIP == addrPort.Addr():
return nil
case !addrPort.Addr().IsValid():
// invalid address, remove any existing rules for the port
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
case !existingIP.IsValid():
// no previous existing address for the port
return c.insertOutputAddrPortRestriction(ctx, addrPort)
default:
// existing rule in the same IP family or different family
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
}
}
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
commonInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
}
if existingIP.Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
delete(c.outputAddrPort, port)
return nil
}
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
commonInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
}
if addrPort.Addr().Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
existingIP netip.Addr, addrPort netip.AddrPort,
) (err error) {
for _, protocol := range [...]string{"udp", "tcp"} {
switch {
case existingIP.Is4() && addrPort.Addr().Is4():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is6():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
}
case existingIP.Is4() && addrPort.Addr().Is6():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv4 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv6 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is4():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv6 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv4 rule: %w", err)
}
}
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}
+2 -2
View File
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
return nil
case conflict != nil:
const remove = true
err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove)
if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err)
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
}
const remove = false
err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
+51
View File
@@ -0,0 +1,51 @@
package firewall
import (
"context"
"errors"
"fmt"
)
var errRuleNotFound = errors.New("rule not found")
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
targetRule, err := parseIptablesInstruction(oldInstruction)
if err != nil {
return fmt.Errorf("parsing iptables command to replace: %w", err)
}
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
if err != nil {
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
} else if lineNumber == 0 {
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
}
parsed, err := parseIptablesInstruction(newInstruction)
if err != nil {
return fmt.Errorf("parsing replacement iptables command: %w", err)
}
parsed.operation = opReplace
parsed.lineNumber = lineNumber
return c.runIptablesInstruction(ctx, parsed.String())
}
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
targetRule, err := parseIptablesInstruction(oldInstruction)
if err != nil {
return fmt.Errorf("parsing iptables command to replace: %w", err)
}
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
if err != nil {
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
} else if lineNumber == 0 {
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
}
parsed, err := parseIptablesInstruction(newInstruction)
if err != nil {
return fmt.Errorf("parsing replacement iptables command: %w", err)
}
parsed.operation = opReplace
parsed.lineNumber = lineNumber
return c.runIP6tablesInstruction(ctx, parsed.String())
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -11,10 +11,10 @@ import (
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrIPTablesNotSupported = errors.New("no iptables supported found")
)
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
}
return "", fmt.Errorf("%w: errors encountered are: %s",
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; "))
}
func testIptablesPath(ctx context.Context, path string,
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errSentinel: ErrIPTablesNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
+4 -4
View File
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove := true
if c.vpnConnection.IP.IsValid() {
for _, defaultRoute := range c.defaultRoutes {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
}
}
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
c.vpnConnection = models.Connection{}
if c.vpnIntf != "" {
if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
}
}
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove = false
for _, defaultRoute := range c.defaultRoutes {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
}
}
c.vpnConnection = connection
if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf
-21
View File
@@ -1,21 +0,0 @@
package firewall
import (
"context"
"net/netip"
)
func (c *Config) Version(ctx context.Context) (version string, err error) {
return c.impl.Version(ctx)
}
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
}
+11 -37
View File
@@ -23,7 +23,6 @@ type Checker struct {
logger Logger
icmpTargetIPs []netip.Addr
smallCheckType string
startupOnFail bool
configMutex sync.Mutex
icmpNotPermitted *bool
@@ -46,43 +45,26 @@ func NewChecker(logger Logger) *Checker {
}
}
// SetConfig sets the following:
// - TCP+TLS dial addresses
// - ICMP echo IP addresses to target
// - the desired small check type (dns or icmp)
// - whether to startup the periodic checks if the startup check fails.
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
// to target and the desired small check type (dns or icmp).
// This function MUST be called before calling [Checker.Start].
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
smallCheckType string, startupOnFail bool,
smallCheckType string,
) {
c.configMutex.Lock()
defer c.configMutex.Unlock()
c.tlsDialAddrs = tlsDialAddrs
c.icmpTargetIPs = icmpTargets
c.smallCheckType = smallCheckType
c.startupOnFail = startupOnFail
}
// Start starts the [Checker] which behaves differently according to its
// internal field startupOnFail, which is set by calling [Checker.SetConfig].
//
// By default, startupOnFail should be false and the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// an error is returned and the [Checker] is not started.
// On success, it starts the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// If startupOnFail is true, the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// the error is sent to the runError channel, but no error is returned
// and the [Checker] continues to start the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// The periodic checks consist in:
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
// and, on success, starts the periodic checks in a separate goroutine:
// - a "small" ICMP echo check every minute
// - a "full" TCP+TLS check every 5 minutes
//
// The [Checker] has to be ultimately stopped by calling [Checker.Stop].
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
// It returns an error if the initial TCP+TLS check fails.
// The Checker has to be ultimately stopped by calling [Checker.Stop].
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
panic("call Checker.SetConfig with non empty values before Checker.Start")
@@ -94,19 +76,9 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
}
c.echoer.Reset()
// runErrorCh MUST be buffered in the case startupOnFail is true, and
// a startup error was encountered, to avoid blocking the startup
// goroutine when sending the error, especially since the caller may
// not be ready to receive from the channel yet.
runErrorCh := make(chan error, 1)
runError = runErrorCh
err = c.startupCheck(ctx)
if err != nil {
err = fmt.Errorf("startup check: %w", err)
if !c.startupOnFail {
return nil, err
}
runErrorCh <- err
return nil, fmt.Errorf("startup check: %w", err)
}
ready := make(chan struct{})
@@ -118,6 +90,8 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
smallCheckTimer := time.NewTimer(smallCheckPeriod)
const fullCheckPeriod = 5 * time.Minute
fullCheckTimer := time.NewTimer(fullCheckPeriod)
runErrorCh := make(chan error)
runError = runErrorCh
go func() {
defer close(done)
close(ready)
-4
View File
@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"time"
)
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
@@ -22,9 +21,6 @@ func NewClient(httpClient *http.Client) *Client {
}
func (c *Client) Check(ctx context.Context, url string) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
-33
View File
@@ -1,33 +0,0 @@
package mod
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
var errBuiltinModuleNotFound = errors.New("builtin module not found")
func checkModulesBuiltin(modulesPath, moduleName string) error {
f, err := os.Open(filepath.Join(modulesPath, "modules.builtin"))
if err != nil {
return err
}
defer f.Close()
moduleName = strings.TrimSuffix(moduleName, ".ko")
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSuffix(line, ".ko")
if strings.HasSuffix(line, "/"+moduleName) {
return nil
}
}
return fmt.Errorf("%w: %s", errBuiltinModuleNotFound, moduleName)
}
-132
View File
@@ -1,132 +0,0 @@
package mod
import (
"bufio"
"compress/gzip"
"errors"
"fmt"
"os"
"strings"
)
var (
errModuleNameUnknown = errors.New("unknown module name")
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
errKernelFeatureNotSet = errors.New("kernel feature not set")
errKernelFeatureNotFound = errors.New("kernel feature not found")
)
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
// feature is a module, not built-in.
// If the kernel feature is found and not set, it returns an error indicating that the kernel
// feature is not set. If the kernel feature is not found, it returns an error indicating that the kernel
// feature is not found.
func checkProcConfig(moduleName string) error {
f, err := os.Open("/proc/config.gz")
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("creating gzip reader: %w", err)
}
defer gz.Close()
// If any group of kernel features is satisfied, then the module is considered supported.
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
if !ok {
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
}
groups := make([]map[string]bool, len(kernelFeatureGroups))
for i, group := range kernelFeatureGroups {
featureToOK := make(map[string]bool)
for _, feature := range group {
featureToOK[feature] = false
}
groups[i] = featureToOK
}
scanner := bufio.NewScanner(gz)
for scanner.Scan() {
line := scanner.Text()
for _, featureToOK := range groups {
for name, ok := range featureToOK {
switch {
case ok:
case strings.HasPrefix(line, name+"=m"):
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
case strings.HasPrefix(line, name+"=y"):
featureToOK[name] = true
if allFeaturesOK(featureToOK) {
return nil
}
case strings.HasPrefix(line, "# "+name+" is not set"):
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
}
}
}
}
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
}
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
moduleMap := map[string][][]string{
"nf_tables": {{"CONFIG_NF_TABLES"}},
// Netfilter Matches
"xt_conntrack": {{"CONFIG_NETFILTER_XT_MATCH_CONNTRACK"}},
"xt_connmark": {
{"CONFIG_NETFILTER_XT_CONNMARK"},
{"CONFIG_NETFILTER_XT_MATCH_CONNMARK", "CONFIG_NETFILTER_XT_TARGET_CONNMARK"},
},
"xt_mark": {
{"CONFIG_NETFILTER_XT_MARK"},
{"CONFIG_NETFILTER_XT_MATCH_MARK", "CONFIG_NETFILTER_XT_TARGET_MARK"},
},
"nf_conntrack_netlink": {{"CONFIG_NF_CT_NETLINK"}},
"nf_reject_ipv4": {{"CONFIG_NF_REJECT_IPV4"}},
// Common Netfilter Targets
"xt_log": {{"CONFIG_NETFILTER_XT_TARGET_LOG"}},
"xt_reject": {
{"CONFIG_IP_NF_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
{"CONFIG_NETFILTER_XT_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
},
"xt_masquerade": {{"CONFIG_NETFILTER_XT_TARGET_MASQUERADE"}},
// Additional Netfilter Matches
"xt_addrtype": {{"CONFIG_NETFILTER_XT_MATCH_ADDRTYPE"}},
"xt_comment": {{"CONFIG_NETFILTER_XT_MATCH_COMMENT"}},
"xt_multiport": {{"CONFIG_NETFILTER_XT_MATCH_MULTIPORT"}},
"xt_state": {{"CONFIG_NETFILTER_XT_MATCH_STATE"}},
"xt_tcpudp": {{"CONFIG_NETFILTER_XT_MATCH_TCPUDP"}},
// Tunneling and Virtualization
"tun": {{"CONFIG_TUN"}},
"bridge": {{"CONFIG_BRIDGE"}},
"veth": {{"CONFIG_VETH"}},
"vxlan": {{"CONFIG_VXLAN"}},
"wireguard": {{"CONFIG_WIREGUARD"}},
// Filesystems
"overlay": {{"CONFIG_OVERLAY_FS"}},
"fuse": {{"CONFIG_FUSE_FS"}},
}
featureGroups, ok = moduleMap[strings.ToLower(moduleName)]
return featureGroups, ok
}
func allFeaturesOK(featureToOK map[string]bool) bool {
for _, ok := range featureToOK {
if !ok {
return false
}
}
return true
}
+30 -34
View File
@@ -30,7 +30,36 @@ type moduleInfo struct {
var ErrModulesDirectoryNotFound = errors.New("modules directory not found")
func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err error) {
func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return nil, fmt.Errorf("getting unix uname release: %w", err)
}
release := unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
var modulesPath string
var found bool
for _, modulesPath = range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
dependencyFilepath := filepath.Join(modulesPath, "modules.dep")
dependencyFile, err := os.Open(dependencyFilepath)
if err != nil {
@@ -82,39 +111,6 @@ func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err
return modulesInfo, nil
}
func getModulesPath() (string, error) {
release, err := getReleaseName()
if err != nil {
return "", fmt.Errorf("getting release name: %w", err)
}
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
for _, modulesPath := range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
return modulesPath, nil
}
}
return "", fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
func getReleaseName() (release string, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return "", fmt.Errorf("getting unix uname release: %w", err)
}
release = unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
return release, nil
}
func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error {
file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin"))
if err != nil {
+2 -39
View File
@@ -1,49 +1,12 @@
package mod
import (
"errors"
"fmt"
)
// Probe is a expanded version of modprobe, in which it checks if the Kernel
// built-in features contain the given module name.
// It first tries to locate the modules directory in [getModulesPath].
// If it fails (like on WSL), it then only checks for the kernel feature
// in /proc/config.gz with [checkProcConfig].
// Otherwise, it first checks if the modules directory modules.builtin
// file contains the given module name in [checkModulesBuiltin].
// If the module is not found, it then runs the classic [modProbe] behavior,
// trying to load the module in the kernel.
// If this fails, it does one final try running [checkProcConfig].
// Probe loads the given kernel module and its dependencies.
func Probe(moduleName string) error {
modulesPath, err := getModulesPath()
if err != nil {
if errors.Is(err, ErrModulesDirectoryNotFound) {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
return nil
}
return fmt.Errorf("getting modules path: %w", err)
}
err = checkModulesBuiltin(modulesPath, moduleName)
if err != nil {
err = modProbe(modulesPath, moduleName)
if err != nil {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
}
}
return nil
}
// modProbe is the classic modprobe behavior.
func modProbe(modulesPath, moduleName string) error {
modulesInfo, err := getModulesInfo(modulesPath)
modulesInfo, err := getModulesInfo()
if err != nil {
return fmt.Errorf("getting modules information: %w", err)
}
-38
View File
@@ -1,38 +0,0 @@
package netlink
import (
"fmt"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
)
func (n *NetLink) FlushConntrack() error {
conn, err := netfilter.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netfilter: %w", err)
}
defer conn.Close()
families := [...]netfilter.ProtoFamily{netfilter.ProtoIPv4, netfilter.ProtoIPv6}
for _, family := range families {
const IPCtnlMsgCtDelete = 2
request, err := netfilter.MarshalNetlink(
netfilter.Header{
SubsystemID: netfilter.NFSubsysCTNetlink,
MessageType: netfilter.MessageType(IPCtnlMsgCtDelete),
Family: family,
Flags: netlink.Request | netlink.Acknowledge,
},
nil)
if err != nil {
return fmt.Errorf("encoding netlink request: %w", err)
}
_, err = conn.Query(request)
if err != nil {
return fmt.Errorf("querying netlink request: %w", err)
}
}
return nil
}
@@ -1,7 +0,0 @@
//go:build !linux
package netlink
func (n *NetLink) FlushConntrack() error {
panic("not implemented")
}
+1 -13
View File
@@ -18,7 +18,6 @@ type Route struct {
Type uint8
Scope uint8
Proto uint8
AdvMSS uint32
}
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
@@ -36,9 +35,6 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
if metrics := message.Attributes.Metrics; metrics != nil {
r.AdvMSS = metrics.AdvMSS
}
}
func (r Route) message() *rtnetlink.RouteMessage {
@@ -62,6 +58,7 @@ func (r Route) message() *rtnetlink.RouteMessage {
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Dst: *dst, // there should always be a dst for routes
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
@@ -70,15 +67,6 @@ func (r Route) message() *rtnetlink.RouteMessage {
if src != nil { // src is optional
message.Attributes.Src = *src
}
if dst != nil {
message.Attributes.Dst = *dst
}
if r.AdvMSS != 0 {
if message.Attributes.Metrics == nil {
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
}
message.Attributes.Metrics.AdvMSS = r.AdvMSS
}
return message
}
@@ -1,4 +1,4 @@
package icmp
package pmtud
import (
"net"
@@ -1,4 +1,4 @@
package icmp
package pmtud
import (
"bytes"
@@ -9,17 +9,17 @@ import (
)
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrIDMismatch = errors.New("ICMP id mismatch")
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
return nil
}
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received))
ErrICMPEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrEchoDataMismatch, sent, received)
ErrICMPEchoDataMismatch, sent, received)
}
return nil
}
-24
View File
@@ -1,24 +0,0 @@
package constants
const (
MaxEthernetFrameSize uint32 = 1500
// MinIPv4MTU is defined according to
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
MinIPv4MTU uint32 = 68
MinIPv6MTU uint32 = 1280
IPv4HeaderLength uint32 = 20
IPv6HeaderLength uint32 = 40
UDPHeaderLength uint32 = 8
// BaseTCPHeaderLength is the TCP header length without options,
// which is the minimum TCP header length.
BaseTCPHeaderLength uint32 = 20
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
MaxTCPHeaderLength uint32 = 60
WireguardHeaderLength uint32 = 32
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
8 + // session id
4 + // packet id
28 // max possible auth tag/iv
)
-16
View File
@@ -1,16 +0,0 @@
//go:build linux || darwin
package constants
import "golang.org/x/sys/unix"
//nolint:revive
const (
SOCK_RAW = unix.SOCK_RAW
SOCK_STREAM = unix.SOCK_STREAM
AF_INET = unix.AF_INET
AF_INET6 = unix.AF_INET6
IPPROTO_TCP = unix.IPPROTO_TCP
EAGAIN = unix.EAGAIN
EWOULDBLOCK = unix.EWOULDBLOCK
)
@@ -1,13 +0,0 @@
package constants
import "golang.org/x/sys/windows"
const (
SOCK_RAW = windows.SOCK_RAW
SOCK_STREAM = windows.SOCK_STREAM
AF_INET = windows.AF_INET
AF_INET6 = windows.AF_INET6
IPPROTO_TCP = windows.IPPROTO_TCP
EAGAIN = windows.WSAEWOULDBLOCK
EWOULDBLOCK = windows.WSAEWOULDBLOCK
)
@@ -1,10 +1,10 @@
//go:build !linux && !windows
package icmp
package pmtud
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
func setDontFragment(fd uintptr) (err error) {
return nil
}
+10
View File
@@ -0,0 +1,10 @@
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
+13
View File
@@ -0,0 +1,13 @@
//go:build windows
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}
+29
View File
@@ -0,0 +1,29 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
-14
View File
@@ -1,14 +0,0 @@
package icmp
import (
"golang.org/x/sys/unix"
)
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
if ipv4 {
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP,
unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
}
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6,
unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
}
-14
View File
@@ -1,14 +0,0 @@
package icmp
import (
"golang.org/x/sys/windows"
)
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
if ipv4 {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, 14, 1)
}
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, 14, 1)
}
-30
View File
@@ -1,30 +0,0 @@
package icmp
import (
"context"
"errors"
"fmt"
"strings"
"time"
)
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found")
errTimeout = errors.New("operation timed out")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w: after %s", errTimeout, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
-52
View File
@@ -1,52 +0,0 @@
package icmp
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// PathMTUDiscover discovers the path MTU to the given IP address
// using ICMP.
// It first tries to get the next hop MTU using ICMP messages.
// If that fails, it falls back to sending echo requests with
// different packet sizes to find the maximum MTU.
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
logger.Debugf("finding IPv4 next hop MTU to %s", ip)
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
}
} else {
logger.Debugf("requesting IPv6 ICMP packet-too-big reply from %s", ip)
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, errTimeout): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debugf("falling back to sending different sized echo packets to %s", ip)
minMTU := constants.MinIPv4MTU
if ip.Is6() {
minMTU = constants.MinIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
}
-7
View File
@@ -1,7 +0,0 @@
package icmp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
-193
View File
@@ -1,193 +0,0 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
"golang.org/x/net/icmp"
)
type icmpTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
tests := make([]icmpTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
func collectReplies(conn net.PacketConn, ipVersion string,
tests []icmpTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
switch message.Body.(type) {
case *icmp.Echo:
case *icmp.DstUnreach, *icmp.TimeExceeded:
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
continue
default:
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
}
echoBody, _ := message.Body.(*icmp.Echo)
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}
-27
View File
@@ -1,27 +0,0 @@
package ip
import (
"net/netip"
"slices"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func GetFamilies(dsts []netip.AddrPort) (families []int) {
const maxFamilies = 2
families = make([]int, 0, maxFamilies)
for _, dst := range dsts {
family := GetFamily(dst)
if !slices.Contains(families, family) {
families = append(families, family)
}
}
return families
}
func GetFamily(dst netip.AddrPort) int {
if dst.Addr().Is4() {
return constants.AF_INET
}
return constants.AF_INET6
}
-79
View File
@@ -1,79 +0,0 @@
package ip
import (
"encoding/binary"
"net/netip"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func HeaderLength(ipv4 bool) uint32 {
if ipv4 {
return constants.IPv4HeaderLength
}
return constants.IPv6HeaderLength
}
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
ipHeader := make([]byte, constants.IPv4HeaderLength)
const version byte = 4
const headerLength byte = 20 / 4 // in 32-bit words
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
ipHeader[1] = 0 // type of Service
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
ipHeader[4], ipHeader[5] = 0, 0 // identification
const flagsAndOffset uint16 = 0x4000 // DF bit set
putUint16(ipHeader[6:], flagsAndOffset)
ipHeader[8] = 64 // ttl
ipHeader[9] = constants.IPPROTO_TCP
srcIPBytes := srcIP.As4()
copy(ipHeader[12:16], srcIPBytes[:])
dstIPBytes := dstIP.As4()
copy(ipHeader[16:20], dstIPBytes[:])
checksum := ipChecksum(ipHeader)
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
return ipHeader
}
// ipChecksum calculates the checksum for the IP header.
//
//nolint:mnd
func ipChecksum(header []byte) uint16 {
sum := uint32(0)
for i := 0; i < len(header)-1; i += 2 {
sum += uint32(header[i])<<8 + uint32(header[i+1])
}
if len(header)%2 != 0 {
sum += uint32(header[len(header)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
// HeaderV6 makes an IPv6 header.
// payloadLen is the length of the payload following the header.
// nextHeader can be byte([constants.IPPROTO_TCP]) for example.
func HeaderV6(srcIP, dstIP netip.Addr,
payloadLen uint16, nextHeader byte,
) []byte {
ipv6Header := make([]byte, constants.IPv6HeaderLength)
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
// Flow Label (remaining 16 bits)
ipv6Header[2] = 0x00
ipv6Header[3] = 0x00
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
ipv6Header[6] = nextHeader
const hopLimit = 64
ipv6Header[7] = hopLimit
copy(ipv6Header[8:24], srcIP.AsSlice())
copy(ipv6Header[24:40], dstIP.AsSlice())
return ipv6Header
}
-9
View File
@@ -1,9 +0,0 @@
package ip
import (
"encoding/binary"
)
func putUint16(b []byte, v uint16) {
binary.NativeEndian.PutUint16(b, v)
}
@@ -1,9 +0,0 @@
//go:build !darwin
package ip
import "encoding/binary"
func putUint16(b []byte, v uint16) {
binary.BigEndian.PutUint16(b, v)
}
-9
View File
@@ -1,9 +0,0 @@
//go:build linux || darwin
package ip
import "golang.org/x/sys/unix"
func SetIPv4HeaderIncluded(fd int) error {
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1)
}
-7
View File
@@ -1,7 +0,0 @@
//go:build !linux && !windows && !darwin
package ip
func SetIPv4HeaderIncluded(fd int) error {
panic("not implemented")
}
-10
View File
@@ -1,10 +0,0 @@
package ip
import (
"golang.org/x/sys/windows"
)
func SetIPv4HeaderIncluded(handle windows.Handle) error {
const ipHdrIncluded = windows.IP_HDRINCL
return windows.SetsockoptInt(handle, windows.IPPROTO_IP, ipHdrIncluded, 1)
}
-113
View File
@@ -1,113 +0,0 @@
package ip
import (
"errors"
"fmt"
"net/netip"
"syscall"
"github.com/jsimonetti/rtnetlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
// specified destination. It also reserves an ephemeral source port for the specified protocol
// to ensure that the port is not used by other processes. The cleanup function returned should
// be called to release the reserved port when done.
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
srcAddr, err := srcIP(dst.Addr())
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
}
srcPort, cleanup, err := srcPort(srcAddr, proto)
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
}
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
}
var (
errNoRoute = fmt.Errorf("no route to destination")
ErrNetworkUnreachable = errors.New("network unreachable")
)
func srcIP(dst netip.Addr) (netip.Addr, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return netip.Addr{}, err
}
defer conn.Close()
family := uint8(constants.AF_INET)
if dst.Is6() {
family = constants.AF_INET6
}
// Request route to destination
requestMessage := &rtnetlink.RouteMessage{
Family: family,
Attributes: rtnetlink.RouteAttributes{
Dst: dst.AsSlice(),
},
}
messages, err := conn.Route.Get(requestMessage)
if err != nil {
var sysErr syscall.Errno
if errors.As(err, &sysErr) && sysErr == syscall.ENETUNREACH {
err = ErrNetworkUnreachable
}
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
}
for _, message := range messages {
if message.Attributes.Src == nil {
continue
}
ipv6 := message.Attributes.Src.To4() == nil
if ipv6 {
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
}
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
}
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
}
// srcPort reserves an ephemeral source port by opening a socket for the
// protocol specified and binds it to the provided source address.
// It doesn't actually listen on the port.
// The cleanup function returned should be called to release the port when done.
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
family := constants.AF_INET
if srcAddr.Is6() {
family = constants.AF_INET6
}
fd, err := socket(family, constants.SOCK_STREAM, proto)
if err != nil {
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
}
cleanup = func() {
_ = closeSocket(fd)
}
// Bind to port 0 to get an ephemeral port
const port = 0
bindAddr := makeSockAddr(srcAddr, port)
err = bind(fd, bindAddr)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
}
srcPort, err = extractPortFromFD(fd)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("extracting port from socket fd: %w", err)
}
return srcPort, cleanup, nil
}
-51
View File
@@ -1,51 +0,0 @@
//go:build linux || darwin
package ip
import (
"fmt"
"net/netip"
"golang.org/x/sys/unix"
)
func socket(domain int, typ int, proto int) (fd int, err error) {
return unix.Socket(domain, typ, proto)
}
func closeSocket(fd int) error {
return unix.Close(fd)
}
func bind(fd int, addr unix.Sockaddr) error {
return unix.Bind(fd, addr)
}
func makeSockAddr(ip netip.Addr, port uint16) unix.Sockaddr {
if ip.Is4() {
return &unix.SockaddrInet4{
Port: int(port),
Addr: ip.As4(),
}
}
return &unix.SockaddrInet6{
Port: 0,
Addr: ip.As16(),
}
}
func extractPortFromFD(fd int) (uint16, error) {
sockAddr, err := unix.Getsockname(fd)
if err != nil {
return 0, fmt.Errorf("getting sockname: %w", err)
}
switch typedSockAddr := sockAddr.(type) {
case *unix.SockaddrInet4:
return uint16(typedSockAddr.Port), nil //nolint:gosec
case *unix.SockaddrInet6:
return uint16(typedSockAddr.Port), nil //nolint:gosec
default:
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
}
}
-49
View File
@@ -1,49 +0,0 @@
package ip
import (
"fmt"
"net/netip"
"golang.org/x/sys/windows"
)
func socket(domain int, typ int, proto int) (fd windows.Handle, err error) {
return windows.Socket(domain, typ, proto)
}
func closeSocket(fd windows.Handle) error {
return windows.Close(fd)
}
func bind(fd windows.Handle, addr windows.Sockaddr) error {
return windows.Bind(fd, addr)
}
func makeSockAddr(ip netip.Addr, port uint16) windows.Sockaddr {
if ip.Is4() {
return &windows.SockaddrInet4{
Port: int(port),
Addr: ip.As4(),
}
}
return &windows.SockaddrInet6{
Port: int(port),
Addr: ip.As16(),
}
}
func extractPortFromFD(fd windows.Handle) (uint16, error) {
sockAddr, err := windows.Getsockname(fd)
if err != nil {
return 0, fmt.Errorf("getting sockname: %w", err)
}
switch typedSockAddr := sockAddr.(type) {
case *windows.SockaddrInet4:
return uint16(typedSockAddr.Port), nil //nolint:gosec
case *windows.SockaddrInet6:
return uint16(typedSockAddr.Port), nil //nolint:gosec
default:
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
}
}
@@ -1,4 +1,4 @@
package icmp
package pmtud
import (
"context"
@@ -11,13 +11,14 @@ import (
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
icmpv4Protocol = 1
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU uint32 = 68
icmpv4Protocol int = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
@@ -25,8 +26,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
const ipv4 = true
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
@@ -38,7 +38,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
@@ -83,9 +83,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
buffer := make([]byte, physicalLinkMTU)
// for loop in case we read an ICMP message from another ICMP request
// or TCP/UDP traffic triggering an ICMP response.
for {
for { // for loop in case we read an echo reply for another ICMP request
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
@@ -110,27 +108,24 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const portUnreachable = 3
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrDestinationUnreachable, inboundMessage.Code)
ErrICMPDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
@@ -158,7 +153,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package icmp
package pmtud
import (
"context"
@@ -6,36 +6,24 @@ import (
"net"
"net/netip"
"strings"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
const ipv4 = false
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
@@ -97,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
@@ -115,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
@@ -128,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package icmp
package pmtud
import (
cryptorand "crypto/rand"
+230 -50
View File
@@ -4,88 +4,268 @@ import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/icmp"
"github.com/qdm12/gluetun/internal/pmtud/tcp"
"golang.org/x/net/icmp"
)
var (
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
// whole range of possible MTU values up to the physical link MTU to check.
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
physicalLinkMTU uint32, tryTimeout time.Duration, fw tcp.Firewall, logger Logger) (
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
mtu uint32, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if tryTimeout == 0 {
tryTimeout = time.Second
if pingTimeout == 0 {
pingTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
// Try finding the MTU using ICMP
maxPossibleMTU := physicalLinkMTU
icmpSuccess := false
for _, icmpIP := range icmpAddrs {
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
tryTimeout, logger)
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
icmpSuccess = true
maxPossibleMTU = mtu
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
logger.Debugf("ICMP path MTU discovery failed: %s", err)
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
if icmpSuccess {
break
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
minMTU := constants.MinIPv4MTU
if tcpAddrs[0].Addr().Is6() {
minMTU = constants.MinIPv6MTU
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
if icmpSuccess {
const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, iptables.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
if icmpSuccess {
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d",
ErrICMPOkTCPFail, maxPossibleMTU)
}
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail)
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
return mtu, nil
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}
@@ -1,4 +1,4 @@
package test
package pmtud
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_MakeMTUsToTest(t *testing.T) {
func Test_makeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -48,7 +48,7 @@ func Test_MakeMTUsToTest(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
-142
View File
@@ -1,142 +0,0 @@
package tcp
import (
"errors"
"fmt"
"sync"
"testing"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// testFirewall must be global to prevent parallel tests from interfering
// with each other since they would interact with the same filter table.
// The first test to use should initialize it, and the rest will reuse it.
var (
testFirewall *firewall.Config //nolint:gochecknoglobals
testFirewallOnce sync.Once //nolint:gochecknoglobals
)
// getFirewall returns a Firewall instance, initializing it if needed. If
// iptables is not supported, it skips the test.
func getFirewall(t *testing.T) *firewall.Config {
t.Helper()
testFirewallOnce.Do(func() {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, iptables.ErrNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
})
if testFirewall == nil {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
return testFirewall
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
}
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
families := []uint8{constants.AF_INET, constants.AF_INET6}
for _, family := range families {
for _, route := range defaultRoutes {
if route.Family != family {
continue
}
link, err := netlinker.LinkByName(route.NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
mtu = max(mtu, link.MTU)
}
}
if mtu == 0 {
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
return mtu, nil
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := unix.Close(fd)
assert.NoError(t, err)
})
addr := &unix.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = unix.Bind(fd, addr)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr, err := unix.Getsockname(fd)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
if !ok {
_ = unix.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}

Some files were not shown because too many files have changed in this diff Show More