mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-27 22:37:33 +02:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db947c17a8 |
@@ -1,2 +1,2 @@
|
||||
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
||||
RUN apk add wireguard-tools htop openssl tcpdump iptables nftables
|
||||
RUN apk add wireguard-tools htop openssl
|
||||
|
||||
@@ -45,7 +45,6 @@ jobs:
|
||||
level: error
|
||||
exclude: |
|
||||
./internal/storage/servers.json
|
||||
./golangci.yml
|
||||
*.md
|
||||
|
||||
- name: Linting
|
||||
|
||||
+1
-2
@@ -22,7 +22,6 @@ linters:
|
||||
- "^disabled$"
|
||||
# Firewall and routing strings
|
||||
- "^(ACCEPT|DROP)$"
|
||||
- "^--append$"
|
||||
- "^--delete$"
|
||||
- "^all$"
|
||||
- "^(tcp|udp)$"
|
||||
@@ -48,7 +47,7 @@ linters:
|
||||
path: internal\/server\/.+\.go
|
||||
- linters:
|
||||
- ireturn
|
||||
text: returns interface \(golang\.org\/x\/sys\/unix\.Sockaddr\)
|
||||
text: returns interface \(github\.com\/vishvananda\/netlink\.Link\)
|
||||
- linters:
|
||||
- ireturn
|
||||
path: internal\/openvpn\/pkcs8\/descbc\.go
|
||||
|
||||
+2
-5
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
|
||||
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
|
||||
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
|
||||
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
|
||||
RUN apk --update add git g++ findutils iptables
|
||||
RUN apk --update add git g++ findutils
|
||||
ENV CGO_ENABLED=0
|
||||
COPY --from=golangci-lint /bin /go/bin/golangci-lint
|
||||
COPY --from=mockgen /bin /go/bin/mockgen
|
||||
@@ -110,11 +110,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||
WIREGUARD_ADDRESSES= \
|
||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
WIREGUARD_MTU= \
|
||||
WIREGUARD_MTU=1320 \
|
||||
WIREGUARD_IMPLEMENTATION=auto \
|
||||
# PMTUD
|
||||
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
||||
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
|
||||
# VPN server filtering
|
||||
SERVER_REGIONS= \
|
||||
SERVER_COUNTRIES= \
|
||||
|
||||
@@ -58,7 +58,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
||||
## Features
|
||||
|
||||
- Based on Alpine 3.22 for a small Docker image of 41.1MB
|
||||
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
|
||||
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
|
||||
- Supports OpenVPN for all providers listed
|
||||
- Supports Wireguard both kernelspace and userspace
|
||||
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
|
||||
|
||||
+5
-10
@@ -168,7 +168,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
defer fmt.Println(gluetunLogo)
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2026-04-01T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "All control server routes are now private by default",
|
||||
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -237,10 +237,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = netLinker.FlushConntrack()
|
||||
if err != nil {
|
||||
logger.Warnf("flushing conntrack failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
@@ -268,7 +264,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID)
|
||||
|
||||
const clientTimeout = 35 * time.Second
|
||||
const clientTimeout = 15 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
// Create configurators
|
||||
alpineConf := alpine.New()
|
||||
@@ -283,7 +279,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
err = printVersions(ctx, logger, []printVersionElement{
|
||||
{name: "Alpine", getVersion: alpineConf.Version},
|
||||
{name: "OpenVPN", getVersion: ovpnVersion},
|
||||
{name: "Firewall", getVersion: firewallConf.Version},
|
||||
{name: "IPtables", getVersion: firewallConf.Version},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -398,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
|
||||
dnsLogger := logger.New(log.SetComponent("dns"))
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
|
||||
dnsLogger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating DNS loop: %w", err)
|
||||
@@ -560,7 +556,6 @@ type netLinker interface {
|
||||
Linker
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
IsIPv6Supported() (ok bool, err error)
|
||||
FlushConntrack() error
|
||||
PatchLoggerLevel(level log.Level)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,17 +4,15 @@ go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/ProtonMail/go-srp v0.0.7
|
||||
github.com/breml/rootcerts v0.3.4
|
||||
github.com/breml/rootcerts v0.3.3
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/nftables v0.3.0
|
||||
github.com/jsimonetti/rtnetlink v1.4.2
|
||||
github.com/klauspost/compress v1.18.1
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/mdlayher/genetlink v1.3.2
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10
|
||||
github.com/qdm12/gosettings v0.4.4
|
||||
github.com/qdm12/goshutdown v0.3.0
|
||||
github.com/qdm12/gosplash v0.2.0
|
||||
@@ -22,7 +20,6 @@ require (
|
||||
github.com/qdm12/log v0.1.0
|
||||
github.com/qdm12/ss-server v0.6.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/ti-mo/netfilter v0.5.3
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
@@ -43,8 +40,10 @@ require (
|
||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
|
||||
@@ -8,8 +8,8 @@ github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1J
|
||||
github.com/ProtonMail/go-srp v0.0.7/go.mod h1:giCp+7qRnMIcCvI6V6U3S1lDDXDQYx2ewJ6F/9wdlJk=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/breml/rootcerts v0.3.4 h1:9i7WNl/ctd9OEAOaTfLy//Wrlfxq/tRQ7v4okYFN9Ys=
|
||||
github.com/breml/rootcerts v0.3.4/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw=
|
||||
github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
@@ -30,8 +30,8 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
|
||||
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
|
||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||
@@ -49,8 +49,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
@@ -73,8 +73,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
|
||||
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 h1:0ycKUDQ50cYb2QpeyGcEnvVs9HJmC9jsb/XZNC1z28c=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10 h1:IyeNEYXfhBsaE1dwxx5eAqdAz1HS98dT+8c7xoKODa0=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
|
||||
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
|
||||
@@ -95,12 +95,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ=
|
||||
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
|
||||
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
|
||||
@@ -23,9 +23,7 @@ type DNSBlacklist struct {
|
||||
AddBlockedIPs []netip.Addr
|
||||
AddBlockedIPPrefixes []netip.Prefix
|
||||
// RebindingProtectionExemptHostnames is a list of hostnames
|
||||
// exempt from DNS rebinding protection. It can contain parent
|
||||
// domains which are of the form "*.example.com". Note the wildcard
|
||||
// can only be used at the start of the hostname.
|
||||
// exempt from DNS rebinding protection.
|
||||
RebindingProtectionExemptHostnames []string
|
||||
}
|
||||
|
||||
@@ -57,9 +55,6 @@ func (b DNSBlacklist) validate() (err error) {
|
||||
}
|
||||
|
||||
for _, host := range b.RebindingProtectionExemptHostnames {
|
||||
if len(host) > 2 && host[:2] == "*." {
|
||||
host = host[2:]
|
||||
}
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
|
||||
}
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
// PMTUD contains settings to configure Path MTU Discovery.
|
||||
type PMTUD struct {
|
||||
// ICMPAddresses is the redundancy list of addresses to use
|
||||
// for ICMP path MTU discovery. Each address MUST handle ICMP
|
||||
// packets for PMTUD to work.
|
||||
// It cannot be nil in the internal state.
|
||||
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
|
||||
// TCPAddresses is the redundancy list of addresses to use
|
||||
// for TCP path MTU discovery. Each address MUST have a listening
|
||||
// TCP server on the port specified.
|
||||
// It cannot be nil in the internal state.
|
||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||
)
|
||||
|
||||
// Validate validates PMTUD settings.
|
||||
func (p PMTUD) validate() (err error) {
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
for i, addr := range p.TCPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PMTUD) copy() (copied PMTUD) {
|
||||
return PMTUD{
|
||||
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
|
||||
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PMTUD) overrideWith(other PMTUD) {
|
||||
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
|
||||
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
|
||||
}
|
||||
|
||||
func (p *PMTUD) setDefaults() {
|
||||
defaultICMPAddresses := []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
|
||||
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
|
||||
}
|
||||
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
|
||||
|
||||
const dnsPort, tlsPort = 53, 443
|
||||
defaultTCPAddresses := []netip.AddrPort{
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), dnsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
|
||||
}
|
||||
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
||||
}
|
||||
|
||||
func (p PMTUD) String() string {
|
||||
return p.toLinesNode().String()
|
||||
}
|
||||
|
||||
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Path MTU discovery:")
|
||||
|
||||
icmpAddrNode := node.Append("ICMP addresses:")
|
||||
for _, addr := range p.ICMPAddresses {
|
||||
icmpAddrNode.Append(addr.String())
|
||||
}
|
||||
|
||||
tcpAddrNode := node.Append("TCP addresses:")
|
||||
for _, addr := range p.TCPAddresses {
|
||||
tcpAddrNode.Append(addr.String())
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func (p *PMTUD) read(r *reader.Reader) (err error) {
|
||||
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
@@ -33,11 +31,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
if vpnType == vpn.OpenVPN {
|
||||
validNames = providers.AllWithCustom()
|
||||
validNames = append(validNames, "pia") // Retro-compatibility
|
||||
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
|
||||
mullvadIndex := slices.Index(validNames, providers.Mullvad)
|
||||
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
|
||||
validNames = validNames[:len(validNames)-1]
|
||||
sort.Strings(validNames)
|
||||
} else { // Wireguard
|
||||
validNames = []string{
|
||||
providers.Airvpn,
|
||||
@@ -55,6 +48,10 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
||||
}
|
||||
|
||||
if p.Name == providers.Mullvad && vpnType == vpn.OpenVPN {
|
||||
warner.Warn("https://mullvad.net/en/blog/removing-openvpn-15th-january-2026")
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server selection: %w", err)
|
||||
|
||||
@@ -29,27 +29,14 @@ func Test_Settings_String(t *testing.T) {
|
||||
| | └── OpenVPN server selection settings:
|
||||
| | ├── Protocol: UDP
|
||||
| | └── Private Internet Access encryption preset: strong
|
||||
| ├── OpenVPN settings:
|
||||
| | ├── OpenVPN version: 2.6
|
||||
| | ├── User: [not set]
|
||||
| | ├── Password: [not set]
|
||||
| | ├── Private Internet Access encryption preset: strong
|
||||
| | ├── Network interface: tun0
|
||||
| | ├── Run OpenVPN as: root
|
||||
| | └── Verbosity level: 1
|
||||
| └── Path MTU discovery:
|
||||
| ├── ICMP addresses:
|
||||
| | ├── 1.1.1.1
|
||||
| | └── 8.8.8.8
|
||||
| └── TCP addresses:
|
||||
| ├── 1.1.1.1:53
|
||||
| ├── 8.8.8.8:53
|
||||
| ├── 1.1.1.1:443
|
||||
| ├── 8.8.8.8:443
|
||||
| ├── [2606:4700:4700::1111]:53
|
||||
| ├── [2001:4860:4860::8888]:53
|
||||
| ├── [2606:4700:4700::1111]:443
|
||||
| └── [2001:4860:4860::8888]:443
|
||||
| └── OpenVPN settings:
|
||||
| ├── OpenVPN version: 2.6
|
||||
| ├── User: [not set]
|
||||
| ├── Password: [not set]
|
||||
| ├── Private Internet Access encryption preset: strong
|
||||
| ├── Network interface: tun0
|
||||
| ├── Run OpenVPN as: root
|
||||
| └── Verbosity level: 1
|
||||
├── DNS settings:
|
||||
| ├── Keep existing nameserver(s): no
|
||||
| ├── DNS server address to use: 127.0.0.1
|
||||
|
||||
@@ -18,7 +18,6 @@ type VPN struct {
|
||||
Provider Provider `json:"provider"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
PMTUD PMTUD `json:"pmtud"`
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
@@ -46,11 +45,6 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
}
|
||||
}
|
||||
|
||||
err = v.PMTUD.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +54,6 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
Provider: v.Provider.copy(),
|
||||
OpenVPN: v.OpenVPN.copy(),
|
||||
Wireguard: v.Wireguard.copy(),
|
||||
PMTUD: v.PMTUD.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +62,6 @@ func (v *VPN) OverrideWith(other VPN) {
|
||||
v.Provider.overrideWith(other.Provider)
|
||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||
v.Wireguard.overrideWith(other.Wireguard)
|
||||
v.PMTUD.overrideWith(other.PMTUD)
|
||||
}
|
||||
|
||||
func (v *VPN) setDefaults() {
|
||||
@@ -77,7 +69,6 @@ func (v *VPN) setDefaults() {
|
||||
v.Provider.setDefaults()
|
||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||
v.Wireguard.setDefaults(v.Provider.Name)
|
||||
v.PMTUD.setDefaults()
|
||||
}
|
||||
|
||||
func (v VPN) String() string {
|
||||
@@ -94,7 +85,6 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
||||
} else {
|
||||
node.AppendNode(v.Wireguard.toLinesNode())
|
||||
}
|
||||
node.AppendNode(v.PMTUD.toLinesNode())
|
||||
|
||||
return node
|
||||
}
|
||||
@@ -117,10 +107,5 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
||||
return fmt.Errorf("wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = v.PMTUD.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,9 +38,15 @@ type Wireguard struct {
|
||||
Interface string `json:"interface"`
|
||||
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||
// It cannot be nil in the internal state, and defaults to
|
||||
// 0 indicating to use PMTUD.
|
||||
MTU *uint32 `json:"mtu"`
|
||||
// It cannot be zero in the internal state, and defaults to
|
||||
// 1320. Note it is not the wireguard-go MTU default of 1420
|
||||
// because this impacts bandwidth a lot on some VPN providers,
|
||||
// see https://github.com/qdm12/gluetun/issues/1650.
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue:
|
||||
// https://github.com/qdm12/gluetun/issues/2533.
|
||||
// Note this should now be replaced with the PMTUD feature.
|
||||
MTU uint32 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
// It defaults to "auto" and cannot be the empty string
|
||||
@@ -189,7 +195,8 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||
const defaultMTU = 1320
|
||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||
}
|
||||
|
||||
@@ -225,11 +232,7 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||
if *w.MTU == 0 {
|
||||
interfaceNode.Append("MTU: use path MTU discovery")
|
||||
} else {
|
||||
interfaceNode.Appendf("MTU: %d", *w.MTU)
|
||||
}
|
||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||
|
||||
if w.Implementation != "auto" {
|
||||
node.Appendf("Implementation: %s", w.Implementation)
|
||||
@@ -270,9 +273,11 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
||||
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if mtuPtr != nil {
|
||||
w.MTU = *mtuPtr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type Firewall interface {
|
||||
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dns
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -24,6 +24,7 @@ type Loop struct {
|
||||
localResolvers []netip.Addr
|
||||
resolvConf string
|
||||
client *http.Client
|
||||
firewall Firewall
|
||||
logger Logger
|
||||
userTrigger bool
|
||||
start <-chan struct{}
|
||||
@@ -39,7 +40,7 @@ type Loop struct {
|
||||
const defaultBackoffTime = 10 * time.Second
|
||||
|
||||
func NewLoop(settings settings.DNS,
|
||||
client *http.Client, logger Logger,
|
||||
client *http.Client, firewall Firewall, logger Logger,
|
||||
) (loop *Loop, err error) {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
|
||||
filter: filter,
|
||||
resolvConf: "/etc/resolv.conf",
|
||||
client: client,
|
||||
firewall: firewall,
|
||||
logger: logger,
|
||||
userTrigger: true,
|
||||
start: start,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
)
|
||||
|
||||
func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
|
||||
settings := l.GetSettings()
|
||||
|
||||
targetIP := settings.GetFirstPlaintextIPv4()
|
||||
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
|
||||
const dialTimeout = 3 * time.Second
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
|
||||
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
Timeout: dialTimeout,
|
||||
}
|
||||
nameserver.UseDNSInternally(settingsInternalDNS)
|
||||
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
+11
-10
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
@@ -23,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
"and go through your container network DNS outside the VPN tunnel!")
|
||||
} else {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -43,12 +44,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
runError, err = l.setupServer(ctx)
|
||||
if err == nil {
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.logger.Info("ready and using DNS server at address " + settings.ServerAddress.String())
|
||||
|
||||
err = l.updateFiles(ctx, settings)
|
||||
if err != nil {
|
||||
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
|
||||
}
|
||||
l.logger.Info("ready")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -57,6 +53,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.Is(err, errUpdateBlockLists) {
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
settings = l.GetSettings()
|
||||
}
|
||||
@@ -65,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
settings = l.GetSettings()
|
||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
l.userTrigger = false
|
||||
@@ -93,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
settings := l.GetSettings()
|
||||
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.stopServer()
|
||||
}
|
||||
l.stopped <- struct{}{}
|
||||
@@ -104,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
case err := <-runError: // unexpected error
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
return false
|
||||
}
|
||||
|
||||
+14
-7
@@ -2,24 +2,25 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
"github.com/qdm12/dns/v2/pkg/server"
|
||||
)
|
||||
|
||||
var errUpdateBlockLists = errors.New("cannot update filter block lists")
|
||||
|
||||
func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err error) {
|
||||
settings := l.GetSettings()
|
||||
var updateSettings update.Settings
|
||||
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
|
||||
err = l.filter.Update(updateSettings)
|
||||
err = l.updateFiles(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("updating filter for rebinding protection: %w", err)
|
||||
return nil, fmt.Errorf("%w: %w", errUpdateBlockLists, err)
|
||||
}
|
||||
|
||||
settings := l.GetSettings()
|
||||
|
||||
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building server settings: %w", err)
|
||||
@@ -38,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
|
||||
// use internal DNS server
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
|
||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
})
|
||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||
IPs: []netip.Addr{settings.ServerAddress},
|
||||
@@ -49,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
|
||||
}
|
||||
|
||||
err = check.WaitForDNS(ctx, check.Settings{})
|
||||
if err != nil {
|
||||
l.stopServer()
|
||||
|
||||
+15
-4
@@ -28,12 +28,23 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
|
||||
return
|
||||
case <-timer.C:
|
||||
lastTick = l.timeNow()
|
||||
settings := l.GetSettings()
|
||||
if l.GetStatus() == constants.Running {
|
||||
if err := l.updateFiles(ctx, settings); err != nil {
|
||||
l.logger.Warn("updating block lists failed, skipping: " + err.Error())
|
||||
|
||||
status := l.GetStatus()
|
||||
if status == constants.Running {
|
||||
if err := l.updateFiles(ctx); err != nil {
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logger.Error(err.Error())
|
||||
l.logger.Warn("skipping DNS server restart due to failed files update")
|
||||
settings := l.GetSettings()
|
||||
timer.Reset(*settings.UpdatePeriod)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = l.statusManager.ApplyStatus(ctx, constants.Running)
|
||||
|
||||
settings := l.GetSettings()
|
||||
timer.Reset(*settings.UpdatePeriod)
|
||||
case <-l.updateTicker:
|
||||
if !timer.Stop() {
|
||||
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/blockbuilder"
|
||||
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
)
|
||||
|
||||
func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err error) {
|
||||
func (l *Loop) updateFiles(ctx context.Context) (err error) {
|
||||
settings := l.GetSettings()
|
||||
|
||||
l.logger.Info("downloading hostnames and IP block lists")
|
||||
blacklistSettings := settings.Blacklist.ToBlockBuilderSettings(l.client)
|
||||
|
||||
@@ -36,6 +37,7 @@ func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err erro
|
||||
IPPrefixes: result.BlockedIPPrefixes,
|
||||
}
|
||||
updateSettings.BlockHostnames(result.BlockedHostnames)
|
||||
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
|
||||
err = l.filter.Update(updateSettings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating filter: %w", err)
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
+58
-28
@@ -22,7 +22,9 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
c.restore(ctx)
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return fmt.Errorf("disabling firewall: %w", err)
|
||||
}
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
return nil
|
||||
@@ -39,33 +41,64 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
c.restore, err = c.impl.SaveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving firewall rules: %w", err)
|
||||
func (c *Config) disable(ctx context.Context) (err error) {
|
||||
if err = c.clearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("clearing all rules: %w", err)
|
||||
}
|
||||
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv4 policies: %w", err)
|
||||
}
|
||||
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv6 policies: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.restore(context.Background())
|
||||
}
|
||||
}()
|
||||
const remove = true
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing port redirections: %w", err)
|
||||
}
|
||||
|
||||
if err = c.impl.SetBaseChainsPolicy(ctx, "DROP"); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// To use in defered call when enabling the firewall.
|
||||
func (c *Config) fallbackToDisabled(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := c.disable(ctx); err != nil {
|
||||
c.logger.Error("failed reversing firewall changes: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
touched := false
|
||||
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
touched = true
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
|
||||
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
|
||||
defer func() {
|
||||
if touched && err != nil {
|
||||
c.fallbackToDisabled(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
|
||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -75,9 +108,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
|
||||
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
|
||||
for _, network := range c.localNetworks {
|
||||
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
|
||||
network.InterfaceName, network.IP, network.IPNet, remove)
|
||||
if err != nil {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -86,7 +117,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
localInterfaces[network.InterfaceName] = struct{}{}
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
|
||||
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
|
||||
}
|
||||
@@ -99,7 +130,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
// Allows packets from any IP address to go through eth0 / local network
|
||||
// to reach Gluetun.
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
|
||||
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -108,12 +139,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.redirectPorts(ctx)
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting ports: %w", err)
|
||||
}
|
||||
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
|
||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
return fmt.Errorf("running user defined post firewall rules: %w", err)
|
||||
}
|
||||
|
||||
@@ -133,7 +164,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
|
||||
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
|
||||
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting output traffic through VPN: %w", err)
|
||||
}
|
||||
@@ -155,7 +186,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
|
||||
firewallUpdated = true
|
||||
|
||||
const remove = false
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subnet, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -173,7 +204,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
for port, netInterfaces := range c.allowedInputPorts {
|
||||
for netInterface := range netInterfaces {
|
||||
const remove = false
|
||||
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
|
||||
err = c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting input port %d on interface %s: %w",
|
||||
port, netInterface, err)
|
||||
@@ -183,10 +214,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) redirectPorts(ctx context.Context) (err error) {
|
||||
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
|
||||
for _, portRedirection := range c.portRedirections {
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
portRedirection.destinationPort, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,33 +2,34 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall/iptables"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
defaultRoutes []routing.DefaultRoute
|
||||
localNetworks []routing.LocalNetwork
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
defaultRoutes []routing.DefaultRoute
|
||||
localNetworks []routing.LocalNetwork
|
||||
|
||||
// Fixed
|
||||
impl firewallImpl
|
||||
// Fixed state
|
||||
ipTables string
|
||||
ip6Tables string
|
||||
customRulesPath string
|
||||
|
||||
// State
|
||||
enabled bool
|
||||
restore func(context.Context)
|
||||
vpnConnection models.Connection
|
||||
vpnIntf string
|
||||
outboundSubnets []netip.Prefix
|
||||
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||
portRedirections portRedirections
|
||||
outputAddrPort map[uint16]netip.Addr
|
||||
stateMutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -38,19 +39,26 @@ func NewConfig(ctx context.Context, logger Logger,
|
||||
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
|
||||
localNetworks []routing.LocalNetwork,
|
||||
) (config *Config, err error) {
|
||||
impl, err := iptables.New(ctx, runner, logger)
|
||||
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating iptables firewall: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip6tables, err := findIP6tablesSupported(ctx, runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Config{
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||
outputAddrPort: make(map[uint16]netip.Addr),
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
// Obtained from routing
|
||||
defaultRoutes: defaultRoutes,
|
||||
localNetworks: localNetworks,
|
||||
impl: impl,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
defaultRoutes: defaultRoutes,
|
||||
localNetworks: localNetworks,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
import "os/exec"
|
||||
|
||||
type CmdRunner interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
@@ -18,24 +12,3 @@ type Logger interface {
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type firewallImpl interface { //nolint:interfacebloat
|
||||
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
|
||||
AcceptEstablishedRelatedTraffic(ctx context.Context) error
|
||||
AcceptInputThroughInterface(ctx context.Context, intf string) error
|
||||
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
|
||||
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
|
||||
connection models.Connection, remove bool) error
|
||||
RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
destinationPort uint16, remove bool) error
|
||||
RunUserPostRules(ctx context.Context, customRulesPath string) error
|
||||
SetBaseChainsPolicy(ctx context.Context, policy string) error
|
||||
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error)
|
||||
Version(ctx context.Context) (version string, err error)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
ip6tablesPath string, err error,
|
||||
) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
|
||||
if errors.Is(err, ErrNotSupported) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
|
||||
if errors.Is(err, ErrIPTablesNotSupported) {
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
@@ -24,23 +24,8 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -48,24 +33,11 @@ func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instruction
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if c.ip6Tables == "" {
|
||||
return nil
|
||||
}
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
@@ -81,3 +53,18 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,6 +26,22 @@ func appendOrDelete(remove bool) string {
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(rule, "-A"):
|
||||
return strings.Replace(rule, "-A", "-D", 1)
|
||||
case strings.HasPrefix(rule, "--append"):
|
||||
return strings.Replace(rule, "--append", "-D", 1)
|
||||
case strings.HasPrefix(rule, "-D"):
|
||||
return strings.Replace(rule, "-D", "-A", 1)
|
||||
case strings.HasPrefix(rule, "--delete"):
|
||||
return strings.Replace(rule, "--delete", "-A", 1)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
|
||||
@@ -38,28 +54,12 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
return words[1], nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -70,19 +70,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
@@ -98,33 +85,42 @@ func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) SetBaseChainsPolicy(ctx context.Context, policy string) error {
|
||||
policy = strings.ToUpper(policy)
|
||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
|
||||
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"--append INPUT -i %s -j ACCEPT", intf))
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
|
||||
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
|
||||
destination netip.Prefix, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
|
||||
interfaceFlag, destination.String())
|
||||
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destination.String())
|
||||
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
@@ -135,20 +131,20 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
|
||||
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool,
|
||||
) error {
|
||||
protocol := connection.Protocol
|
||||
@@ -166,11 +162,8 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
// Thanks to @npawelek.
|
||||
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
|
||||
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
|
||||
) error {
|
||||
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
|
||||
@@ -191,24 +184,21 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
|
||||
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
|
||||
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
|
||||
// all interfaces. If remove is true, the rule is removed instead of added.
|
||||
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
|
||||
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4).
|
||||
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
|
||||
intf string, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag)
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
|
||||
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
|
||||
// intf set to the VPN tunnel interface.
|
||||
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
// Used for port forwarding, with intf set to tun.
|
||||
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
@@ -219,12 +209,8 @@ func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectPort redirects incoming traffic on the specified source port to the
|
||||
// specified destination port, for both TCP and UDP protocols, on the interface intf.
|
||||
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
|
||||
// the redirection is removed instead of added. This is used for VPN server side
|
||||
// port forwarding, with intf set to the VPN tunnel interface.
|
||||
func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
|
||||
func (c *Config) redirectPort(ctx context.Context, intf string,
|
||||
sourcePort, destinationPort uint16, remove bool,
|
||||
) (err error) {
|
||||
interfaceFlag := "-i " + intf
|
||||
@@ -232,17 +218,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, []string{
|
||||
err = c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
@@ -253,12 +229,11 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
|
||||
err = c.runIP6tablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
@@ -269,7 +244,6 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx) // just in case
|
||||
errMessage := err.Error()
|
||||
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
|
||||
if !remove {
|
||||
@@ -283,7 +257,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -299,17 +273,16 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
}()
|
||||
for _, line := range lines {
|
||||
var ipv4 bool
|
||||
var rule string
|
||||
@@ -336,18 +309,23 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstructionNoSave(ctx, rule)
|
||||
err = c.runIptablesInstruction(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
||||
err = c.runIP6tablesInstruction(ctx, rule)
|
||||
}
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SaveAndRestore saves the current iptables and ip6tables rules and
|
||||
// returns a restore function that can be called to restore the saved rules.
|
||||
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
return c.saveAndRestore(ctx)
|
||||
}
|
||||
|
||||
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
|
||||
// before calling this function. Note the restore function does not interact with mutexes
|
||||
// so the caller must make sure the mutexes are locked when calling the restore function.
|
||||
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
restoreIPv4(ctx)
|
||||
if restoreIPv6 != nil {
|
||||
restoreIPv6(ctx)
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
|
||||
if c.ip6Tables == "" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
|
||||
// Fixed state
|
||||
ipTables string
|
||||
ip6Tables string
|
||||
nftables bool
|
||||
xtMark bool
|
||||
}
|
||||
|
||||
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
|
||||
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip6tables, err := findIP6tablesSupported(ctx, runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modules := map[string]bool{
|
||||
"xt_mark": false,
|
||||
"nf_tables": false,
|
||||
}
|
||||
for module := range modules {
|
||||
err := mod.Probe(module)
|
||||
modules[module] = err == nil
|
||||
}
|
||||
|
||||
return &Config{
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
nftables: modules["nf_tables"],
|
||||
xtMark: modules["xt_mark"],
|
||||
}, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import "os/exec"
|
||||
|
||||
type CmdRunner interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runMixedIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type tcpFlags struct {
|
||||
mask []tcpFlag
|
||||
comparison []tcpFlag
|
||||
}
|
||||
|
||||
type tcpFlag uint8
|
||||
|
||||
const (
|
||||
tcpFlagFIN tcpFlag = 1 << iota
|
||||
tcpFlagSYN
|
||||
tcpFlagRST
|
||||
tcpFlagPSH
|
||||
tcpFlagACK
|
||||
tcpFlagURG
|
||||
tcpFlagECE
|
||||
tcpFlagCWR
|
||||
)
|
||||
|
||||
func (f tcpFlag) String() string {
|
||||
switch f {
|
||||
case tcpFlagFIN:
|
||||
return "FIN"
|
||||
case tcpFlagSYN:
|
||||
return "SYN"
|
||||
case tcpFlagRST:
|
||||
return "RST"
|
||||
case tcpFlagPSH:
|
||||
return "PSH"
|
||||
case tcpFlagACK:
|
||||
return "ACK"
|
||||
case tcpFlagURG:
|
||||
return "URG"
|
||||
case tcpFlagECE:
|
||||
return "ECE"
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
|
||||
}
|
||||
for _, flag := range allFlags {
|
||||
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
if !c.nftables && !c.xtMark {
|
||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||
}
|
||||
|
||||
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
|
||||
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
run := c.runIptablesInstruction
|
||||
if dst.Addr().Is6() {
|
||||
run = c.runIP6tablesInstruction
|
||||
}
|
||||
revert = func(ctx context.Context) error {
|
||||
return run(ctx, revertInstruction)
|
||||
}
|
||||
err = run(ctx, instruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running instruction: %w", err)
|
||||
}
|
||||
return revert, nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
|
||||
ipv4Instructions, ipv6Instructions []string,
|
||||
) error {
|
||||
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
|
||||
return fmt.Errorf("running iptables instructions: %w", err)
|
||||
}
|
||||
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
|
||||
return fmt.Errorf("running ip6tables instructions: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -26,18 +26,10 @@ type chainRule struct {
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
sourcePort uint16 // Not specified if set to zero.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
type mark struct {
|
||||
invert bool
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
@@ -249,23 +241,19 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
i := 0
|
||||
for i < len(optionalFields) {
|
||||
switch optionalFields[i] {
|
||||
case "udp":
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
consumed, err := parseUDPOptional(optionalFields[i:], rule)
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing UDP optional fields: %w", err)
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
i += consumed
|
||||
case "tcp":
|
||||
i++
|
||||
consumed, err := parseTCPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing TCP optional fields: %w", err)
|
||||
}
|
||||
i += consumed
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
@@ -276,136 +264,20 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
i++
|
||||
case "mark":
|
||||
i++
|
||||
mark, consumed, err := parseMark(optionalFields[i:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mark: %w", err)
|
||||
}
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a UDP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a TCP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "flags:"):
|
||||
rule.tcpFlags, err = parseTCPFlags(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseDestinationPort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
func parseSourcePort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "spt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
var err error
|
||||
for i, maskFlag := range maskFlags {
|
||||
mask[i], err = parseTCPFlag(maskFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
|
||||
}
|
||||
}
|
||||
comparisonFlags := strings.Split(fields[1], ",")
|
||||
comparison := make([]tcpFlag, len(comparisonFlags))
|
||||
for i, comparisonFlag := range comparisonFlags {
|
||||
comparison[i], err = parseTCPFlag(comparisonFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
|
||||
}
|
||||
}
|
||||
return tcpFlags{
|
||||
mask: mask,
|
||||
comparison: comparison,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
@@ -414,40 +286,16 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
ports[i], err = parsePort(field)
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
consumed++
|
||||
if optionalFields[consumed] == "!" {
|
||||
m.invert = true
|
||||
consumed++
|
||||
}
|
||||
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger)
|
||||
|
||||
// Package iptables is a generated GoMock package.
|
||||
package iptables
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
exec "os/exec"
|
||||
@@ -1,99 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
// SaveAndRestore saves the current nftables tree and returns a restore function that
|
||||
// can be called to restore the saved tree.
|
||||
func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
tables, err := saveTables(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving nftables state: %w", err)
|
||||
}
|
||||
return func(_ context.Context) {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
f.logger.Warnf("creating nftables connection for restore: %s", err)
|
||||
return
|
||||
}
|
||||
err = restoreTables(conn, tables)
|
||||
if err != nil {
|
||||
f.logger.Warnf("restoring nftables state: %s", err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type savedTable struct {
|
||||
table *nftables.Table
|
||||
chains []savedChain
|
||||
}
|
||||
|
||||
type savedChain struct {
|
||||
chain *nftables.Chain
|
||||
rules []*nftables.Rule
|
||||
}
|
||||
|
||||
func saveTables(conn *nftables.Conn) ([]savedTable, error) {
|
||||
tables, err := conn.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedTables := make([]savedTable, len(tables))
|
||||
for i, table := range tables {
|
||||
savedTables[i].table = table
|
||||
|
||||
chains, err := conn.ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name ||
|
||||
chain.Table.Family != table.Family {
|
||||
continue
|
||||
}
|
||||
rules, err := conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err)
|
||||
}
|
||||
savedChain := savedChain{chain: chain, rules: rules}
|
||||
savedTables[i].chains = append(savedTables[i].chains, savedChain)
|
||||
}
|
||||
}
|
||||
|
||||
return savedTables, nil
|
||||
}
|
||||
|
||||
func restoreTables(conn *nftables.Conn, savedTables []savedTable) error {
|
||||
conn.FlushRuleset()
|
||||
|
||||
for _, savedTable := range savedTables {
|
||||
table := conn.AddTable(savedTable.table)
|
||||
for _, savedChain := range savedTable.chains {
|
||||
// Make the [nftables.Chain.Table] points to the new [nftables.Table]
|
||||
// created in this connection.
|
||||
savedChain.chain.Table = table
|
||||
savedChain.chain = conn.AddChain(savedChain.chain)
|
||||
|
||||
for _, rule := range savedChain.rules {
|
||||
rule.Table = table
|
||||
rule.Chain = savedChain.chain
|
||||
conn.AddRule(rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conn.Flush()
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
var ErrPolicyUnknown = errors.New("unknown policy")
|
||||
|
||||
// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT)
|
||||
// for the filter table to the given policy (accept or drop).
|
||||
func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
var chainPolicy nftables.ChainPolicy
|
||||
switch strings.ToLower(policy) {
|
||||
case "accept":
|
||||
chainPolicy = nftables.ChainPolicyAccept
|
||||
case "drop":
|
||||
chainPolicy = nftables.ChainPolicyDrop
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
_, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn)
|
||||
inputChain.Policy = &chainPolicy
|
||||
forwardChain.Policy = &chainPolicy
|
||||
outputChain.Policy = &chainPolicy
|
||||
|
||||
conn.AddChain(inputChain)
|
||||
conn.AddChain(forwardChain)
|
||||
conn.AddChain(outputChain)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing nftables changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
func (f *Firewall) AcceptEstablishedRelatedTraffic(_ context.Context) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, outputChain := setupFilterWithBaseChains(conn)
|
||||
|
||||
ctStateExprs := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4, //nolint:mnd
|
||||
Mask: []byte{byte(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), 0x00, 0x00, 0x00},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
conn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: ctStateExprs,
|
||||
})
|
||||
|
||||
conn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: outputChain,
|
||||
Exprs: ctStateExprs,
|
||||
})
|
||||
|
||||
if err := conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
var errRuleToDeleteNotFound = errors.New("rule not found for removal")
|
||||
|
||||
func (f *Firewall) deleteRule(conn *nftables.Conn, rule *nftables.Rule) error {
|
||||
for i, existing := range f.rules {
|
||||
if !reflect.DeepEqual(existing, rule) {
|
||||
continue
|
||||
}
|
||||
err := conn.DelRule(existing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting rule: %w", err)
|
||||
}
|
||||
f.rules[i], f.rules[len(f.rules)-1] = f.rules[len(f.rules)-1], f.rules[i]
|
||||
f.rules = f.rules[:len(f.rules)-1]
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %#v", errRuleToDeleteNotFound, rule)
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import "github.com/google/nftables"
|
||||
|
||||
func setupFilterWithBaseChains(conn *nftables.Conn) (table *nftables.Table,
|
||||
inputChain, forwardChain, outputChain *nftables.Chain,
|
||||
) {
|
||||
table = conn.AddTable(&nftables.Table{
|
||||
Family: nftables.TableFamilyINet,
|
||||
Name: "filter",
|
||||
})
|
||||
|
||||
inputChain = conn.AddChain(&nftables.Chain{
|
||||
Name: "input",
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
})
|
||||
|
||||
forwardChain = conn.AddChain(&nftables.Chain{
|
||||
Name: "forward",
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
})
|
||||
|
||||
outputChain = conn.AddChain(&nftables.Chain{
|
||||
Name: "output",
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
})
|
||||
|
||||
return table, inputChain, forwardChain, outputChain
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
type Firewall struct {
|
||||
logger Logger
|
||||
|
||||
// rules are only rules added and tracked for later removal.
|
||||
// Not all rules added are tracked for removal.
|
||||
rules []*nftables.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func New(logger Logger) *Firewall {
|
||||
return &Firewall{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(intf + "\x00"),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conn.AddRule(rule)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
|
||||
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
|
||||
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
|
||||
// intf set to the VPN tunnel interface.
|
||||
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
|
||||
const tcp, udp uint8 = 6, 17
|
||||
protocols := []uint8{tcp, udp}
|
||||
|
||||
for _, protocol := range protocols {
|
||||
const maxExprsLen = 7
|
||||
exprs := make([]expr.Any, 0, maxExprsLen)
|
||||
if intf != "" && intf != "*" {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
if !remove {
|
||||
conn.AddRule(rule)
|
||||
f.rules = append(f.rules, rule)
|
||||
continue
|
||||
}
|
||||
err = f.deleteRule(conn, rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
f.rules = f.rules[:len(f.rules)-len(protocols)]
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
|
||||
|
||||
const maxExprsLen = 5
|
||||
exprs := make([]expr.Any, 0, maxExprsLen)
|
||||
|
||||
if intf != "" && intf != "*" {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
|
||||
)
|
||||
}
|
||||
|
||||
var payloadOffset uint32
|
||||
if subnet.Addr().Is4() {
|
||||
payloadOffset = 16
|
||||
} else {
|
||||
payloadOffset = 24
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: payloadOffset,
|
||||
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: subnet.Addr().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: inputChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
conn.AddRule(rule)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package nftables
|
||||
|
||||
type Logger interface {
|
||||
Warnf(format string, args ...any)
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
func (f *Firewall) AcceptIpv6MulticastOutput(_ context.Context, intf string) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
|
||||
table, _, _, outputChain := setupFilterWithBaseChains(conn)
|
||||
|
||||
const maxExprsLen = 6
|
||||
exprs := make([]expr.Any, 0, maxExprsLen)
|
||||
|
||||
if intf != "" && intf != "*" {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
|
||||
)
|
||||
}
|
||||
|
||||
// ff02::1:ff00:0/104 mask is 13 bytes of 0xff
|
||||
mask := []byte{
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00,
|
||||
} //nolint:mnd
|
||||
addr := []byte{
|
||||
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00,
|
||||
} //nolint:mnd
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 24, // IPv6 Destination Address offset //nolint:mnd
|
||||
Len: 16, //nolint:mnd
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 16, //nolint:mnd
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //nolint:mnd
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: addr,
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: outputChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
conn.AddRule(rule)
|
||||
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import "github.com/google/nftables"
|
||||
|
||||
func IsSupported() bool {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, err = conn.ListTable("filter")
|
||||
return err == nil
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
|
||||
}
|
||||
|
||||
firewallUpdated = true
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subNet, remove)
|
||||
if err != nil {
|
||||
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
||||
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
|
||||
}
|
||||
|
||||
firewallUpdated = true
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subnet, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -9,22 +9,30 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type operation uint8
|
||||
|
||||
const (
|
||||
opNone operation = iota
|
||||
opAppend
|
||||
opDelete
|
||||
opInsert
|
||||
opReplace
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
operation operation
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
sourcePort uint16 // if zero, there is no source port
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
lineNumber uint16 // for replace operation, the line number to replace
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
@@ -46,8 +54,6 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case i.sourcePort != rule.sourcePort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
@@ -60,16 +66,63 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
case i.mark != rule.mark:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) String() string {
|
||||
var sb strings.Builder
|
||||
if i.table != "" && i.table != "filter" {
|
||||
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
|
||||
}
|
||||
switch i.operation {
|
||||
case opNone:
|
||||
panic("no operation specified")
|
||||
case opAppend:
|
||||
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
|
||||
case opDelete:
|
||||
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
|
||||
case opInsert:
|
||||
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
|
||||
case opReplace:
|
||||
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
|
||||
}
|
||||
if i.inputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
|
||||
}
|
||||
if i.outputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
|
||||
}
|
||||
if i.protocol != "" {
|
||||
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
|
||||
}
|
||||
if i.source.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
|
||||
}
|
||||
if i.destination.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
|
||||
}
|
||||
if i.destinationPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
|
||||
}
|
||||
if len(i.ctstate) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
|
||||
}
|
||||
if len(i.toPorts) > 0 {
|
||||
var portStrings []string
|
||||
for _, port := range i.toPorts {
|
||||
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
|
||||
}
|
||||
if i.target != "" {
|
||||
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
@@ -102,39 +155,53 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
||||
}
|
||||
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
consumed, err = preCheckInstructionFields(fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
flag := fields[0]
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "-R", "--replace":
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
consumed = expected
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
consumed = expected
|
||||
}
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.operation = opDelete
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.operation = opAppend
|
||||
instruction.chain = value
|
||||
case "-I", "--insert":
|
||||
instruction.operation = opInsert
|
||||
instruction.chain = value
|
||||
case "-R", "--replace":
|
||||
instruction.operation = opReplace
|
||||
instruction.chain = value
|
||||
const base, bits = 10, 16
|
||||
n, err := strconv.ParseUint(fields[2], base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
|
||||
}
|
||||
instruction.lineNumber = uint16(n)
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match":
|
||||
consumed, err = parseMatchModule(fields, instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||
}
|
||||
case "--mark":
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(value, base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
|
||||
}
|
||||
instruction.mark.value = uint(value)
|
||||
case "-m", "--match": // ignore match
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
@@ -144,61 +211,37 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "--sport":
|
||||
instruction.sourcePort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
instruction.destinationPort, err = parsePort(value)
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
instruction.toPorts, err = parseToPorts(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
@@ -211,52 +254,3 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
func parsePort(value string) (port uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(portValue), nil
|
||||
}
|
||||
|
||||
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed int, err error,
|
||||
) {
|
||||
_ = fields[consumed] // -m or --match flag already detected
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "tcp", "udp":
|
||||
consumed++
|
||||
// for now ignore the protocol match since it's auto-loaded
|
||||
// when parsing the -p/--protocol flag, and we don't need to
|
||||
// parse it twice.
|
||||
case "mark":
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||
portStrings := strings.Split(value, ",")
|
||||
toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
toPorts[i], err = parsePort(portString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toPorts, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
@@ -28,14 +28,14 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
s: "-I INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
operation: opInsert,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
operation: opAppend,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
operation: opDelete,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
+133
-2
@@ -3,6 +3,7 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
|
||||
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
|
||||
|
||||
const remove = false
|
||||
if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("allowing input to port %d through interface %s: %w",
|
||||
port, intf, err)
|
||||
}
|
||||
@@ -68,7 +69,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
const remove = true
|
||||
for netInterface := range interfacesSet {
|
||||
err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
|
||||
err := c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing allowed port %d on interface %s: %w",
|
||||
port, netInterface, err)
|
||||
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
|
||||
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
|
||||
// If the port was previously allowed for another IP address, that previous allowance will be removed.
|
||||
// Giving an invalid address will remove any existing restrictions for the port specified.
|
||||
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
existingIP := c.outputAddrPort[addrPort.Port()]
|
||||
|
||||
switch {
|
||||
case existingIP == addrPort.Addr():
|
||||
return nil
|
||||
case !addrPort.Addr().IsValid():
|
||||
// invalid address, remove any existing rules for the port
|
||||
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
|
||||
case !existingIP.IsValid():
|
||||
// no previous existing address for the port
|
||||
return c.insertOutputAddrPortRestriction(ctx, addrPort)
|
||||
default:
|
||||
// existing rule in the same IP family or different family
|
||||
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
}
|
||||
if existingIP.Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.outputAddrPort, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
}
|
||||
if addrPort.Addr().Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
|
||||
existingIP netip.Addr, addrPort netip.AddrPort,
|
||||
) (err error) {
|
||||
for _, protocol := range [...]string{"udp", "tcp"} {
|
||||
switch {
|
||||
case existingIP.Is4() && addrPort.Addr().Is4():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is6():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is4() && addrPort.Addr().Is6():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv4 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is4():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv6 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv4 rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
return nil
|
||||
case conflict != nil:
|
||||
const remove = true
|
||||
err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
|
||||
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
|
||||
conflict.destinationPort, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing conflicting redirection: %w", err)
|
||||
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
}
|
||||
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
|
||||
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting port: %w", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIptablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
|
||||
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIP6tablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrIPTablesNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
|
||||
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
func testIptablesPath(ctx context.Context, path string,
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errSentinel: ErrIPTablesNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove := true
|
||||
if c.vpnConnection.IP.IsValid() {
|
||||
for _, defaultRoute := range c.defaultRoutes {
|
||||
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
c.vpnConnection = models.Connection{}
|
||||
|
||||
if c.vpnIntf != "" {
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove = false
|
||||
|
||||
for _, defaultRoute := range c.defaultRoutes {
|
||||
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
|
||||
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
|
||||
}
|
||||
}
|
||||
c.vpnConnection = connection
|
||||
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
|
||||
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
|
||||
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
|
||||
}
|
||||
c.vpnIntf = vpnIntf
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (c *Config) Version(ctx context.Context) (version string, err error) {
|
||||
return c.impl.Version(ctx)
|
||||
}
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
|
||||
}
|
||||
@@ -23,7 +23,6 @@ type Checker struct {
|
||||
logger Logger
|
||||
icmpTargetIPs []netip.Addr
|
||||
smallCheckType string
|
||||
startupOnFail bool
|
||||
configMutex sync.Mutex
|
||||
|
||||
icmpNotPermitted *bool
|
||||
@@ -46,43 +45,26 @@ func NewChecker(logger Logger) *Checker {
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the following:
|
||||
// - TCP+TLS dial addresses
|
||||
// - ICMP echo IP addresses to target
|
||||
// - the desired small check type (dns or icmp)
|
||||
// - whether to startup the periodic checks if the startup check fails.
|
||||
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
|
||||
// to target and the desired small check type (dns or icmp).
|
||||
// This function MUST be called before calling [Checker.Start].
|
||||
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
|
||||
smallCheckType string, startupOnFail bool,
|
||||
smallCheckType string,
|
||||
) {
|
||||
c.configMutex.Lock()
|
||||
defer c.configMutex.Unlock()
|
||||
c.tlsDialAddrs = tlsDialAddrs
|
||||
c.icmpTargetIPs = icmpTargets
|
||||
c.smallCheckType = smallCheckType
|
||||
c.startupOnFail = startupOnFail
|
||||
}
|
||||
|
||||
// Start starts the [Checker] which behaves differently according to its
|
||||
// internal field startupOnFail, which is set by calling [Checker.SetConfig].
|
||||
//
|
||||
// By default, startupOnFail should be false and the behavior is as follows:
|
||||
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
|
||||
// an error is returned and the [Checker] is not started.
|
||||
// On success, it starts the periodic checks in a separate goroutine, returning
|
||||
// the runError error channel and a nil error.
|
||||
//
|
||||
// If startupOnFail is true, the behavior is as follows:
|
||||
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
|
||||
// the error is sent to the runError channel, but no error is returned
|
||||
// and the [Checker] continues to start the periodic checks in a separate goroutine, returning
|
||||
// the runError error channel and a nil error.
|
||||
//
|
||||
// The periodic checks consist in:
|
||||
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
|
||||
// and, on success, starts the periodic checks in a separate goroutine:
|
||||
// - a "small" ICMP echo check every minute
|
||||
// - a "full" TCP+TLS check every 5 minutes
|
||||
//
|
||||
// The [Checker] has to be ultimately stopped by calling [Checker.Stop].
|
||||
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
|
||||
// It returns an error if the initial TCP+TLS check fails.
|
||||
// The Checker has to be ultimately stopped by calling [Checker.Stop].
|
||||
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
|
||||
panic("call Checker.SetConfig with non empty values before Checker.Start")
|
||||
@@ -94,19 +76,9 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
}
|
||||
c.echoer.Reset()
|
||||
|
||||
// runErrorCh MUST be buffered in the case startupOnFail is true, and
|
||||
// a startup error was encountered, to avoid blocking the startup
|
||||
// goroutine when sending the error, especially since the caller may
|
||||
// not be ready to receive from the channel yet.
|
||||
runErrorCh := make(chan error, 1)
|
||||
runError = runErrorCh
|
||||
err = c.startupCheck(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("startup check: %w", err)
|
||||
if !c.startupOnFail {
|
||||
return nil, err
|
||||
}
|
||||
runErrorCh <- err
|
||||
return nil, fmt.Errorf("startup check: %w", err)
|
||||
}
|
||||
|
||||
ready := make(chan struct{})
|
||||
@@ -118,6 +90,8 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
smallCheckTimer := time.NewTimer(smallCheckPeriod)
|
||||
const fullCheckPeriod = 5 * time.Minute
|
||||
fullCheckTimer := time.NewTimer(fullCheckPeriod)
|
||||
runErrorCh := make(chan error)
|
||||
runError = runErrorCh
|
||||
go func() {
|
||||
defer close(done)
|
||||
close(ready)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
|
||||
@@ -22,9 +21,6 @@ func NewClient(httpClient *http.Client) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Check(ctx context.Context, url string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errBuiltinModuleNotFound = errors.New("builtin module not found")
|
||||
|
||||
func checkModulesBuiltin(modulesPath, moduleName string) error {
|
||||
f, err := os.Open(filepath.Join(modulesPath, "modules.builtin"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
moduleName = strings.TrimSuffix(moduleName, ".ko")
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSuffix(line, ".ko")
|
||||
if strings.HasSuffix(line, "/"+moduleName) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %s", errBuiltinModuleNotFound, moduleName)
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errModuleNameUnknown = errors.New("unknown module name")
|
||||
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
|
||||
errKernelFeatureNotSet = errors.New("kernel feature not set")
|
||||
errKernelFeatureNotFound = errors.New("kernel feature not found")
|
||||
)
|
||||
|
||||
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
||||
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
||||
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
||||
// feature is a module, not built-in.
|
||||
// If the kernel feature is found and not set, it returns an error indicating that the kernel
|
||||
// feature is not set. If the kernel feature is not found, it returns an error indicating that the kernel
|
||||
// feature is not found.
|
||||
func checkProcConfig(moduleName string) error {
|
||||
f, err := os.Open("/proc/config.gz")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
// If any group of kernel features is satisfied, then the module is considered supported.
|
||||
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
|
||||
}
|
||||
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
||||
for i, group := range kernelFeatureGroups {
|
||||
featureToOK := make(map[string]bool)
|
||||
for _, feature := range group {
|
||||
featureToOK[feature] = false
|
||||
}
|
||||
groups[i] = featureToOK
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(gz)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
for _, featureToOK := range groups {
|
||||
for name, ok := range featureToOK {
|
||||
switch {
|
||||
case ok:
|
||||
case strings.HasPrefix(line, name+"=m"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
|
||||
case strings.HasPrefix(line, name+"=y"):
|
||||
featureToOK[name] = true
|
||||
if allFeaturesOK(featureToOK) {
|
||||
return nil
|
||||
}
|
||||
case strings.HasPrefix(line, "# "+name+" is not set"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
|
||||
}
|
||||
|
||||
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
||||
moduleMap := map[string][][]string{
|
||||
"nf_tables": {{"CONFIG_NF_TABLES"}},
|
||||
|
||||
// Netfilter Matches
|
||||
"xt_conntrack": {{"CONFIG_NETFILTER_XT_MATCH_CONNTRACK"}},
|
||||
"xt_connmark": {
|
||||
{"CONFIG_NETFILTER_XT_CONNMARK"},
|
||||
{"CONFIG_NETFILTER_XT_MATCH_CONNMARK", "CONFIG_NETFILTER_XT_TARGET_CONNMARK"},
|
||||
},
|
||||
"xt_mark": {
|
||||
{"CONFIG_NETFILTER_XT_MARK"},
|
||||
{"CONFIG_NETFILTER_XT_MATCH_MARK", "CONFIG_NETFILTER_XT_TARGET_MARK"},
|
||||
},
|
||||
"nf_conntrack_netlink": {{"CONFIG_NF_CT_NETLINK"}},
|
||||
"nf_reject_ipv4": {{"CONFIG_NF_REJECT_IPV4"}},
|
||||
|
||||
// Common Netfilter Targets
|
||||
"xt_log": {{"CONFIG_NETFILTER_XT_TARGET_LOG"}},
|
||||
"xt_reject": {
|
||||
{"CONFIG_IP_NF_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
|
||||
{"CONFIG_NETFILTER_XT_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
|
||||
},
|
||||
"xt_masquerade": {{"CONFIG_NETFILTER_XT_TARGET_MASQUERADE"}},
|
||||
|
||||
// Additional Netfilter Matches
|
||||
"xt_addrtype": {{"CONFIG_NETFILTER_XT_MATCH_ADDRTYPE"}},
|
||||
"xt_comment": {{"CONFIG_NETFILTER_XT_MATCH_COMMENT"}},
|
||||
"xt_multiport": {{"CONFIG_NETFILTER_XT_MATCH_MULTIPORT"}},
|
||||
"xt_state": {{"CONFIG_NETFILTER_XT_MATCH_STATE"}},
|
||||
"xt_tcpudp": {{"CONFIG_NETFILTER_XT_MATCH_TCPUDP"}},
|
||||
|
||||
// Tunneling and Virtualization
|
||||
"tun": {{"CONFIG_TUN"}},
|
||||
"bridge": {{"CONFIG_BRIDGE"}},
|
||||
"veth": {{"CONFIG_VETH"}},
|
||||
"vxlan": {{"CONFIG_VXLAN"}},
|
||||
"wireguard": {{"CONFIG_WIREGUARD"}},
|
||||
|
||||
// Filesystems
|
||||
"overlay": {{"CONFIG_OVERLAY_FS"}},
|
||||
"fuse": {{"CONFIG_FUSE_FS"}},
|
||||
}
|
||||
|
||||
featureGroups, ok = moduleMap[strings.ToLower(moduleName)]
|
||||
return featureGroups, ok
|
||||
}
|
||||
|
||||
func allFeaturesOK(featureToOK map[string]bool) bool {
|
||||
for _, ok := range featureToOK {
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
+30
-34
@@ -30,7 +30,36 @@ type moduleInfo struct {
|
||||
|
||||
var ErrModulesDirectoryNotFound = errors.New("modules directory not found")
|
||||
|
||||
func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err error) {
|
||||
func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
|
||||
var utsName unix.Utsname
|
||||
err = unix.Uname(&utsName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting unix uname release: %w", err)
|
||||
}
|
||||
release := unix.ByteSliceToString(utsName.Release[:])
|
||||
release = strings.TrimSpace(release)
|
||||
|
||||
modulePaths := []string{
|
||||
filepath.Join("/lib/modules", release),
|
||||
filepath.Join("/usr/lib/modules", release),
|
||||
}
|
||||
|
||||
var modulesPath string
|
||||
var found bool
|
||||
for _, modulesPath = range modulePaths {
|
||||
info, err := os.Stat(modulesPath)
|
||||
if err == nil && info.IsDir() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("%w: %s are not valid existing directories"+
|
||||
"; have you bind mounted the /lib/modules directory?",
|
||||
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
|
||||
}
|
||||
|
||||
dependencyFilepath := filepath.Join(modulesPath, "modules.dep")
|
||||
dependencyFile, err := os.Open(dependencyFilepath)
|
||||
if err != nil {
|
||||
@@ -82,39 +111,6 @@ func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err
|
||||
return modulesInfo, nil
|
||||
}
|
||||
|
||||
func getModulesPath() (string, error) {
|
||||
release, err := getReleaseName()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting release name: %w", err)
|
||||
}
|
||||
|
||||
modulePaths := []string{
|
||||
filepath.Join("/lib/modules", release),
|
||||
filepath.Join("/usr/lib/modules", release),
|
||||
}
|
||||
|
||||
for _, modulesPath := range modulePaths {
|
||||
info, err := os.Stat(modulesPath)
|
||||
if err == nil && info.IsDir() {
|
||||
return modulesPath, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%w: %s are not valid existing directories"+
|
||||
"; have you bind mounted the /lib/modules directory?",
|
||||
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
|
||||
}
|
||||
|
||||
func getReleaseName() (release string, err error) {
|
||||
var utsName unix.Utsname
|
||||
err = unix.Uname(&utsName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting unix uname release: %w", err)
|
||||
}
|
||||
release = unix.ByteSliceToString(utsName.Release[:])
|
||||
release = strings.TrimSpace(release)
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error {
|
||||
file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin"))
|
||||
if err != nil {
|
||||
|
||||
@@ -1,49 +1,12 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Probe is a expanded version of modprobe, in which it checks if the Kernel
|
||||
// built-in features contain the given module name.
|
||||
// It first tries to locate the modules directory in [getModulesPath].
|
||||
// If it fails (like on WSL), it then only checks for the kernel feature
|
||||
// in /proc/config.gz with [checkProcConfig].
|
||||
// Otherwise, it first checks if the modules directory modules.builtin
|
||||
// file contains the given module name in [checkModulesBuiltin].
|
||||
// If the module is not found, it then runs the classic [modProbe] behavior,
|
||||
// trying to load the module in the kernel.
|
||||
// If this fails, it does one final try running [checkProcConfig].
|
||||
// Probe loads the given kernel module and its dependencies.
|
||||
func Probe(moduleName string) error {
|
||||
modulesPath, err := getModulesPath()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrModulesDirectoryNotFound) {
|
||||
err = checkProcConfig(moduleName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking /proc/config.gz: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("getting modules path: %w", err)
|
||||
}
|
||||
|
||||
err = checkModulesBuiltin(modulesPath, moduleName)
|
||||
if err != nil {
|
||||
err = modProbe(modulesPath, moduleName)
|
||||
if err != nil {
|
||||
err = checkProcConfig(moduleName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking /proc/config.gz: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// modProbe is the classic modprobe behavior.
|
||||
func modProbe(modulesPath, moduleName string) error {
|
||||
modulesInfo, err := getModulesInfo(modulesPath)
|
||||
modulesInfo, err := getModulesInfo()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting modules information: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
func (n *NetLink) FlushConntrack() error {
|
||||
conn, err := netfilter.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netfilter: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
families := [...]netfilter.ProtoFamily{netfilter.ProtoIPv4, netfilter.ProtoIPv6}
|
||||
for _, family := range families {
|
||||
const IPCtnlMsgCtDelete = 2
|
||||
request, err := netfilter.MarshalNetlink(
|
||||
netfilter.Header{
|
||||
SubsystemID: netfilter.NFSubsysCTNetlink,
|
||||
MessageType: netfilter.MessageType(IPCtnlMsgCtDelete),
|
||||
Family: family,
|
||||
Flags: netlink.Request | netlink.Acknowledge,
|
||||
},
|
||||
nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding netlink request: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.Query(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("querying netlink request: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) FlushConntrack() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -18,7 +18,6 @@ type Route struct {
|
||||
Type uint8
|
||||
Scope uint8
|
||||
Proto uint8
|
||||
AdvMSS uint32
|
||||
}
|
||||
|
||||
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
@@ -36,9 +35,6 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
r.Type = message.Type
|
||||
r.Scope = message.Scope
|
||||
r.Proto = message.Protocol
|
||||
if metrics := message.Attributes.Metrics; metrics != nil {
|
||||
r.AdvMSS = metrics.AdvMSS
|
||||
}
|
||||
}
|
||||
|
||||
func (r Route) message() *rtnetlink.RouteMessage {
|
||||
@@ -62,6 +58,7 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
||||
Protocol: r.Proto,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
OutIface: r.LinkIndex,
|
||||
Dst: *dst, // there should always be a dst for routes
|
||||
Gateway: netipAddrToNetIP(r.Gw),
|
||||
Priority: r.Priority,
|
||||
Table: extendedTable,
|
||||
@@ -70,15 +67,6 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
||||
if src != nil { // src is optional
|
||||
message.Attributes.Src = *src
|
||||
}
|
||||
if dst != nil {
|
||||
message.Attributes.Dst = *dst
|
||||
}
|
||||
if r.AdvMSS != 0 {
|
||||
if message.Attributes.Metrics == nil {
|
||||
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
|
||||
}
|
||||
message.Attributes.Metrics.AdvMSS = r.AdvMSS
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
||||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
||||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrEchoDataMismatch, len(sent), len(received))
|
||||
ErrICMPEchoDataMismatch, len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrEchoDataMismatch, sent, received)
|
||||
ErrICMPEchoDataMismatch, sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
MaxEthernetFrameSize uint32 = 1500
|
||||
// MinIPv4MTU is defined according to
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
MinIPv4MTU uint32 = 68
|
||||
MinIPv6MTU uint32 = 1280
|
||||
|
||||
IPv4HeaderLength uint32 = 20
|
||||
IPv6HeaderLength uint32 = 40
|
||||
UDPHeaderLength uint32 = 8
|
||||
// BaseTCPHeaderLength is the TCP header length without options,
|
||||
// which is the minimum TCP header length.
|
||||
BaseTCPHeaderLength uint32 = 20
|
||||
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
|
||||
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
|
||||
MaxTCPHeaderLength uint32 = 60
|
||||
WireguardHeaderLength uint32 = 32
|
||||
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
|
||||
8 + // session id
|
||||
4 + // packet id
|
||||
28 // max possible auth tag/iv
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package constants
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
//nolint:revive
|
||||
const (
|
||||
SOCK_RAW = unix.SOCK_RAW
|
||||
SOCK_STREAM = unix.SOCK_STREAM
|
||||
AF_INET = unix.AF_INET
|
||||
AF_INET6 = unix.AF_INET6
|
||||
IPPROTO_TCP = unix.IPPROTO_TCP
|
||||
EAGAIN = unix.EAGAIN
|
||||
EWOULDBLOCK = unix.EWOULDBLOCK
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
package constants
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
SOCK_RAW = windows.SOCK_RAW
|
||||
SOCK_STREAM = windows.SOCK_STREAM
|
||||
AF_INET = windows.AF_INET
|
||||
AF_INET6 = windows.AF_INET6
|
||||
IPPROTO_TCP = windows.IPPROTO_TCP
|
||||
EAGAIN = windows.WSAEWOULDBLOCK
|
||||
EWOULDBLOCK = windows.WSAEWOULDBLOCK
|
||||
)
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
// setDontFragment for platforms other than Linux and Windows
|
||||
// is not implemented, so we just return assuming the don't
|
||||
// fragment flag is set on IP packets.
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
|
||||
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
if ipv4 {
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP,
|
||||
unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6,
|
||||
unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
if ipv4 {
|
||||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, 14, 1)
|
||||
}
|
||||
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, 14, 1)
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrMTUNotFound = errors.New("MTU not found")
|
||||
errTimeout = errors.New("operation timed out")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w: after %s", errTimeout, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// PathMTUDiscover discovers the path MTU to the given IP address
|
||||
// using ICMP.
|
||||
// It first tries to get the next hop MTU using ICMP messages.
|
||||
// If that fails, it falls back to sending echo requests with
|
||||
// different packet sizes to find the maximum MTU.
|
||||
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is4() {
|
||||
logger.Debugf("finding IPv4 next hop MTU to %s", ip)
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("requesting IPv6 ICMP packet-too-big reply from %s", ip)
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, errTimeout): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debugf("falling back to sending different sized echo packets to %s", ip)
|
||||
minMTU := constants.MinIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = constants.MinIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package icmp
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
type icmpTestUnit struct {
|
||||
mtu uint32
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU uint32, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]icmpTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []icmpTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch message.Body.(type) {
|
||||
case *icmp.Echo:
|
||||
case *icmp.DstUnreach, *icmp.TimeExceeded:
|
||||
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
|
||||
continue
|
||||
default:
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
echoBody, _ := message.Body.(*icmp.Echo)
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func GetFamilies(dsts []netip.AddrPort) (families []int) {
|
||||
const maxFamilies = 2
|
||||
families = make([]int, 0, maxFamilies)
|
||||
for _, dst := range dsts {
|
||||
family := GetFamily(dst)
|
||||
if !slices.Contains(families, family) {
|
||||
families = append(families, family)
|
||||
}
|
||||
}
|
||||
return families
|
||||
}
|
||||
|
||||
func GetFamily(dst netip.AddrPort) int {
|
||||
if dst.Addr().Is4() {
|
||||
return constants.AF_INET
|
||||
}
|
||||
return constants.AF_INET6
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func HeaderLength(ipv4 bool) uint32 {
|
||||
if ipv4 {
|
||||
return constants.IPv4HeaderLength
|
||||
}
|
||||
return constants.IPv6HeaderLength
|
||||
}
|
||||
|
||||
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
|
||||
ipHeader := make([]byte, constants.IPv4HeaderLength)
|
||||
const version byte = 4
|
||||
const headerLength byte = 20 / 4 // in 32-bit words
|
||||
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
|
||||
ipHeader[1] = 0 // type of Service
|
||||
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
|
||||
ipHeader[4], ipHeader[5] = 0, 0 // identification
|
||||
const flagsAndOffset uint16 = 0x4000 // DF bit set
|
||||
putUint16(ipHeader[6:], flagsAndOffset)
|
||||
ipHeader[8] = 64 // ttl
|
||||
ipHeader[9] = constants.IPPROTO_TCP
|
||||
srcIPBytes := srcIP.As4()
|
||||
copy(ipHeader[12:16], srcIPBytes[:])
|
||||
dstIPBytes := dstIP.As4()
|
||||
copy(ipHeader[16:20], dstIPBytes[:])
|
||||
|
||||
checksum := ipChecksum(ipHeader)
|
||||
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
|
||||
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
|
||||
|
||||
return ipHeader
|
||||
}
|
||||
|
||||
// ipChecksum calculates the checksum for the IP header.
|
||||
//
|
||||
//nolint:mnd
|
||||
func ipChecksum(header []byte) uint16 {
|
||||
sum := uint32(0)
|
||||
for i := 0; i < len(header)-1; i += 2 {
|
||||
sum += uint32(header[i])<<8 + uint32(header[i+1])
|
||||
}
|
||||
if len(header)%2 != 0 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum) //nolint:gosec
|
||||
}
|
||||
|
||||
// HeaderV6 makes an IPv6 header.
|
||||
// payloadLen is the length of the payload following the header.
|
||||
// nextHeader can be byte([constants.IPPROTO_TCP]) for example.
|
||||
func HeaderV6(srcIP, dstIP netip.Addr,
|
||||
payloadLen uint16, nextHeader byte,
|
||||
) []byte {
|
||||
ipv6Header := make([]byte, constants.IPv6HeaderLength)
|
||||
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
|
||||
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
|
||||
|
||||
// Flow Label (remaining 16 bits)
|
||||
ipv6Header[2] = 0x00
|
||||
ipv6Header[3] = 0x00
|
||||
|
||||
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
|
||||
ipv6Header[6] = nextHeader
|
||||
const hopLimit = 64
|
||||
ipv6Header[7] = hopLimit
|
||||
copy(ipv6Header[8:24], srcIP.AsSlice())
|
||||
copy(ipv6Header[24:40], dstIP.AsSlice())
|
||||
return ipv6Header
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
func putUint16(b []byte, v uint16) {
|
||||
binary.NativeEndian.PutUint16(b, v)
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !darwin
|
||||
|
||||
package ip
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func putUint16(b []byte, v uint16) {
|
||||
binary.BigEndian.PutUint16(b, v)
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package ip
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func SetIPv4HeaderIncluded(fd int) error {
|
||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package ip
|
||||
|
||||
func SetIPv4HeaderIncluded(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func SetIPv4HeaderIncluded(handle windows.Handle) error {
|
||||
const ipHdrIncluded = windows.IP_HDRINCL
|
||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IP, ipHdrIncluded, 1)
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
|
||||
// specified destination. It also reserves an ephemeral source port for the specified protocol
|
||||
// to ensure that the port is not used by other processes. The cleanup function returned should
|
||||
// be called to release the reserved port when done.
|
||||
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
|
||||
srcAddr, err := srcIP(dst.Addr())
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
|
||||
}
|
||||
|
||||
srcPort, cleanup, err := srcPort(srcAddr, proto)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errNoRoute = fmt.Errorf("no route to destination")
|
||||
ErrNetworkUnreachable = errors.New("network unreachable")
|
||||
)
|
||||
|
||||
func srcIP(dst netip.Addr) (netip.Addr, error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
family := uint8(constants.AF_INET)
|
||||
if dst.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
// Request route to destination
|
||||
requestMessage := &rtnetlink.RouteMessage{
|
||||
Family: family,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
Dst: dst.AsSlice(),
|
||||
},
|
||||
}
|
||||
messages, err := conn.Route.Get(requestMessage)
|
||||
if err != nil {
|
||||
var sysErr syscall.Errno
|
||||
if errors.As(err, &sysErr) && sysErr == syscall.ENETUNREACH {
|
||||
err = ErrNetworkUnreachable
|
||||
}
|
||||
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Attributes.Src == nil {
|
||||
continue
|
||||
}
|
||||
ipv6 := message.Attributes.Src.To4() == nil
|
||||
if ipv6 {
|
||||
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
|
||||
}
|
||||
|
||||
// srcPort reserves an ephemeral source port by opening a socket for the
|
||||
// protocol specified and binds it to the provided source address.
|
||||
// It doesn't actually listen on the port.
|
||||
// The cleanup function returned should be called to release the port when done.
|
||||
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
|
||||
family := constants.AF_INET
|
||||
if srcAddr.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
fd, err := socket(family, constants.SOCK_STREAM, proto)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
|
||||
}
|
||||
cleanup = func() {
|
||||
_ = closeSocket(fd)
|
||||
}
|
||||
|
||||
// Bind to port 0 to get an ephemeral port
|
||||
const port = 0
|
||||
bindAddr := makeSockAddr(srcAddr, port)
|
||||
|
||||
err = bind(fd, bindAddr)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
|
||||
}
|
||||
|
||||
srcPort, err = extractPortFromFD(fd)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return 0, nil, fmt.Errorf("extracting port from socket fd: %w", err)
|
||||
}
|
||||
|
||||
return srcPort, cleanup, nil
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package ip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func socket(domain int, typ int, proto int) (fd int, err error) {
|
||||
return unix.Socket(domain, typ, proto)
|
||||
}
|
||||
|
||||
func closeSocket(fd int) error {
|
||||
return unix.Close(fd)
|
||||
}
|
||||
|
||||
func bind(fd int, addr unix.Sockaddr) error {
|
||||
return unix.Bind(fd, addr)
|
||||
}
|
||||
|
||||
func makeSockAddr(ip netip.Addr, port uint16) unix.Sockaddr {
|
||||
if ip.Is4() {
|
||||
return &unix.SockaddrInet4{
|
||||
Port: int(port),
|
||||
Addr: ip.As4(),
|
||||
}
|
||||
}
|
||||
return &unix.SockaddrInet6{
|
||||
Port: 0,
|
||||
Addr: ip.As16(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractPortFromFD(fd int) (uint16, error) {
|
||||
sockAddr, err := unix.Getsockname(fd)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting sockname: %w", err)
|
||||
}
|
||||
|
||||
switch typedSockAddr := sockAddr.(type) {
|
||||
case *unix.SockaddrInet4:
|
||||
return uint16(typedSockAddr.Port), nil //nolint:gosec
|
||||
case *unix.SockaddrInet6:
|
||||
return uint16(typedSockAddr.Port), nil //nolint:gosec
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func socket(domain int, typ int, proto int) (fd windows.Handle, err error) {
|
||||
return windows.Socket(domain, typ, proto)
|
||||
}
|
||||
|
||||
func closeSocket(fd windows.Handle) error {
|
||||
return windows.Close(fd)
|
||||
}
|
||||
|
||||
func bind(fd windows.Handle, addr windows.Sockaddr) error {
|
||||
return windows.Bind(fd, addr)
|
||||
}
|
||||
|
||||
func makeSockAddr(ip netip.Addr, port uint16) windows.Sockaddr {
|
||||
if ip.Is4() {
|
||||
return &windows.SockaddrInet4{
|
||||
Port: int(port),
|
||||
Addr: ip.As4(),
|
||||
}
|
||||
}
|
||||
return &windows.SockaddrInet6{
|
||||
Port: int(port),
|
||||
Addr: ip.As16(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractPortFromFD(fd windows.Handle) (uint16, error) {
|
||||
sockAddr, err := windows.Getsockname(fd)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting sockname: %w", err)
|
||||
}
|
||||
|
||||
switch typedSockAddr := sockAddr.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
return uint16(typedSockAddr.Port), nil //nolint:gosec
|
||||
case *windows.SockaddrInet6:
|
||||
return uint16(typedSockAddr.Port), nil //nolint:gosec
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
icmpv4Protocol = 1
|
||||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
minIPv4MTU uint32 = 68
|
||||
icmpv4Protocol int = 1
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
@@ -25,8 +26,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
const ipv4 = true
|
||||
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
|
||||
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
@@ -38,7 +38,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,9 +83,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
// for loop in case we read an ICMP message from another ICMP request
|
||||
// or TCP/UDP traffic triggering an ICMP response.
|
||||
for {
|
||||
for { // for loop in case we read an echo reply for another ICMP request
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
@@ -110,27 +108,24 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.DstUnreach:
|
||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||
const portUnreachable = 3
|
||||
const communicationAdministrativelyProhibitedCode = 13
|
||||
switch inboundMessage.Code {
|
||||
case fragmentationRequiredAndDFFlagSetCode:
|
||||
case portUnreachable: // triggered by TCP or UDP from applications
|
||||
continue // ignore and wait for the next message
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrDestinationUnreachable,
|
||||
ErrCommunicationAdministrativelyProhibited,
|
||||
ErrICMPDestinationUnreachable,
|
||||
ErrICMPCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrDestinationUnreachable, inboundMessage.Code)
|
||||
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
// Note: the go library does not handle this NextHopMTU section.
|
||||
nextHopMTU := packetBytes[6:8]
|
||||
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
|
||||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||
}
|
||||
@@ -158,7 +153,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,36 +6,24 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
minIPv6MTU = 1280
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
|
||||
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
const ipv4 = false
|
||||
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -97,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
case *icmp.PacketTooBig:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
|
||||
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||
}
|
||||
@@ -115,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
||||
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
@@ -128,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package icmp
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
cryptorand "crypto/rand"
|
||||
+230
-50
@@ -4,88 +4,268 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall/iptables"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/icmp"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
|
||||
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
|
||||
)
|
||||
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
||||
|
||||
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
|
||||
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
|
||||
// whole range of possible MTU values up to the physical link MTU to check.
|
||||
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
||||
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||
// If the pingTimeout is zero, it defaults to 1 second.
|
||||
// If the logger is nil, a no-op logger is used.
|
||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||
physicalLinkMTU uint32, tryTimeout time.Duration, fw tcp.Firewall, logger Logger) (
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
|
||||
mtu uint32, err error,
|
||||
) {
|
||||
if physicalLinkMTU == 0 {
|
||||
const ethernetStandardMTU = 1500
|
||||
physicalLinkMTU = ethernetStandardMTU
|
||||
}
|
||||
if tryTimeout == 0 {
|
||||
tryTimeout = time.Second
|
||||
if pingTimeout == 0 {
|
||||
pingTimeout = time.Second
|
||||
}
|
||||
if logger == nil {
|
||||
logger = &noopLogger{}
|
||||
}
|
||||
|
||||
// Try finding the MTU using ICMP
|
||||
maxPossibleMTU := physicalLinkMTU
|
||||
icmpSuccess := false
|
||||
for _, icmpIP := range icmpAddrs {
|
||||
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
|
||||
tryTimeout, logger)
|
||||
if ip.Is4() {
|
||||
logger.Debug("finding IPv4 next hop MTU")
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
|
||||
icmpSuccess = true
|
||||
maxPossibleMTU = mtu
|
||||
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
|
||||
logger.Debugf("ICMP path MTU discovery failed: %s", err)
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||
}
|
||||
if icmpSuccess {
|
||||
break
|
||||
} else {
|
||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
minMTU := constants.MinIPv4MTU
|
||||
if tcpAddrs[0].Addr().Is6() {
|
||||
minMTU = constants.MinIPv6MTU
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debug("falling back to sending different sized echo packets")
|
||||
minMTU := minIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = minIPv6MTU
|
||||
}
|
||||
if icmpSuccess {
|
||||
const mtuMargin = 150
|
||||
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
||||
}
|
||||
|
||||
type pmtudTestUnit struct {
|
||||
mtu uint32
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU uint32, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, iptables.ErrMarkMatchModuleMissing) {
|
||||
logger.Debugf("aborting TCP path MTU discovery: %s", err)
|
||||
if icmpSuccess {
|
||||
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]pmtudTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
if icmpSuccess {
|
||||
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d",
|
||||
ErrICMPOkTCPFail, maxPossibleMTU)
|
||||
}
|
||||
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail)
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
// Create the MTU slice of length 11 such that:
|
||||
// - the first element is the minMTU
|
||||
// - the last element is the maxMTU
|
||||
// - elements in-between are separated as close to each other
|
||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||
// with a total search space of 1728 MTUs which is enough;
|
||||
// to find it in 2 searches requires 37 parallel queries which
|
||||
// could be blocked by firewalls.
|
||||
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||
const mtusLength = 11 // find the final MTU in 3 searches
|
||||
diff := maxMTU - minMTU
|
||||
switch {
|
||||
case minMTU > maxMTU:
|
||||
panic("minMTU > maxMTU")
|
||||
case diff <= mtusLength:
|
||||
mtus = make([]uint32, 0, diff)
|
||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||
mtus = append(mtus, mtu)
|
||||
}
|
||||
default:
|
||||
step := float64(diff) / float64(mtusLength-1)
|
||||
mtus = make([]uint32, 0, mtusLength)
|
||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||
}
|
||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||
}
|
||||
|
||||
return mtus
|
||||
}
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []pmtudTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
echoBody, ok := message.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package test
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeMTUsToTest(t *testing.T) {
|
||||
func Test_makeMTUsToTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -48,7 +48,7 @@ func Test_MakeMTUsToTest(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
assert.Equal(t, testCase.mtus, mtus)
|
||||
})
|
||||
}
|
||||
@@ -1,142 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/firewall/iptables"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// testFirewall must be global to prevent parallel tests from interfering
|
||||
// with each other since they would interact with the same filter table.
|
||||
// The first test to use should initialize it, and the rest will reuse it.
|
||||
var (
|
||||
testFirewall *firewall.Config //nolint:gochecknoglobals
|
||||
testFirewallOnce sync.Once //nolint:gochecknoglobals
|
||||
)
|
||||
|
||||
// getFirewall returns a Firewall instance, initializing it if needed. If
|
||||
// iptables is not supported, it skips the test.
|
||||
func getFirewall(t *testing.T) *firewall.Config {
|
||||
t.Helper()
|
||||
|
||||
testFirewallOnce.Do(func() {
|
||||
noopLogger := &noopLogger{}
|
||||
cmder := command.New()
|
||||
var err error
|
||||
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||
if errors.Is(err, iptables.ErrNotSupported) {
|
||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||
}
|
||||
require.NoError(t, err, "creating firewall config")
|
||||
})
|
||||
if testFirewall == nil {
|
||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||
}
|
||||
return testFirewall
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Patch(_ ...log.Option) {}
|
||||
func (l *noopLogger) Debug(_ string) {}
|
||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Info(_ string) {}
|
||||
func (l *noopLogger) Warn(_ string) {}
|
||||
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Error(_ string) {}
|
||||
|
||||
var errRouteNotFound = errors.New("route not found")
|
||||
|
||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting routes list: %w", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
||||
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by index: %w", err)
|
||||
}
|
||||
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
||||
// or the IP header 16 bits will fail to fit that packet length value.
|
||||
const maxMTU = 65535
|
||||
return min(link.MTU, maxMTU), nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
noopLogger := &noopLogger{}
|
||||
routing := routing.New(netlinker, noopLogger)
|
||||
defaultRoutes, err := routing.DefaultRoutes()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||
}
|
||||
families := []uint8{constants.AF_INET, constants.AF_INET6}
|
||||
for _, family := range families {
|
||||
for _, route := range defaultRoutes {
|
||||
if route.Family != family {
|
||||
continue
|
||||
}
|
||||
link, err := netlinker.LinkByName(route.NetInterface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||
}
|
||||
mtu = max(mtu, link.MTU)
|
||||
}
|
||||
}
|
||||
if mtu == 0 {
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
}
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||
t.Helper()
|
||||
|
||||
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := unix.Close(fd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
addr := &unix.SockaddrInet4{
|
||||
Port: 0,
|
||||
Addr: [4]byte{127, 0, 0, 1},
|
||||
}
|
||||
|
||||
err = unix.Bind(fd, addr)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr, err := unix.Getsockname(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
|
||||
if !ok {
|
||||
_ = unix.Close(fd)
|
||||
t.Fatal("not an IPv4 address")
|
||||
}
|
||||
|
||||
return uint16(sockAddr4.Port) //nolint:gosec
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user