Compare commits

..

6 Commits

Author SHA1 Message Date
Quentin McGaw 7f22fb3276 fix(protonvpn): support port 51820 for UDP OpenVPN 2026-02-11 14:13:34 +00:00
Quentin McGaw 6909a0c123 fix(healthcheck): prevent race condition and fix #3096 (#3123) 2026-02-11 14:12:20 +00:00
Quentin McGaw 3e1f48932a fix(openvpn): only log openvpn version corresponding to OPENVPN_VERSION 2026-02-11 14:12:08 +00:00
Chris Duck 50744852c5 fix(protonvpn): update OpenVPN settings (#3120) 2026-02-11 14:11:57 +00:00
Quentin McGaw 09e52bc685 fix(httpproxy): remove info log when no Proxy-Authorization header is present 2026-02-11 14:11:46 +00:00
Quentin McGaw 857fe425ec fix(wireguard): fix detection of kernelspace wireguard 2026-02-11 14:11:36 +00:00
197 changed files with 13389 additions and 15074 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl tcpdump iptables
RUN apk add wireguard-tools htop openssl
+1 -5
View File
@@ -45,7 +45,6 @@ jobs:
level: error
exclude: |
./internal/storage/servers.json
./golangci.yml
*.md
- name: Linting
@@ -60,13 +59,10 @@ jobs:
- name: Run tests in test container
run: |
touch coverage.txt
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
docker run --rm --device /dev/net/tun \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
- name: Verify dev cross platform compatibility
run: docker build --target xcompile .
- name: Build final image
run: docker build -t final-image .
+1 -2
View File
@@ -22,7 +22,6 @@ linters:
- "^disabled$"
# Firewall and routing strings
- "^(ACCEPT|DROP)$"
- "^--append$"
- "^--delete$"
- "^all$"
- "^(tcp|udp)$"
@@ -48,7 +47,7 @@ linters:
path: internal\/server\/.+\.go
- linters:
- ireturn
text: returns interface \(golang\.org\/x\/sys\/unix\.Sockaddr\)
text: returns interface \(github\.com\/vishvananda\/netlink\.Link\)
- linters:
- ireturn
path: internal\/openvpn\/pkcs8\/descbc\.go
+2 -9
View File
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
RUN apk --update add git g++ findutils iptables
RUN apk --update add git g++ findutils
ENV CGO_ENABLED=0
COPY --from=golangci-lint /bin /go/bin/golangci-lint
COPY --from=mockgen /bin /go/bin/mockgen
@@ -46,10 +46,6 @@ RUN git init && \
git diff --exit-code && \
rm -rf .git/
FROM --platform=${BUILDPLATFORM} base AS xcompile
RUN GOOS=darwin go build -o /dev/null ./...
RUN GOOS=windows go build -o /dev/null ./...
FROM --platform=${BUILDPLATFORM} base AS build
ARG TARGETPLATFORM
ARG VERSION=unknown
@@ -110,11 +106,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU= \
WIREGUARD_MTU=1320 \
WIREGUARD_IMPLEMENTATION=auto \
# PMTUD
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
# VPN server filtering
SERVER_REGIONS= \
SERVER_COUNTRIES= \
+1 -1
View File
@@ -58,7 +58,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
## Features
- Based on Alpine 3.22 for a small Docker image of 41.1MB
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports OpenVPN for all providers listed
- Supports Wireguard both kernelspace and userspace
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
+16 -19
View File
@@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
@@ -168,7 +167,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
defer fmt.Println(gluetunLogo)
announcementExp, err := time.Parse(time.RFC3339, "2026-04-01T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
if err != nil {
return err
}
@@ -179,7 +178,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version,
Commit: buildInfo.Commit,
Created: buildInfo.Created,
Announcement: "All control server routes are now private by default",
Announcement: "All control server routes will become private by default after the v3.41.0 release",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -227,7 +226,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
firewallLogger.Patch(log.SetLevel(log.LevelDebug))
}
firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder,
netLinker, defaultRoutes, localNetworks)
defaultRoutes, localNetworks)
if err != nil {
return err
}
@@ -264,7 +263,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID)
const clientTimeout = 35 * time.Second
const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
// Create configurators
alpineConf := alpine.New()
@@ -279,7 +278,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
err = printVersions(ctx, logger, []printVersionElement{
{name: "Alpine", getVersion: alpineConf.Version},
{name: "OpenVPN", getVersion: ovpnVersion},
{name: "Firewall", getVersion: firewallConf.Version},
{name: "IPtables", getVersion: firewallConf.Version},
})
if err != nil {
return err
@@ -554,27 +553,26 @@ type netLinker interface {
Router
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
IsWireguardSupported() bool
IsIPv6Supported() (ok bool, err error)
FlushConntrack() error
PatchLoggerLevel(level log.Level)
}
type Addresser interface {
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, addr netip.Prefix) error
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -582,12 +580,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) (err error)
LinkSetUp(linkIndex uint32) (err error)
LinkSetDown(linkIndex uint32) (err error)
LinkSetMTU(linkIndex, mtu uint32) error
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}
type clier interface {
+14 -14
View File
@@ -7,13 +7,10 @@ require (
github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/mdlayher/netlink v1.7.2
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0
@@ -21,13 +18,13 @@ require (
github.com/qdm12/log v0.1.0
github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1
github.com/ti-mo/netfilter v0.5.3
github.com/ulikunitz/xz v0.5.15
github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.49.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
golang.org/x/text v0.31.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0
@@ -41,11 +38,13 @@ require (
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -56,11 +55,12 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sync v0.19.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+26 -26
View File
@@ -13,8 +13,6 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
@@ -28,12 +26,10 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
@@ -51,8 +47,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
@@ -73,8 +69,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 h1:0ycKUDQ50cYb2QpeyGcEnvVs9HJmC9jsb/XZNC1z28c=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/dns/v2 v2.0.0-rc10 h1:IyeNEYXfhBsaE1dwxx5eAqdAz1HS98dT+8c7xoKODa0=
github.com/qdm12/dns/v2 v2.0.0-rc10/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
@@ -95,10 +91,12 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ=
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -108,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -124,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -142,10 +140,12 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -155,8 +155,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -164,8 +164,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -23,9 +23,7 @@ type DNSBlacklist struct {
AddBlockedIPs []netip.Addr
AddBlockedIPPrefixes []netip.Prefix
// RebindingProtectionExemptHostnames is a list of hostnames
// exempt from DNS rebinding protection. It can contain parent
// domains which are of the form "*.example.com". Note the wildcard
// can only be used at the start of the hostname.
// exempt from DNS rebinding protection.
RebindingProtectionExemptHostnames []string
}
@@ -57,9 +55,6 @@ func (b DNSBlacklist) validate() (err error) {
}
for _, host := range b.RebindingProtectionExemptHostnames {
if len(host) > 2 && host[:2] == "*." {
host = host[2:]
}
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
}
-111
View File
@@ -1,111 +0,0 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// PMTUD contains settings to configure Path MTU Discovery.
type PMTUD struct {
// ICMPAddresses is the redundancy list of addresses to use
// for ICMP path MTU discovery. Each address MUST handle ICMP
// packets for PMTUD to work.
// It cannot be nil in the internal state.
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
// TCPAddresses is the redundancy list of addresses to use
// for TCP path MTU discovery. Each address MUST have a listening
// TCP server on the port specified.
// It cannot be nil in the internal state.
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
}
}
return nil
}
func (p *PMTUD) copy() (copied PMTUD) {
return PMTUD{
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
}
}
func (p *PMTUD) overrideWith(other PMTUD) {
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
}
func (p *PMTUD) setDefaults() {
defaultICMPAddresses := []netip.Addr{
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
}
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
const dnsPort, tlsPort = 53, 443
defaultTCPAddresses := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
}
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
}
func (p PMTUD) String() string {
return p.toLinesNode().String()
}
func (p PMTUD) toLinesNode() (node *gotree.Node) {
node = gotree.New("Path MTU discovery:")
icmpAddrNode := node.Append("ICMP addresses:")
for _, addr := range p.ICMPAddresses {
icmpAddrNode.Append(addr.String())
}
tcpAddrNode := node.Append("TCP addresses:")
for _, addr := range p.TCPAddresses {
tcpAddrNode.Append(addr.String())
}
return node
}
func (p *PMTUD) read(r *reader.Reader) (err error) {
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
if err != nil {
return err
}
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
if err != nil {
return err
}
return nil
}
@@ -2,8 +2,6 @@ package settings
import (
"fmt"
"slices"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants/providers"
@@ -33,11 +31,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
if vpnType == vpn.OpenVPN {
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
mullvadIndex := slices.Index(validNames, providers.Mullvad)
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
validNames = validNames[:len(validNames)-1]
sort.Strings(validNames)
} else { // Wireguard
validNames = []string{
providers.Airvpn,
@@ -29,27 +29,14 @@ func Test_Settings_String(t *testing.T) {
| | └── OpenVPN server selection settings:
| | ├── Protocol: UDP
| | └── Private Internet Access encryption preset: strong
| ── OpenVPN settings:
| | ├── OpenVPN version: 2.6
| | ├── User: [not set]
| | ├── Password: [not set]
| | ├── Private Internet Access encryption preset: strong
| | ├── Network interface: tun0
| | ├── Run OpenVPN as: root
| | └── Verbosity level: 1
| └── Path MTU discovery:
| ├── ICMP addresses:
| | ├── 1.1.1.1
| | └── 8.8.8.8
| └── TCP addresses:
| ├── 1.1.1.1:53
| ├── 8.8.8.8:53
| ├── 1.1.1.1:443
| ├── 8.8.8.8:443
| ├── [2606:4700:4700::1111]:53
| ├── [2001:4860:4860::8888]:53
| ├── [2606:4700:4700::1111]:443
| └── [2001:4860:4860::8888]:443
| ── OpenVPN settings:
| ├── OpenVPN version: 2.6
| ├── User: [not set]
| ├── Password: [not set]
| ├── Private Internet Access encryption preset: strong
| ├── Network interface: tun0
| ├── Run OpenVPN as: root
| └── Verbosity level: 1
├── DNS settings:
| ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1
-15
View File
@@ -18,7 +18,6 @@ type VPN struct {
Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
}
// TODO v4 remove pointer for receiver (because of Surfshark).
@@ -46,11 +45,6 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
}
}
err = v.PMTUD.validate()
if err != nil {
return fmt.Errorf("PMTUD settings: %w", err)
}
return nil
}
@@ -60,7 +54,6 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
}
}
@@ -69,7 +62,6 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
}
func (v *VPN) setDefaults() {
@@ -77,7 +69,6 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
}
func (v VPN) String() string {
@@ -94,7 +85,6 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
} else {
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
return node
}
@@ -117,10 +107,5 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err)
}
err = v.PMTUD.read(r)
if err != nil {
return fmt.Errorf("PMTUD: %w", err)
}
return nil
}
+14 -10
View File
@@ -38,9 +38,14 @@ type Wireguard struct {
Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be nil in the internal state, and defaults to
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// It cannot be zero in the internal state, and defaults to
// 1320. Note it is not the wireguard-go MTU default of 1420
// because this impacts bandwidth a lot on some VPN providers,
// see https://github.com/qdm12/gluetun/issues/1650.
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
MTU uint16 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -189,7 +194,8 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
const defaultMTU = 1320
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
}
@@ -225,11 +231,7 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
if *w.MTU == 0 {
interfaceNode.Append("MTU: use path MTU discovery")
} else {
interfaceNode.Appendf("MTU: %d", *w.MTU)
}
interfaceNode.Appendf("MTU: %d", w.MTU)
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
@@ -270,9 +272,11 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
w.MTU = *mtuPtr
}
return nil
}
+7 -6
View File
@@ -2,6 +2,7 @@ package dns
import (
"context"
"errors"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/gluetun/internal/constants"
@@ -43,12 +44,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
runError, err = l.setupServer(ctx)
if err == nil {
l.backoffTime = defaultBackoffTime
l.logger.Info("ready and using DNS server at address " + settings.ServerAddress.String())
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
l.logger.Info("ready")
break
}
@@ -57,6 +53,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if ctx.Err() != nil {
return
}
if !errors.Is(err, errUpdateBlockLists) {
const fallback = true
l.useUnencryptedDNS(fallback)
}
l.logAndWait(ctx, err)
settings = l.GetSettings()
}
+7 -6
View File
@@ -2,24 +2,25 @@ package dns
import (
"context"
"errors"
"fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/dns/v2/pkg/server"
)
var errUpdateBlockLists = errors.New("cannot update filter block lists")
func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err error) {
settings := l.GetSettings()
var updateSettings update.Settings
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
err = l.updateFiles(ctx)
if err != nil {
return nil, fmt.Errorf("updating filter for rebinding protection: %w", err)
return nil, fmt.Errorf("%w: %w", errUpdateBlockLists, err)
}
settings := l.GetSettings()
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger)
if err != nil {
return nil, fmt.Errorf("building server settings: %w", err)
+15 -4
View File
@@ -28,12 +28,23 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
return
case <-timer.C:
lastTick = l.timeNow()
settings := l.GetSettings()
if l.GetStatus() == constants.Running {
if err := l.updateFiles(ctx, settings); err != nil {
l.logger.Warn("updating block lists failed, skipping: " + err.Error())
status := l.GetStatus()
if status == constants.Running {
if err := l.updateFiles(ctx); err != nil {
l.statusManager.SetStatus(constants.Crashed)
l.logger.Error(err.Error())
l.logger.Warn("skipping DNS server restart due to failed files update")
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
continue
}
}
_, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped)
_, _ = l.statusManager.ApplyStatus(ctx, constants.Running)
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
case <-l.updateTicker:
if !timer.Stop() {
+4 -2
View File
@@ -6,10 +6,11 @@ import (
"github.com/qdm12/dns/v2/pkg/blockbuilder"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/gluetun/internal/configuration/settings"
)
func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err error) {
func (l *Loop) updateFiles(ctx context.Context) (err error) {
settings := l.GetSettings()
l.logger.Info("downloading hostnames and IP block lists")
blacklistSettings := settings.Blacklist.ToBlockBuilderSettings(l.client)
@@ -36,6 +37,7 @@ func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err erro
IPPrefixes: result.BlockedIPPrefixes,
}
updateSettings.BlockHostnames(result.BlockedHostnames)
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
if err != nil {
return fmt.Errorf("updating filter: %w", err)
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"fmt"
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
+58 -37
View File
@@ -22,7 +22,9 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled {
c.logger.Info("disabling...")
c.restore(ctx)
if err = c.disable(ctx); err != nil {
return fmt.Errorf("disabling firewall: %w", err)
}
c.enabled = false
c.logger.Info("disabled successfully")
return nil
@@ -39,42 +41,64 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil
}
func (c *Config) enable(ctx context.Context) (err error) {
c.restore, err = c.impl.SaveAndRestore(ctx)
func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
return fmt.Errorf("removing port redirections: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
return nil
}
// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *Config) enable(ctx context.Context) (err error) {
touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
touched = true
if err = c.impl.SetIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
defer func() {
if err != nil {
c.restore(context.Background())
}
}()
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
const remove = false
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
defer func() {
if touched && err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
err = c.flushExistingConnections(ctx)
if err != nil {
return fmt.Errorf("flushing existing connections: %w", err)
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return err
}
@@ -84,9 +108,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
for _, network := range c.localNetworks {
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
network.InterfaceName, network.IP, network.IPNet, remove)
if err != nil {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
return err
}
@@ -95,7 +117,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue
}
localInterfaces[network.InterfaceName] = struct{}{}
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
}
@@ -108,7 +130,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
return err
}
}
@@ -117,12 +139,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.redirectPorts(ctx)
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
@@ -142,7 +164,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
continue
}
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil {
return fmt.Errorf("accepting output traffic through VPN: %w", err)
}
@@ -164,7 +186,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
firewallUpdated = true
const remove = false
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
@@ -182,7 +204,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
for port, netInterfaces := range c.allowedInputPorts {
for netInterface := range netInterfaces {
const remove = false
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
err = c.acceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, netInterface, err)
@@ -192,10 +214,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
return nil
}
func (c *Config) redirectPorts(ctx context.Context) (err error) {
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
for _, portRedirection := range c.portRedirections {
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
+23 -19
View File
@@ -2,29 +2,28 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/routing"
)
type Config struct {
runner CmdRunner
netlinker Netlinker
logger Logger
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
// Fixed
impl firewallImpl
// Fixed state
ipTables string
ip6Tables string
customRulesPath string
// State
enabled bool
restore func(context.Context)
vpnConnection models.Connection
vpnIntf string
outboundSubnets []netip.Prefix
@@ -36,23 +35,28 @@ type Config struct {
// NewConfig creates a new Config instance and returns an error
// if no iptables implementation is available.
func NewConfig(ctx context.Context, logger Logger,
runner CmdRunner, netlinker Netlinker,
defaultRoutes []routing.DefaultRoute, localNetworks []routing.LocalNetwork,
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
localNetworks []routing.LocalNetwork,
) (config *Config, err error) {
impl, err := iptables.New(ctx, runner, logger)
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, fmt.Errorf("creating iptables firewall: %w", err)
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
return &Config{
runner: runner,
netlinker: netlinker,
logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}),
ipTables: iptables,
ip6Tables: ip6tables,
customRulesPath: "/iptables/post-rules.txt",
// Obtained from routing
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
impl: impl,
customRulesPath: "/iptables/post-rules.txt",
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
}, nil
}
-74
View File
@@ -1,74 +0,0 @@
package firewall
import (
"context"
"errors"
"fmt"
"time"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/netlink"
)
func (c *Config) flushExistingConnections(ctx context.Context) error {
tries := []struct {
name string
f func(ctx context.Context) error
}{
{name: "flushing conntrack", f: func(_ context.Context) error {
return c.netlinker.FlushConntrack()
}},
{name: "marking and filtering unmarked packets", f: c.impl.AcceptOutputPublicOnlyNewTraffic},
{name: "rejecting connections for one second", f: c.rejectOutputTrafficTemporarily},
{name: "dropping connections for one second", f: c.dropOutputTrafficTemporarily},
}
errs := make([]error, 0, len(tries))
for i, try := range tries {
if i > 0 {
c.logger.Debugf("falling back to %s because %s failed: %s", try.name, tries[i-1].name, errs[i-1])
}
err := try.f(ctx)
if err == nil {
return nil
}
err = fmt.Errorf("%s: %w", try.name, err)
if !errors.Is(err, iptables.ErrKernelModuleMissing) && !errors.Is(err, netlink.ErrConntrackNetlinkNotSupported) {
return err
}
errs = append(errs, err)
}
return fmt.Errorf("all tries failed: %v", errs) //nolint:err113
}
func (c *Config) rejectOutputTrafficTemporarily(ctx context.Context) error {
return setupThenRevert(ctx, c.impl.RejectOutputPublicTraffic)
}
func (c *Config) dropOutputTrafficTemporarily(ctx context.Context) error {
return setupThenRevert(ctx, c.impl.DropOutputPublicTraffic)
}
// setupThenRevert is a helper function to run a setup function that takes a remove boolean argument,
// and then run the same function with remove set to true after one second or when the context is canceled,
// whichever comes first.
func setupThenRevert(ctx context.Context, f func(ctx context.Context, remove bool) error) error {
remove := false
err := f(ctx, remove)
if err != nil {
return fmt.Errorf("setting up: %w", err)
}
timer := time.NewTimer(time.Second)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
}
remove = true
// Use [context.Background] to make sure this is removed, even if the context
// passed to this function is canceled.
err = f(context.Background(), remove)
if err != nil {
return fmt.Errorf("reverting: %w", err)
}
return nil
}
+1 -37
View File
@@ -1,12 +1,6 @@
package firewall
import (
"context"
"net/netip"
"os/exec"
"github.com/qdm12/gluetun/internal/models"
)
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
@@ -14,37 +8,7 @@ type CmdRunner interface {
type Logger interface {
Debug(s string)
Debugf(format string, args ...any)
Info(s string)
Warn(s string)
Error(s string)
}
type Netlinker interface {
FlushConntrack() error
}
type firewallImpl interface { //nolint:interfacebloat
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error
RejectOutputPublicTraffic(ctx context.Context, remove bool) error
DropOutputPublicTraffic(ctx context.Context, remove bool) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
connection models.Connection, remove bool) error
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string) error
SetIPv4AllPolicies(ctx context.Context, policy string) error
SetIPv6AllPolicies(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error)
Version(ctx context.Context) (version string, err error)
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -14,8 +14,8 @@ import (
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
ip6tablesPath string, err error,
) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
if errors.Is(err, ErrNotSupported) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
if errors.Is(err, ErrIPTablesNotSupported) {
return "", nil
} else if err != nil {
return "", err
@@ -24,23 +24,8 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
}
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
return err
}
}
@@ -48,24 +33,11 @@ func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instruction
}
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
if c.ip6Tables == "" {
return nil
}
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
@@ -76,9 +48,6 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
if strings.Contains(output, "missing kernel module") {
err = ErrKernelModuleMissing
}
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
}
@@ -87,7 +56,7 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
+331
View File
@@ -0,0 +1,331 @@
package firewall
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"strings"
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
)
func appendOrDelete(remove bool) string {
if remove {
return "--delete"
}
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
}
return rule
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
output, err := c.runner.Run(cmd)
if err != nil {
return "", err
}
words := strings.Fields(output)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
}
return words[1], nil
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
}
return nil
}
func (c *Config) clearAllRules(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
}
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool,
) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
})
}
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
protocol, connection.Port)
if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
if doIPv4 {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4).
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool,
) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
appendOrDelete(remove), interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
// Used for port forwarding, with intf set to tun.
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
})
}
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) redirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool,
) (err error) {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
}
return nil
}
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
return nil
}
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
b, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()
for _, line := range lines {
var ipv4 bool
var rule string
switch {
case strings.HasPrefix(line, "iptables "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables ")
case strings.HasPrefix(line, "iptables-nft "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-nft ")
case strings.HasPrefix(line, "iptables-legacy "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-legacy ")
case strings.HasPrefix(line, "ip6tables "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables ")
case strings.HasPrefix(line, "ip6tables-nft "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-nft ")
case strings.HasPrefix(line, "ip6tables-legacy "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
default:
continue
}
if remove {
rule = flipRule(rule)
}
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstruction(ctx, rule)
}
if err != nil {
return err
}
successfulRules = append(successfulRules, rule)
}
return nil
}
-85
View File
@@ -1,85 +0,0 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strings"
)
// SaveAndRestore saves the current iptables and ip6tables rules and
// returns a restore function that can be called to restore the saved rules.
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
return c.saveAndRestore(ctx)
}
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
// before calling this function. Note the restore function does not interact with mutexes
// so the caller must make sure the mutexes are locked when calling the restore function.
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return nil, err
}
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return nil, err
}
restore = func(ctx context.Context) {
restoreIPv4(ctx)
if restoreIPv6 != nil {
restoreIPv6(ctx)
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
if c.ip6Tables == "" {
return nil, nil //nolint:nilnil
}
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
-39
View File
@@ -1,39 +0,0 @@
package iptables
import (
"context"
"errors"
"sync"
)
var ErrKernelModuleMissing = errors.New("kernel module is missing for this operation")
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
// Fixed state
ipTables string
ip6Tables string
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
}, nil
}
-14
View File
@@ -1,14 +0,0 @@
package iptables
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
-485
View File
@@ -1,485 +0,0 @@
package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"strings"
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
)
func appendOrDelete(remove bool) string {
if remove {
return "--delete"
}
return "--append"
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
output, err := c.runner.Run(cmd)
if err != nil {
return "", err
}
words := strings.Fields(output)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
}
return "iptables " + words[1], nil
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
if strings.Contains(output, "missing kernel module") {
err = ErrKernelModuleMissing
}
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
}
return nil
}
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"--append INPUT -i %s -j ACCEPT", intf))
}
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
})
}
// AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept
// established or related packets with this mark only. This effectively forces
// previously established or related traffic to be blocked.
// If remove is true, the rules are removed instead of appended.
// If the relevant kernel modules are not available, it returns an error indicating
// which kernel module is missing.
func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error {
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
appendToBoth := func(instruction string) {
ipv4Instructions = append(ipv4Instructions, instruction)
ipv6Instructions = append(ipv6Instructions, instruction)
}
// Mark new connections with mark 0x567
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate NEW -j CONNMARK --set-mark 0x567")
// Drop related/established connections that made it through; marked connections would
// be directly accepted by the first rule in the OUTPUT chain (see below)
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j DROP")
// Set the PUBLIC_ONLY chain as the second rule in the OUTPUT chain, so that it is evaluated
// after the accept rule below, for performance reasons.
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
appendToBoth("-I OUTPUT -m conntrack --ctstate RELATED,ESTABLISHED -m connmark --mark 0x567 -j ACCEPT")
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions)
if err != nil {
restore(ctx)
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions)
if err != nil {
restore(ctx)
return err
}
return nil
}
func (c *Config) RejectOutputPublicTraffic(ctx context.Context, remove bool) error {
return c.targetOutputPublicTraffic(ctx, "REJECT", remove)
}
func (c *Config) DropOutputPublicTraffic(ctx context.Context, remove bool) error {
return c.targetOutputPublicTraffic(ctx, "DROP", remove)
}
func (c *Config) targetOutputPublicTraffic(ctx context.Context, target string, remove bool) error {
removeInstructions := []string{
"-D OUTPUT -j PUBLIC_ONLY",
"-F PUBLIC_ONLY",
"-X PUBLIC_ONLY",
}
if remove {
return c.runMixedIptablesInstructions(ctx, removeInstructions)
}
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
appendToBoth := func(instruction string) {
ipv4Instructions = append(ipv4Instructions, instruction)
ipv6Instructions = append(ipv6Instructions, instruction)
}
if target == "REJECT" {
// Block TCP by sending back TCP RST packets.
appendToBoth("-A PUBLIC_ONLY -p tcp -m conntrack --ctstate RELATED,ESTABLISHED " +
"-j REJECT --reject-with tcp-reset")
// Block UDP and ICMP, sending back ICMP port unreachable.
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j REJECT")
} else {
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j " + target)
}
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
err := c.runIptablesInstructions(ctx, ipv4Instructions)
if err != nil {
if strings.Contains(err.Error(), " support") {
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
}
}
err = c.runIP6tablesInstructions(ctx, ipv6Instructions)
if err != nil {
_ = c.runIptablesInstructions(ctx, removeInstructions)
if strings.Contains(err.Error(), " support") {
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
}
return err
}
return nil
}
func makeCreatePublicIPChainInstructions() (ipv4Instructions, ipv6Instructions []string) {
ipv4PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("127.0.0.0/8"),
}
ipv6PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("fc00::/7"),
netip.MustParsePrefix("fe80::/10"),
netip.MustParsePrefix("::1/128"),
}
ipv4Instructions = append(ipv4Instructions, "-N PUBLIC_ONLY")
ipv6Instructions = append(ipv6Instructions, "-N PUBLIC_ONLY")
for _, prefix := range ipv4PrivatePrefixes {
ipv4Instructions = append(ipv4Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
for _, prefix := range ipv6PrivatePrefixes {
ipv6Instructions = append(ipv6Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
return ipv4Instructions, ipv6Instructions
}
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
protocol, connection.Port)
if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
// Thanks to @npawelek.
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
if doIPv4 {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
})
}
// RedirectPort redirects incoming traffic on the specified source port to the
// specified destination port, for both TCP and UDP protocols, on the interface intf.
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
// the redirection is removed instead of added. This is used for VPN server side
// port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) RedirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool,
) (err error) {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx)
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx) // just in case
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
}
return nil
}
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
return nil
}
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
b, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(b), "\n")
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, line := range lines {
var ipv4 bool
var rule string
switch {
case strings.HasPrefix(line, "iptables "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables ")
case strings.HasPrefix(line, "iptables-nft "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-nft ")
case strings.HasPrefix(line, "iptables-legacy "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-legacy ")
case strings.HasPrefix(line, "ip6tables "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables ")
case strings.HasPrefix(line, "ip6tables-nft "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-nft ")
case strings.HasPrefix(line, "ip6tables-legacy "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
default:
continue
}
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstruction(ctx, rule)
}
if err != nil {
restore(ctx)
return err
}
}
return nil
}
-49
View File
@@ -1,49 +0,0 @@
package iptables
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, instruction := range instructions {
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
restore(ctx)
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstructionNoSave(ctx, instruction)
}
-339
View File
@@ -1,339 +0,0 @@
package iptables
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
operation operation
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
sourcePort uint16 // if zero, there is no source port
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint // only used for jump CONNMARK --set-mark
rejectWith string // only used for REJECT targets
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case i.sourcePort != rule.sourcePort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
case i.mark != rule.mark:
return false
case i.connMark != rule.connMark:
return false
case i.setMark != rule.setMark:
return false
case i.rejectWith != rule.rejectWith:
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
consumed, err = preCheckInstructionFields(fields)
if err != nil {
return 0, err
}
flag := fields[0]
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.operation = opDelete
instruction.chain = value
case "-A", "--append":
instruction.operation = opAppend
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.chain = value
case "-j", "--jump":
subConsumed, err := parseJumpFlag(fields[1:], instruction)
if err != nil {
return 0, fmt.Errorf("parsing jump flag: %w", err)
}
consumed += subConsumed
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match":
consumed, err = parseMatchModule(fields, instruction)
if err != nil {
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
n, err := parseAny32bNumber(value)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", value, err)
}
instruction.mark.value = n
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "--sport":
instruction.sourcePort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
instruction.destinationPort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
instruction.toPorts, err = parseToPorts(value)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
case "--reject-with":
instruction.rejectWith = value // for example "tcp-reset"
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
}
return consumed, nil
}
func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
return expected, nil
}
}
func parseJumpFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
instruction.target = fields[0]
// consumed in the caller already takes fields[0] into account
if instruction.target != "CONNMARK" {
return consumed, nil
}
// consumed already accounts for the "CONNMARK" value
const expectedFields = 3
if len(fields) < expectedFields {
return 0, fmt.Errorf("%w: jump CONNMARK requires at least two additional values",
ErrIptablesCommandMalformed)
}
switch fields[1] {
case "--set-mark":
n, err := parseAny32bNumber(fields[2])
if err != nil {
return 0, fmt.Errorf("parsing connmark mark value %q: %w", fields[2], err)
}
consumed++
instruction.setMark = n
default:
return consumed, fmt.Errorf("%w: unsupported jump CONNMARK with value: %s",
ErrIptablesCommandMalformed, fields[1])
}
consumed++
return consumed, nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}
func parsePort(value string) (port uint16, err error) {
const base, bitLength = 10, 16
portValue, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, err
}
return uint16(portValue), nil
}
func parseAny32bNumber(mark string) (value uint, err error) {
const base = 0 // auto-detect
const bits = 32
n, err := strconv.ParseUint(mark, base, bits)
return uint(n), err
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
_ = fields[consumed] // -m or --match flag already detected
consumed++
switch fields[consumed] {
case "tcp", "udp":
consumed++
// for now ignore the protocol match since it's auto-loaded
// when parsing the -p/--protocol flag, and we don't need to
// parse it twice.
case "mark":
consumed++
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
case "connmark":
consumed++
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
consumed++
instruction.connMark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match connmark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
}
return consumed, nil
}
func parseToPorts(value string) (toPorts []uint16, err error) {
portStrings := strings.Split(value, ",")
toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
toPorts[i], err = parsePort(portString)
if err != nil {
return nil, err
}
}
return toPorts, nil
}
-98
View File
@@ -1,98 +0,0 @@
package iptables
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("libxt_mark.so module is missing")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
if err != nil && errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
run := c.runIptablesInstruction
if dst.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}
+21
View File
@@ -0,0 +1,21 @@
package firewall
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"errors"
@@ -26,21 +26,10 @@ type chainRule struct {
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
sourcePort uint16 // Not specified if set to zero.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint
rejectWith string // for example "tcp-reset", only used for REJECT targets
}
type mark struct {
invert bool
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -222,6 +211,10 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
@@ -248,23 +241,19 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
i := 0
for i < len(optionalFields) {
switch optionalFields[i] {
case "udp":
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i++
consumed, err := parseUDPOptional(optionalFields[i:], rule)
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing UDP optional fields: %w", err)
return fmt.Errorf("parsing destination port %q: %w", value, err)
}
i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
rule.destinationPort = uint16(destinationPort)
case "redir":
i++
switch optionalFields[i] {
@@ -275,163 +264,20 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
i++
case "mark":
i++
mark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing mark: %w", err)
}
rule.mark = mark
i += consumed
case "reject-with":
i++
rule.rejectWith = optionalFields[i] // for example "tcp-reset"
i++
case "connmark":
i++
connMark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing connmark: %w", err)
}
rule.connMark = connMark
i += consumed
case "CONNMARK":
i++
switch optionalFields[i] {
case "set":
i++
value, err := parseAny32bNumber(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing CONNMARK set value: %w", err)
}
rule.setMark = value
i++
default:
return fmt.Errorf("%w: unexpected %q after CONNMARK",
ErrChainRuleMalformed, optionalFields[i])
}
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
return parsePort(value)
}
func parseSourcePort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "spt:")
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
@@ -440,36 +286,16 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
ports[i], err = parsePort(field)
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
if err != nil {
return nil, err
return nil, fmt.Errorf("parsing port %q: %w", field, err)
}
ports[i] = uint16(port)
}
return ports, nil
}
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
consumed++
if optionalFields[consumed] == "!" {
m.invert = true
consumed++
}
value, err := parseAny32bNumber(optionalFields[consumed])
if err != nil {
return mark{}, 0, fmt.Errorf("value malformed: %w", err)
}
m.value = value
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"net/netip"
@@ -1,3 +1,3 @@
package iptables
package firewall
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger)
// Package iptables is a generated GoMock package.
package iptables
// Package firewall is a generated GoMock package.
package firewall
import (
exec "os/exec"
+2 -2
View File
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
}
firewallUpdated = true
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subNet, remove)
if err != nil {
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
}
firewallUpdated = true
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
+164
View File
@@ -0,0 +1,164 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.chain = value
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
}
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"net/netip"
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
},
"unknown_key": {
s: "-x something",
@@ -33,9 +33,9 @@ func Test_parseIptablesInstruction(t *testing.T) {
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opAppend,
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opAppend,
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
operation: opDelete,
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
+2 -2
View File
@@ -35,7 +35,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
const remove = false
if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err)
}
@@ -68,7 +68,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
const remove = true
for netInterface := range interfacesSet {
err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
err := c.acceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("removing allowed port %d on interface %s: %w",
port, netInterface, err)
+2 -2
View File
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
return nil
case conflict != nil:
const remove = true
err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove)
if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err)
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
}
const remove = false
err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -11,10 +11,10 @@ import (
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrIPTablesNotSupported = errors.New("no iptables supported found")
)
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
}
return "", fmt.Errorf("%w: errors encountered are: %s",
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; "))
}
func testIptablesPath(ctx context.Context, path string,
@@ -1,4 +1,4 @@
package iptables
package firewall
import (
"context"
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errSentinel: ErrIPTablesNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
+4 -4
View File
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove := true
if c.vpnConnection.IP.IsValid() {
for _, defaultRoute := range c.defaultRoutes {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
}
}
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
c.vpnConnection = models.Connection{}
if c.vpnIntf != "" {
if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
}
}
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove = false
for _, defaultRoute := range c.defaultRoutes {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
}
}
c.vpnConnection = connection
if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf
-21
View File
@@ -1,21 +0,0 @@
package firewall
import (
"context"
"net/netip"
)
func (c *Config) Version(ctx context.Context) (version string, err error) {
return c.impl.Version(ctx)
}
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
}
+11 -37
View File
@@ -23,7 +23,6 @@ type Checker struct {
logger Logger
icmpTargetIPs []netip.Addr
smallCheckType string
startupOnFail bool
configMutex sync.Mutex
icmpNotPermitted *bool
@@ -46,43 +45,26 @@ func NewChecker(logger Logger) *Checker {
}
}
// SetConfig sets the following:
// - TCP+TLS dial addresses
// - ICMP echo IP addresses to target
// - the desired small check type (dns or icmp)
// - whether to startup the periodic checks if the startup check fails.
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
// to target and the desired small check type (dns or icmp).
// This function MUST be called before calling [Checker.Start].
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
smallCheckType string, startupOnFail bool,
smallCheckType string,
) {
c.configMutex.Lock()
defer c.configMutex.Unlock()
c.tlsDialAddrs = tlsDialAddrs
c.icmpTargetIPs = icmpTargets
c.smallCheckType = smallCheckType
c.startupOnFail = startupOnFail
}
// Start starts the [Checker] which behaves differently according to its
// internal field startupOnFail, which is set by calling [Checker.SetConfig].
//
// By default, startupOnFail should be false and the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// an error is returned and the [Checker] is not started.
// On success, it starts the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// If startupOnFail is true, the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// the error is sent to the runError channel, but no error is returned
// and the [Checker] continues to start the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// The periodic checks consist in:
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
// and, on success, starts the periodic checks in a separate goroutine:
// - a "small" ICMP echo check every minute
// - a "full" TCP+TLS check every 5 minutes
//
// The [Checker] has to be ultimately stopped by calling [Checker.Stop].
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
// It returns an error if the initial TCP+TLS check fails.
// The Checker has to be ultimately stopped by calling [Checker.Stop].
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
panic("call Checker.SetConfig with non empty values before Checker.Start")
@@ -94,19 +76,9 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
}
c.echoer.Reset()
// runErrorCh MUST be buffered in the case startupOnFail is true, and
// a startup error was encountered, to avoid blocking the startup
// goroutine when sending the error, especially since the caller may
// not be ready to receive from the channel yet.
runErrorCh := make(chan error, 1)
runError = runErrorCh
err = c.startupCheck(ctx)
if err != nil {
err = fmt.Errorf("startup check: %w", err)
if !c.startupOnFail {
return nil, err
}
runErrorCh <- err
return nil, fmt.Errorf("startup check: %w", err)
}
ready := make(chan struct{})
@@ -118,6 +90,8 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
smallCheckTimer := time.NewTimer(smallCheckPeriod)
const fullCheckPeriod = 5 * time.Minute
fullCheckTimer := time.NewTimer(fullCheckPeriod)
runErrorCh := make(chan error)
runError = runErrorCh
go func() {
defer close(done)
close(ready)
-4
View File
@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"time"
)
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
@@ -22,9 +21,6 @@ func NewClient(httpClient *http.Client) *Client {
}
func (c *Client) Check(ctx context.Context, url string) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
-33
View File
@@ -1,33 +0,0 @@
package mod
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
var errBuiltinModuleNotFound = errors.New("builtin module not found")
func checkModulesBuiltin(modulesPath, moduleName string) error {
f, err := os.Open(filepath.Join(modulesPath, "modules.builtin"))
if err != nil {
return err
}
defer f.Close()
moduleName = strings.TrimSuffix(moduleName, ".ko")
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSuffix(line, ".ko")
if strings.HasSuffix(line, "/"+moduleName) {
return nil
}
}
return fmt.Errorf("%w: %s", errBuiltinModuleNotFound, moduleName)
}
-132
View File
@@ -1,132 +0,0 @@
package mod
import (
"bufio"
"compress/gzip"
"errors"
"fmt"
"os"
"strings"
)
var (
errModuleNameUnknown = errors.New("unknown module name")
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
errKernelFeatureNotSet = errors.New("kernel feature not set")
errKernelFeatureNotFound = errors.New("kernel feature not found")
)
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
// feature is a module, not built-in.
// If the kernel feature is found and not set, it returns an error indicating that the kernel
// feature is not set. If the kernel feature is not found, it returns an error indicating that the kernel
// feature is not found.
func checkProcConfig(moduleName string) error {
f, err := os.Open("/proc/config.gz")
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("creating gzip reader: %w", err)
}
defer gz.Close()
// If any group of kernel features is satisfied, then the module is considered supported.
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
if !ok {
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
}
groups := make([]map[string]bool, len(kernelFeatureGroups))
for i, group := range kernelFeatureGroups {
featureToOK := make(map[string]bool)
for _, feature := range group {
featureToOK[feature] = false
}
groups[i] = featureToOK
}
scanner := bufio.NewScanner(gz)
for scanner.Scan() {
line := scanner.Text()
for _, featureToOK := range groups {
for name, ok := range featureToOK {
switch {
case ok:
case strings.HasPrefix(line, name+"=m"):
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
case strings.HasPrefix(line, name+"=y"):
featureToOK[name] = true
if allFeaturesOK(featureToOK) {
return nil
}
case strings.HasPrefix(line, "# "+name+" is not set"):
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
}
}
}
}
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
}
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
moduleMap := map[string][][]string{
"nf_tables": {{"CONFIG_NF_TABLES"}},
// Netfilter Matches
"xt_conntrack": {{"CONFIG_NETFILTER_XT_MATCH_CONNTRACK"}},
"xt_connmark": {
{"CONFIG_NETFILTER_XT_CONNMARK"},
{"CONFIG_NETFILTER_XT_MATCH_CONNMARK", "CONFIG_NETFILTER_XT_TARGET_CONNMARK"},
},
"xt_mark": {
{"CONFIG_NETFILTER_XT_MARK"},
{"CONFIG_NETFILTER_XT_MATCH_MARK", "CONFIG_NETFILTER_XT_TARGET_MARK"},
},
"nf_conntrack_netlink": {{"CONFIG_NF_CT_NETLINK"}},
"nf_reject_ipv4": {{"CONFIG_NF_REJECT_IPV4"}},
// Common Netfilter Targets
"xt_log": {{"CONFIG_NETFILTER_XT_TARGET_LOG"}},
"xt_reject": {
{"CONFIG_IP_NF_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
{"CONFIG_NETFILTER_XT_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
},
"xt_masquerade": {{"CONFIG_NETFILTER_XT_TARGET_MASQUERADE"}},
// Additional Netfilter Matches
"xt_addrtype": {{"CONFIG_NETFILTER_XT_MATCH_ADDRTYPE"}},
"xt_comment": {{"CONFIG_NETFILTER_XT_MATCH_COMMENT"}},
"xt_multiport": {{"CONFIG_NETFILTER_XT_MATCH_MULTIPORT"}},
"xt_state": {{"CONFIG_NETFILTER_XT_MATCH_STATE"}},
"xt_tcpudp": {{"CONFIG_NETFILTER_XT_MATCH_TCPUDP"}},
// Tunneling and Virtualization
"tun": {{"CONFIG_TUN"}},
"bridge": {{"CONFIG_BRIDGE"}},
"veth": {{"CONFIG_VETH"}},
"vxlan": {{"CONFIG_VXLAN"}},
"wireguard": {{"CONFIG_WIREGUARD"}},
// Filesystems
"overlay": {{"CONFIG_OVERLAY_FS"}},
"fuse": {{"CONFIG_FUSE_FS"}},
}
featureGroups, ok = moduleMap[strings.ToLower(moduleName)]
return featureGroups, ok
}
func allFeaturesOK(featureToOK map[string]bool) bool {
for _, ok := range featureToOK {
if !ok {
return false
}
}
return true
}
@@ -1,5 +1,3 @@
//go:build !windows
package mod
import (
@@ -30,7 +28,36 @@ type moduleInfo struct {
var ErrModulesDirectoryNotFound = errors.New("modules directory not found")
func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err error) {
func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return nil, fmt.Errorf("getting unix uname release: %w", err)
}
release := unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
var modulesPath string
var found bool
for _, modulesPath = range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
dependencyFilepath := filepath.Join(modulesPath, "modules.dep")
dependencyFile, err := os.Open(dependencyFilepath)
if err != nil {
@@ -82,39 +109,6 @@ func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err
return modulesInfo, nil
}
func getModulesPath() (string, error) {
release, err := getReleaseName()
if err != nil {
return "", fmt.Errorf("getting release name: %w", err)
}
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
for _, modulesPath := range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
return modulesPath, nil
}
}
return "", fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
func getReleaseName() (release string, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return "", fmt.Errorf("getting unix uname release: %w", err)
}
release = unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
return release, nil
}
func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error {
file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin"))
if err != nil {
+37
View File
@@ -0,0 +1,37 @@
package mod
import (
"fmt"
)
// Probe loads the given kernel module and its dependencies.
func Probe(moduleName string) error {
modulesInfo, err := getModulesInfo()
if err != nil {
return fmt.Errorf("getting modules information: %w", err)
}
modulePath, err := findModulePath(moduleName, modulesInfo)
if err != nil {
return fmt.Errorf("finding module path: %w", err)
}
info := modulesInfo[modulePath]
if info.state == builtin || info.state == loaded {
return nil
}
info.state = loading
for _, dependencyModulePath := range info.dependencyPaths {
err = initDependencies(dependencyModulePath, modulesInfo)
if err != nil {
return fmt.Errorf("init dependencies: %w", err)
}
}
err = initModule(modulePath)
if err != nil {
return fmt.Errorf("init module: %w", err)
}
return nil
}
-74
View File
@@ -1,74 +0,0 @@
package mod
import (
"errors"
"fmt"
)
// Probe is a expanded version of modprobe, in which it checks if the Kernel
// built-in features contain the given module name.
// It first tries to locate the modules directory in [getModulesPath].
// If it fails (like on WSL), it then only checks for the kernel feature
// in /proc/config.gz with [checkProcConfig].
// Otherwise, it first checks if the modules directory modules.builtin
// file contains the given module name in [checkModulesBuiltin].
// If the module is not found, it then runs the classic [modProbe] behavior,
// trying to load the module in the kernel.
// If this fails, it does one final try running [checkProcConfig].
func Probe(moduleName string) error {
modulesPath, err := getModulesPath()
if err != nil {
if errors.Is(err, ErrModulesDirectoryNotFound) {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
return nil
}
return fmt.Errorf("getting modules path: %w", err)
}
err = checkModulesBuiltin(modulesPath, moduleName)
if err != nil {
err = modProbe(modulesPath, moduleName)
if err != nil {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
}
}
return nil
}
// modProbe is the classic modprobe behavior.
func modProbe(modulesPath, moduleName string) error {
modulesInfo, err := getModulesInfo(modulesPath)
if err != nil {
return fmt.Errorf("getting modules information: %w", err)
}
modulePath, err := findModulePath(moduleName, modulesInfo)
if err != nil {
return fmt.Errorf("finding module path: %w", err)
}
info := modulesInfo[modulePath]
if info.state == builtin || info.state == loaded {
return nil
}
info.state = loading
for _, dependencyModulePath := range info.dependencyPaths {
err = initDependencies(dependencyModulePath, modulesInfo)
if err != nil {
return fmt.Errorf("init dependencies: %w", err)
}
}
err = initModule(modulePath)
if err != nil {
return fmt.Errorf("init module: %w", err)
}
return nil
}
-7
View File
@@ -1,7 +0,0 @@
//go:build !linux
package mod
func Probe(moduleName string) error {
panic("not implemented")
}
+17 -59
View File
@@ -1,75 +1,33 @@
//go:build linux || darwin
package netlink
import (
"fmt"
"net"
"net/netip"
"github.com/jsimonetti/rtnetlink/rtnl"
"github.com/vishvananda/netlink"
)
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
ipPrefixes []netip.Prefix, err error,
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
) {
conn, err := rtnl.Dial(nil)
netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ifc := &net.Interface{
Index: int(linkIndex),
}
ipNets, err := conn.Addrs(ifc, int(family))
if err != nil {
return nil, fmt.Errorf("failed to list addresses: %w", err)
return nil, err
}
ipPrefixes = make([]netip.Prefix, len(ipNets))
for i := range ipNets {
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
}
return ipPrefixes, nil
return addresses, nil
}
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
conn, err := rtnl.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ipNet := netipPrefixToIPNet(prefix)
// Remove any address identical to the one we want to add
family := FamilyV4
if prefix.Addr().Is6() {
family = FamilyV6
}
ifc := &net.Interface{
Index: int(linkIndex),
}
addresses, err := conn.Addrs(ifc, int(family))
if err != nil {
return fmt.Errorf("listing addresses: %w", err)
}
for _, address := range addresses {
if address.IP.Equal(ipNet.IP) &&
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
err = conn.AddrDel(ifc, address)
if err != nil {
return fmt.Errorf("deleting address from interface: %w", err)
}
break
}
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
}
// Add the new address to the interface
err = conn.AddrAdd(ifc, ipNet)
if err != nil {
return fmt.Errorf("adding address to interface: %w", err)
}
return nil
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
}
+13
View File
@@ -0,0 +1,13 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
) {
panic("not implemented")
}
func (n *NetLink) AddrReplace(Link, Addr) error {
panic("not implemented")
}
-44
View File
@@ -1,44 +0,0 @@
package netlink
import (
"errors"
"fmt"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
"golang.org/x/sys/unix"
)
var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel")
func (n *NetLink) FlushConntrack() error {
conn, err := netfilter.Dial(nil)
if err != nil {
if !n.conntrackNetlink {
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
}
return fmt.Errorf("dialing netfilter: %w", err)
}
defer conn.Close()
const ipCtnlMsgCtDelete = netfilter.MessageType(2)
header := netfilter.Header{
SubsystemID: netfilter.NFSubsysCTNetlink,
MessageType: ipCtnlMsgCtDelete,
Family: unix.AF_UNSPEC,
Flags: netlink.Request | netlink.Acknowledge,
}
request, err := netfilter.MarshalNetlink(header, nil)
if err != nil {
return fmt.Errorf("encoding netlink request: %w", err)
}
_, err = conn.Query(request)
if err != nil {
if !n.conntrackNetlink {
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
}
return fmt.Errorf("querying netlink request: %w", err)
}
return nil
}
-11
View File
@@ -1,11 +0,0 @@
//go:build !linux
package netlink
import "errors"
var ErrConntrackNetlinkNotSupported = errors.New("error not implemented")
func (n *NetLink) FlushConntrack() error {
panic("not implemented")
}
-24
View File
@@ -36,30 +36,6 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
return netip.PrefixFrom(ip, bits)
}
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
if ip == nil || len(*ip) == 0 {
return netip.Prefix{}
}
var dstIP netip.Addr
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
dstIP = netip.AddrFrom4([4]byte(*ip))
} else {
dstIP = netip.AddrFrom16([16]byte(*ip))
}
return netip.PrefixFrom(dstIP, int(length))
}
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
if !prefix.IsValid() {
return nil, 0
}
prefixIP := prefix.Addr().Unmap()
ip = new(net.IP)
*ip = netipAddrToNetIP(prefixIP)
length = uint8(prefix.Bits()) //nolint:gosec
return ip, length
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
+7 -1
View File
@@ -4,7 +4,13 @@ import (
"fmt"
)
func FamilyToString(family uint8) string {
const (
FamilyAll = 0
FamilyV4 = 2
FamilyV6 = 10
)
func FamilyToString(family int) string {
switch family {
case FamilyAll:
return "all"
-9
View File
@@ -1,9 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
FamilyAll uint8 = unix.AF_UNSPEC
FamilyV4 uint8 = unix.AF_INET
FamilyV6 uint8 = unix.AF_INET6
)
-14
View File
@@ -1,30 +1,16 @@
package netlink
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/log"
)
func ptrTo[T any](v T) *T { return &v }
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
name := make([]byte, 8)
for i := range name {
name[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(name)
}
type noopLogger struct{}
func (l *noopLogger) Debug(_ string) {}
+1 -1
View File
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
return false, fmt.Errorf("finding link corresponding to route: %w", err)
}
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch {
case !sourceIsIPv6 && !destinationIsIPv6,
+76 -162
View File
@@ -1,191 +1,105 @@
//go:build linux || darwin
package netlink
import (
"errors"
"fmt"
"github.com/jsimonetti/rtnetlink"
)
type DeviceType uint16
type Link struct {
Index uint32
Name string
DeviceType DeviceType
VirtualType string
MTU uint32
}
import "github.com/vishvananda/netlink"
func (n *NetLink) LinkList() (links []Link, err error) {
conn, err := rtnetlink.Dial(nil)
netlinkLinks, err := netlink.LinkList()
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkMessages, err := conn.Link.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
return nil, err
}
links = make([]Link, len(linkMessages))
for i, message := range linkMessages {
virtualType := ""
if message.Attributes.Info != nil {
virtualType = message.Attributes.Info.Kind
}
links[i] = Link{
Index: message.Index,
Name: message.Attributes.Name,
DeviceType: DeviceType(message.Type),
VirtualType: virtualType,
MTU: message.Attributes.MTU,
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
}
return links, nil
}
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) {
links, err := n.LinkList()
netlinkLink, err := netlink.LinkByName(name)
if err != nil {
return Link{}, fmt.Errorf("listing links: %w", err)
return Link{}, err
}
for _, link := range links {
if link.Name == name {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
links, err := n.LinkList()
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
netlinkLink, err := netlink.LinkByIndex(index)
if err != nil {
return Link{}, fmt.Errorf("listing links: %w", err)
return Link{}, err
}
for _, link = range links {
if link.Index == index {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
conn, err := rtnetlink.Dial(nil)
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
if err != nil {
return 0, fmt.Errorf("dialing netlink: %w", err)
return 0, err
}
defer conn.Close()
return netlinkLink.Attrs().Index, nil
}
tx := &rtnetlink.LinkMessage{
Type: uint16(link.DeviceType),
Attributes: &rtnetlink.LinkAttributes{
MTU: link.MTU,
Name: link.Name,
func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU), //nolint:gosec
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
},
}
if link.VirtualType != "" {
tx.Attributes.Info = &rtnetlink.LinkInfo{
Kind: link.VirtualType,
}
}
err = conn.Link.New(tx)
if err != nil {
return 0, fmt.Errorf("creating new link: %w", err)
}
linkMessages, err := conn.Link.List()
if err != nil {
return 0, fmt.Errorf("listing links: %w", err)
}
for _, linkMessage := range linkMessages {
if linkMessage.Attributes.Name == link.Name {
return linkMessage.Index, nil
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Link.Delete(linkIndex)
}
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
rx, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
tx := &rtnetlink.LinkMessage{
Type: rx.Type,
Index: linkIndex,
Flags: iffUp,
Change: iffUp,
}
return conn.Link.Set(tx)
}
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkInfo, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
message := &rtnetlink.LinkMessage{
Type: linkInfo.Type,
Index: linkIndex,
Flags: 0,
Change: iffUp,
}
return conn.Link.Set(message)
}
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
message := &rtnetlink.LinkMessage{
Index: linkIndex,
Attributes: &rtnetlink.LinkAttributes{
MTU: mtu,
},
}
err = conn.Link.Set(message)
if err != nil {
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
mtu, linkIndex, err)
}
return nil
}
-11
View File
@@ -1,11 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
iffUp = unix.IFF_UP
)
-85
View File
@@ -1,85 +0,0 @@
//go:build linux
package netlink
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NetLink_LinkList(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
initialLinks, err := netlink.LinkList()
require.NoError(t, err)
require.NotEmpty(t, initialLinks)
loopbackFound := false
for _, link := range initialLinks {
if link.Name != "lo" {
continue
}
loopbackFound = true
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
break
}
assert.True(t, loopbackFound, "loopback interface not found")
testLink := Link{
Name: makeLinkName(),
// note if [Link.VirtualType] is set, [Link.DeviceType]
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
links, err := netlink.LinkList()
require.NoError(t, err)
testLink.Index = index
for _, link := range links {
if link.Name != testLink.Name {
continue
}
assert.Equal(t, testLink, link)
return
}
t.Errorf("created link %q not found", testLink.Name)
}
func Test_NetLink_LinkSetMTU(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
testLink := Link{
Name: makeLinkName(),
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
testLink.Index = index
err = netlink.LinkSetMTU(index, 1500)
require.NoError(t, err)
link, err := netlink.LinkByIndex(index)
require.NoError(t, err)
testLink.MTU = 1500
assert.Equal(t, testLink, link)
}
+31
View File
@@ -0,0 +1,31 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) LinkList() (links []Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByName(name string) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkDel(link Link) (err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
panic("not implemented")
}
+2 -10
View File
@@ -1,22 +1,14 @@
package netlink
import (
"github.com/qdm12/gluetun/internal/mod"
"github.com/qdm12/log"
)
import "github.com/qdm12/log"
type NetLink struct {
debugLogger DebugLogger
// Fixed state
conntrackNetlink bool
}
func New(debugLogger DebugLogger) *NetLink {
conntrackNetlink := mod.Probe("nf_conntrack_netlink") == nil
return &NetLink{
debugLogger: debugLogger,
conntrackNetlink: conntrackNetlink,
debugLogger: debugLogger,
}
}
-56
View File
@@ -1,56 +0,0 @@
//go:build !linux
package netlink
const (
// FamilyAll is a placeholder only and should not
// be used.
FamilyAll uint8 = iota
// FamilyV4 is a placeholder only and should not
// be used.
FamilyV4
// FamilyV6 is a placeholder only and should not
// be used.
FamilyV6
// DeviceTypeEthernet is a placeholder only and should not be used.
DeviceTypeEthernet DeviceType = 0
// DeviceTypeLoopback is a placeholder only and should not be used.
DeviceTypeLoopback DeviceType = 0
// DeviceTypeNone is a placeholder only and should not be used.
DeviceTypeNone DeviceType = 0
// iffUp is a placeholder only and should not be used.
iffUp = 0
// RouteTypeUnicast is a placeholder only and should not be used.
RouteTypeUnicast = 0
// ScopeUniverse is a placeholder only and should not be used.
ScopeUniverse = 0
// ProtoStatic is a placeholder only and should not be used.
ProtoStatic = 0
// FlagInvert is a placeholder only and should not be used.
FlagInvert = 0
// ActionToTable is a placeholder only and should not be used.
ActionToTable = 0
// rtTableCompat is a placeholder only and should not be used.
rtTableCompat = 0
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) IsWireguardSupported() (bool, error) {
panic("not implemented")
}
+48 -116
View File
@@ -1,137 +1,69 @@
//go:build linux || darwin
package netlink
import (
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
"github.com/vishvananda/netlink"
)
type Route struct {
LinkIndex uint32
Dst netip.Prefix
Src netip.Prefix
Gw netip.Addr
Priority uint32
Family uint8
Table uint32
Type uint8
Scope uint8
Proto uint8
AdvMSS uint32
}
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
// We set the filter to netlink.RT_FILTER_TABLE so that
// routes from all tables are listed, as long as the filter
// table is set to 0.
const filterMask = netlink.RT_FILTER_TABLE
// The filter is not left to `nil` otherwise non-main tables
// are ignored.
filter := &netlink.Route{}
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = message.Attributes.Table
}
r.LinkIndex = message.Attributes.OutIface
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
if metrics := message.Attributes.Metrics; metrics != nil {
r.AdvMSS = metrics.AdvMSS
}
}
func (r Route) message() *rtnetlink.RouteMessage {
dst, dstLength := prefixToIPAndLength(r.Dst)
src, srcLength := prefixToIPAndLength(r.Src)
var table uint8
var extendedTable uint32
if r.Table <= uint32(^uint8(0)) {
table = uint8(r.Table)
} else {
table = rtTableCompat
extendedTable = r.Table
}
message := &rtnetlink.RouteMessage{
Family: r.Family,
DstLength: dstLength,
SrcLength: srcLength,
Table: table,
Type: r.Type,
Scope: r.Scope,
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
},
}
if src != nil { // src is optional
message.Attributes.Src = *src
}
if dst != nil {
message.Attributes.Dst = *dst
}
if r.AdvMSS != 0 {
if message.Attributes.Metrics == nil {
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
}
message.Attributes.Metrics.AdvMSS = r.AdvMSS
}
return message
}
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
conn, err := rtnetlink.Dial(nil)
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
routeMessages, err := conn.Route.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
return nil, err
}
routes = make([]Route, 0, len(routeMessages))
for _, routeMessage := range routeMessages {
if family != FamilyAll && routeMessage.Family != family {
continue
}
var route Route
route.fromMessage(routeMessage)
routes = append(routes, route)
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
}
return routes, nil
}
func (n *NetLink) RouteAdd(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Add(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
}
func (n *NetLink) RouteDel(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Delete(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
}
func (n *NetLink) RouteReplace(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Replace(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
}
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
}
}
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
}
-11
View File
@@ -1,11 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
RouteTypeUnicast = unix.RTN_UNICAST
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
ProtoStatic = unix.RTPROT_STATIC
rtTableCompat = unix.RT_TABLE_COMPAT
)
+21
View File
@@ -0,0 +1,21 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) RouteList(family int) (
routes []Route, err error,
) {
panic("not implemented")
}
func (n *NetLink) RouteAdd(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteDel(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteReplace(route Route) error {
panic("not implemented")
}
+73 -78
View File
@@ -1,96 +1,91 @@
//go:build linux
package netlink
import (
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
"github.com/vishvananda/netlink"
)
type Rule struct {
Priority *uint32
Family uint8
Table uint32
Mark *uint32
Src netip.Prefix
Dst netip.Prefix
Flags uint32
Action uint8
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: 0,
}
}
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = *message.Attributes.Table
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Mark = message.Attributes.FwMark
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
r.Flags = message.Flags
r.Action = message.Action
netlinkRules, err := netlink.RuleList(family)
if err != nil {
return nil, err
}
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
}
return rules, nil
}
func (r Rule) message() *rtnetlink.RuleMessage {
src, srcLength := prefixToIPAndLength(r.Src)
dst, dstLength := prefixToIPAndLength(r.Dst)
message := &rtnetlink.RuleMessage{
Family: r.Family,
SrcLength: srcLength,
DstLength: dstLength,
Flags: r.Flags,
Action: r.Action,
Attributes: &rtnetlink.RuleAttributes{
Priority: r.Priority,
FwMark: r.Mark,
Src: src,
Dst: dst,
},
}
if r.Table <= uint32(^uint8(0)) {
message.Table = uint8(r.Table)
} else {
message.Table = rtTableCompat
message.Attributes.Table = &r.Table
}
return message
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(true, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
priority := ""
if r.Priority != nil {
priority = fmt.Sprintf(" %d", *r.Priority)
}
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
priority, from, to, r.Table)
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(false, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
}
func (r Rule) debugMessage(add bool) (debugMessage string) {
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
}
}
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
debugMessage = "ip"
switch r.Family {
switch rule.Family {
case FamilyV4:
debugMessage += " -f inet"
case FamilyV6:
debugMessage += " -f inet6"
default:
debugMessage += " -f " + fmt.Sprint(r.Family)
debugMessage += " -f " + fmt.Sprint(rule.Family)
}
debugMessage += " rule"
@@ -101,20 +96,20 @@ func (r Rule) debugMessage(add bool) (debugMessage string) {
debugMessage += " del"
}
if r.Src.IsValid() {
debugMessage += " from " + r.Src.String()
if rule.Src.IsValid() {
debugMessage += " from " + rule.Src.String()
}
if r.Dst.IsValid() {
debugMessage += " to " + r.Dst.String()
if rule.Dst.IsValid() {
debugMessage += " to " + rule.Dst.String()
}
if r.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(r.Table)
if rule.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(rule.Table)
}
if r.Priority != nil {
debugMessage += " pref " + fmt.Sprint(*r.Priority)
if rule.Priority != -1 {
debugMessage += " pref " + fmt.Sprint(rule.Priority)
}
return debugMessage
-69
View File
@@ -1,69 +0,0 @@
package netlink
import (
"fmt"
"github.com/jsimonetti/rtnetlink"
"golang.org/x/sys/unix"
)
const (
FlagInvert = unix.FIB_RULE_INVERT
ActionToTable = unix.FR_ACT_TO_TBL
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ruleMessages, err := conn.Rule.List()
if err != nil {
return nil, err
}
rules = make([]Rule, 0, len(ruleMessages))
for _, message := range ruleMessages {
if family != FamilyAll && family != message.Family {
continue
}
var rule Rule
rule.fromMessage(message)
rules = append(rules, rule)
}
return rules, nil
}
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(true))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Add(rule.message())
}
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(false))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Delete(rule.message())
}
+5 -5
View File
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_Rule_debugMessage(t *testing.T) {
func Test_ruleDbgMsg(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -15,7 +15,7 @@ func Test_Rule_debugMessage(t *testing.T) {
dbgMsg string
}{
"default values": {
dbgMsg: "ip -f 0 rule del",
dbgMsg: "ip -f 0 rule del pref 0",
},
"add rule": {
add: true,
@@ -24,7 +24,7 @@ func Test_Rule_debugMessage(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: ptrTo(uint32(101)),
Priority: 101,
},
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -34,7 +34,7 @@ func Test_Rule_debugMessage(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: ptrTo(uint32(101)),
Priority: 101,
},
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -44,7 +44,7 @@ func Test_Rule_debugMessage(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
dbgMsg := testCase.rule.debugMessage(testCase.add)
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
assert.Equal(t, testCase.dbgMsg, dbgMsg)
})
+19
View File
@@ -0,0 +1,19 @@
//go:build !linux
package netlink
func NewRule() Rule {
return Rule{}
}
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
+58
View File
@@ -0,0 +1,58 @@
package netlink
import (
"fmt"
"net/netip"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
}
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
type Rule struct {
Priority int
Family int
Table int
Mark uint32
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
+37
View File
@@ -0,0 +1,37 @@
//go:build linux
package netlink
import (
"github.com/qdm12/gluetun/internal/mod"
"github.com/vishvananda/netlink"
)
func (n *NetLink) IsWireguardSupported() bool {
// Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels.
ok := hasWireguardFamily()
if ok {
return true
}
// Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err := mod.Probe("wireguard")
if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false
}
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module.
return hasWireguardFamily()
}
func hasWireguardFamily() bool {
_, err := netlink.GenlFamilyGet("wireguard")
return err == nil
}
-58
View File
@@ -1,58 +0,0 @@
package netlink
import (
"errors"
"fmt"
"os"
"github.com/mdlayher/genetlink"
"github.com/qdm12/gluetun/internal/mod"
)
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
} else if ok {
return true, nil
}
// Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err = mod.Probe("wireguard")
if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false, nil
}
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
}
return ok, nil
}
func hasWireguardFamily() (ok bool, err error) {
conn, err := genetlink.Dial(nil)
if err != nil {
return false, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
_, err = conn.GetFamily("wireguard")
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("getting wireguard family: %w", err)
}
return true, nil
}
+1 -4
View File
@@ -4,8 +4,6 @@ package netlink
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
@@ -14,8 +12,7 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
netLink := &NetLink{
debugLogger: &noopLogger{},
}
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
ok := netLink.IsWireguardSupported()
if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported")
} else {
@@ -0,0 +1,7 @@
//go:build !linux
package netlink
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
panic("not implemented")
}
+2 -1
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os/exec"
"syscall"
"github.com/qdm12/gluetun/internal/constants/openvpn"
)
@@ -32,7 +33,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
args := []string{"--config", configPath}
args = append(args, flags...)
cmd := exec.CommandContext(ctx, bin, args...)
setCmdSysProcAttr(cmd)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
return starter.Start(cmd)
}
-10
View File
@@ -1,10 +0,0 @@
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
-12
View File
@@ -1,12 +0,0 @@
//go:build !linux
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
-24
View File
@@ -1,24 +0,0 @@
package constants
const (
MaxEthernetFrameSize uint32 = 1500
// MinIPv4MTU is defined according to
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
MinIPv4MTU uint32 = 68
MinIPv6MTU uint32 = 1280
IPv4HeaderLength uint32 = 20
IPv6HeaderLength uint32 = 40
UDPHeaderLength uint32 = 8
// BaseTCPHeaderLength is the TCP header length without options,
// which is the minimum TCP header length.
BaseTCPHeaderLength uint32 = 20
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
MaxTCPHeaderLength uint32 = 60
WireguardHeaderLength uint32 = 32
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
8 + // session id
4 + // packet id
28 // max possible auth tag/iv
)
-16
View File
@@ -1,16 +0,0 @@
//go:build linux || darwin
package constants
import "golang.org/x/sys/unix"
//nolint:revive
const (
SOCK_RAW = unix.SOCK_RAW
SOCK_STREAM = unix.SOCK_STREAM
AF_INET = unix.AF_INET
AF_INET6 = unix.AF_INET6
IPPROTO_TCP = unix.IPPROTO_TCP
EAGAIN = unix.EAGAIN
EWOULDBLOCK = unix.EWOULDBLOCK
)
@@ -1,13 +0,0 @@
package constants
import "golang.org/x/sys/windows"
const (
SOCK_RAW = windows.SOCK_RAW
SOCK_STREAM = windows.SOCK_STREAM
AF_INET = windows.AF_INET
AF_INET6 = windows.AF_INET6
IPPROTO_TCP = windows.IPPROTO_TCP
EAGAIN = windows.WSAEWOULDBLOCK
EWOULDBLOCK = windows.WSAEWOULDBLOCK
)
-49
View File
@@ -1,49 +0,0 @@
package icmp
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}
-83
View File
@@ -1,83 +0,0 @@
package icmp
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrEchoDataMismatch, sent, received)
}
return nil
}
-14
View File
@@ -1,14 +0,0 @@
package icmp
import (
"golang.org/x/sys/unix"
)
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
if ipv4 {
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP,
unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
}
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6,
unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
}
-10
View File
@@ -1,10 +0,0 @@
//go:build !linux && !windows
package icmp
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
return nil
}
-14
View File
@@ -1,14 +0,0 @@
package icmp
import (
"golang.org/x/sys/windows"
)
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
if ipv4 {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, 14, 1)
}
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, 14, 1)
}
-30
View File
@@ -1,30 +0,0 @@
package icmp
import (
"context"
"errors"
"fmt"
"strings"
"time"
)
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found")
errTimeout = errors.New("operation timed out")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w: after %s", errTimeout, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
-52
View File
@@ -1,52 +0,0 @@
package icmp
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// PathMTUDiscover discovers the path MTU to the given IP address
// using ICMP.
// It first tries to get the next hop MTU using ICMP messages.
// If that fails, it falls back to sending echo requests with
// different packet sizes to find the maximum MTU.
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
logger.Debugf("finding IPv4 next hop MTU to %s", ip)
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
}
} else {
logger.Debugf("requesting IPv6 ICMP packet-too-big reply from %s", ip)
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, errTimeout): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debugf("falling back to sending different sized echo packets to %s", ip)
minMTU := constants.MinIPv4MTU
if ip.Is6() {
minMTU = constants.MinIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
}
-7
View File
@@ -1,7 +0,0 @@
package icmp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
-164
View File
@@ -1,164 +0,0 @@
package icmp
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
const ipv4 = true
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}
return packetConn, nil
}
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
// for loop in case we read an ICMP message from another ICMP request
// or TCP/UDP traffic triggering an ICMP response.
for {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const portUnreachable = 3
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram
const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
-134
View File
@@ -1,134 +0,0 @@
package icmp
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
const ipv4 = false
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
return packetConn, nil
}
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
panic("IP address is not v6")
}
conn, err := listenICMPv6(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop if we encounter another ICMP packet with an unknown id.
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
packetBytes = packetBytes[ipv6.HeaderLen:]
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
// Sanity checks
const truncatedBody = true
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
if err != nil {
return 0, fmt.Errorf("checking invoking message: %w", err)
}
return uint32(typedBody.MTU), nil //nolint:gosec
case *icmp.DstUnreach:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}

Some files were not shown because too many files have changed in this diff Show More