Compare commits

..

63 Commits

Author SHA1 Message Date
Quentin McGaw cd9ba54b37 wip 2026-02-28 22:38:52 +00:00
Quentin McGaw 781e74f77a chore: merge iptables SetIPv4AllPolicies and SetIPv6AllPolicies together 2026-02-28 15:25:15 +00:00
Quentin McGaw fa0941a529 add nftables to dev container 2026-02-28 15:24:37 +00:00
Quentin McGaw e87d915f15 chore(firewall/iptables): modprobe and cache support for xt_mark and nf_tables 2026-02-28 15:23:30 +00:00
Quentin McGaw ec24ffdfd8 hotfix(firewall): save and restore behavior fixed
- restore if IPv4 set all policies fails
- fix deadlock when using iptables custom rules
- fix setting ipv6 rules when running runMixedIptablesInstruction
2026-02-28 14:37:58 +00:00
dependabot[bot] b9d49e0661 Chore(deps): Bump github.com/breml/rootcerts from 0.3.3 to 0.3.4 (#3128) 2026-02-27 02:16:31 +01:00
Quentin McGaw 2bb4deccd5 feat(firewall): atomic iptables operations
- all operations rollback on failure
- disabling the firewall means rolling back to its state before enabling it
- aligns with nftables atomicity feature
2026-02-26 22:58:52 +00:00
Quentin McGaw 0d0c0fb143 feat(dns): update block files after DNS server is up for a faster bootup 2026-02-26 18:45:52 +00:00
Quentin McGaw 885e491bb7 chore(dns): clarify "ready" dns message when DNS server is up and being used 2026-02-26 18:45:52 +00:00
Quentin McGaw e75ae21dcd fix(mod): probe searches for features built-in the kernel 2026-02-26 18:45:52 +00:00
Quentin McGaw 4b8dc8ded7 fix(privado): update servers data using JSON API
- Fixes #3159
- Fixes #2118
- Fixes #2657
2026-02-25 16:02:52 +00:00
Quentin McGaw 0eeee5c496 chore(pmtud): clarify debug logs and fix log error message 2026-02-25 04:23:56 +00:00
Quentin McGaw d21953f62e chore(firewall): split apart iptables specific code in internal/firewall/iptables 2026-02-25 04:23:53 +00:00
Quentin McGaw 034f8f6331 hotfix(netlink): specify IP family for conntrack calls and make conntrack failure a warning 2026-02-25 02:44:07 +00:00
Quentin McGaw 01487b5caf feat(protonvpn): add suggestions on some port forwarding errors 2026-02-23 21:19:08 +00:00
Quentin McGaw 625a63e7c2 fix(firewall): flush conntrack table after enabling firewall at container start
- prevent leaks for connections made the first ~10 milliseconds when Gluetun starts
- seems critical,  but in practice this very rarely happen and it very hard to reproduce
2026-02-22 13:31:38 +00:00
Quentin McGaw 0c3e5d94d8 change!(server): auth is now required for all routes (#2980) 2026-02-20 18:10:53 +01:00
Quentin McGaw d586793169 fix(all): increase global http client timeout to 35s and precise lower timeouts where needed
- Fix DNS blocklists slow downloads, fix #3102
- Leave 35s timeout for updaters
- Set timeouts to 1s for local calls
- Set timeouts to 5s for LAN VPN calls and small external calls
- Set timeouts to 10s external VPN API calls
2026-02-20 16:40:51 +00:00
Quentin McGaw c5eacac644 chore(pmtud/tcp): remove unused TCP flags 2026-02-20 16:25:14 +00:00
Quentin McGaw 7fbf2cbee3 hotfix(pmtud/tcp): return an error if no MSS destination server worked 2026-02-20 16:25:02 +00:00
Quentin McGaw 1dee183a70 chore(pmtud/tcp): silently discard IPv6 network unreachable errors 2026-02-20 16:24:25 +00:00
Quentin McGaw c66d8bed00 hotfix(pmtud/tcp): fix code for IPv6 destinations 2026-02-20 16:23:40 +00:00
Quentin McGaw 73b3e2c88a chore(pmtud/tcp): remove unused test code 2026-02-20 15:37:56 +00:00
Quentin McGaw ea87c0a2aa hotfix(pmtud): lower min MTU to MSS-matching-MTU minus 100 in case MSS is very small 2026-02-19 22:39:24 +00:00
Quentin McGaw 2192874de8 hotfix(pmtud/icmp): ignore non echo messages instead of returning an error 2026-02-19 18:05:48 +00:00
Quentin McGaw 007c5159f4 hotfix(pmtud): increase TCP margin from 150 to 300 compared to ICMP found MTU 2026-02-19 17:24:06 +00:00
Quentin McGaw c6b211ef9b feat(pmtud/tcp): support mixed IPv4 and IPv6 TCP servers
- Add default cloudflare and google tls ipv6 servers to default tcp servers
- update integration test to try against both ipv4 and ipv6 servers
2026-02-19 17:11:16 +00:00
Quentin McGaw 1c43a045d1 hotfix(pmtud/tcp): fix timeout apply per network call, not globally 2026-02-19 17:10:30 +00:00
Quentin McGaw 56b9e108be chore(pmtud/tcp): add :53 TCP servers to the default list 2026-02-19 17:10:30 +00:00
Quentin McGaw 67b66bba9e hotfix(pmtud/icmp): set IPv6 dont fragment options just in case 2026-02-19 17:10:30 +00:00
Quentin McGaw 8d86470905 feat(pmtud/tcp): use the TCP server with highest MSS to run MTU tests 2026-02-19 14:03:46 +00:00
Quentin McGaw fb85ae79d1 chore(pmtud/tcp): move test helpers in helpers_test.go 2026-02-19 13:20:59 +00:00
Quentin McGaw 783616f61d chore(pmtud/tcp): close connections with an RST packet on context cancelation 2026-02-19 13:20:59 +00:00
Quentin McGaw bc79901f1e chore(pmtud/tcp): restrict temp firewall rules to source ip and source port 2026-02-19 13:20:58 +00:00
Quentin McGaw 1c56189abc hotfix(pmtud/tcp): fix rare race condition 2026-02-18 19:07:31 +00:00
Quentin McGaw 224618337c hotfix(pmtud/tcp): respect MSS from server into account 2026-02-18 18:32:31 +00:00
Quentin McGaw 183d351b58 chore(pmtud/icmp): do not use net.ErrClosed when inappropriate 2026-02-17 21:46:24 +00:00
Quentin McGaw 04d7cef294 hotfix(pmtud/tcp): block kernel from racing to send RST packets
- this makes PMTUD TCP reliable
- this only works on kernels with the mark module
- on kernels without the mark module, the icmp pmtud mtu found is used
2026-02-17 21:46:24 +00:00
Quentin McGaw 5f903d1fbf chore(pmtud): remove calls to syscall in favor of unix and windows
- syscall is deprecated and is not kept up-to-date
- each OS is inherently different hence the syscall being deprecated
2026-02-17 21:46:04 +00:00
Quentin McGaw d43eb1658f chore(firewall): support TCP flags for future changes 2026-02-17 19:38:20 +00:00
Quentin McGaw 36dfd5b631 hotfix(pmtud): do not try every address for ICMP PMTUD 2026-02-16 23:54:38 +00:00
Quentin McGaw f81b8342d6 hotfix(pmtud/tcp): temporary test fix 2026-02-16 23:54:38 +00:00
Quentin McGaw cdec25da52 feat(pmtud/tcp): generate MTU test data to mimic TLS if possible to avoid being blocked 2026-02-16 19:57:12 +00:00
Quentin McGaw 201d1041f4 hotfix(pmtud/tcp): send MTU data in first and only ACK packet
- less likely to be flagged
- correct using TCP fast-open
2026-02-16 19:56:14 +00:00
Quentin McGaw dc78b4ecce fix(dns): skip blocking if block lists download fails 2026-02-16 15:27:07 +00:00
Quentin McGaw d75b48d123 chore(dns): update filter block lists without restarting DNS server 2026-02-16 15:23:57 +00:00
Quentin McGaw e828ea1462 feat(dns): allow parent domains to be exempt from rebinding protection
- Specify with `*.domain.com` in DNS_REBINDING_PROTECTION_EXEMPT_HOSTNAMES
- Fix #3135
2026-02-16 14:45:09 +00:00
Quentin McGaw be92aa2ac4 Path MTU discovery fixes and improvements (#3109)
- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used
- New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443`
- ICMP PMTUD now targets external-by-default IP addresses
- New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism.
- Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆
- Fix #3108
2026-02-14 19:40:34 -05:00
Quentin McGaw 8f1fda7646 fix(healthcheck): corret behavior when HEALTH_RESTART_VPN=off and startup check fails 2026-02-11 17:33:14 +00:00
Quentin McGaw 8eb990eb66 chore(ci): ignore .golangci.yml file for reviewdog 2026-02-11 14:25:28 +00:00
Quentin McGaw 4698daea16 chore(mullvad): remove openvpn support 2026-02-11 00:09:36 +00:00
Quentin McGaw b0a75673bd chore(dev): ensure project compiles on darwin and windows 2026-02-09 15:41:52 +00:00
Quentin McGaw 5f0c499808 fix(protonvpn): support port 51820 for UDP OpenVPN 2026-02-09 15:41:52 +00:00
Quentin McGaw bdd69a1fb7 fix(healthcheck): prevent race condition and fix #3096 (#3123) 2026-02-07 18:11:04 +01:00
Quentin McGaw 1af75bb30c fix(openvpn): only log openvpn version corresponding to OPENVPN_VERSION 2026-02-07 16:49:21 +00:00
Chris Duck 9c1cd7e8b1 fix(protonvpn): update OpenVPN settings (#3120) 2026-02-06 14:18:10 +01:00
Quentin McGaw facc6df3be chore(all): replace netlink library for more flexibility (#3107) 2026-01-27 01:11:39 -08:00
Quentin McGaw e292a4c9be fix(httpproxy): remove info log when no Proxy-Authorization header is present 2026-01-24 19:39:20 +00:00
Quentin McGaw 9e4dd61c19 feat(ipvanish): update servers data 2026-01-24 19:32:21 +00:00
Quentin McGaw fe3d4a94d4 chore(all): make code compilable for other platforms than Linux 2026-01-24 17:56:10 +00:00
Quentin McGaw de38d759a4 feat(vpn): path MTU discovery to find the best MTU (#2586) 2026-01-21 09:02:23 -08:00
Quentin McGaw fba60af772 fix(wireguard): fix detection of kernelspace wireguard 2026-01-20 21:39:30 +00:00
Quentin McGaw 9b9b723887 chore(mullvad): add openvpn removal warning 2025-12-29 05:28:13 +00:00
204 changed files with 14919 additions and 13038 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl
RUN apk add wireguard-tools htop openssl tcpdump iptables nftables
+5 -1
View File
@@ -45,6 +45,7 @@ jobs:
level: error
exclude: |
./internal/storage/servers.json
./golangci.yml
*.md
- name: Linting
@@ -59,10 +60,13 @@ jobs:
- name: Run tests in test container
run: |
touch coverage.txt
docker run --rm --device /dev/net/tun \
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
- name: Verify dev cross platform compatibility
run: docker build --target xcompile .
- name: Build final image
run: docker build -t final-image .
+2 -1
View File
@@ -22,6 +22,7 @@ linters:
- "^disabled$"
# Firewall and routing strings
- "^(ACCEPT|DROP)$"
- "^--append$"
- "^--delete$"
- "^all$"
- "^(tcp|udp)$"
@@ -47,7 +48,7 @@ linters:
path: internal\/server\/.+\.go
- linters:
- ireturn
text: returns interface \(github\.com\/vishvananda\/netlink\.Link\)
text: returns interface \(golang\.org\/x\/sys\/unix\.Sockaddr\)
- linters:
- ireturn
path: internal\/openvpn\/pkcs8\/descbc\.go
+9 -2
View File
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
RUN apk --update add git g++ findutils
RUN apk --update add git g++ findutils iptables
ENV CGO_ENABLED=0
COPY --from=golangci-lint /bin /go/bin/golangci-lint
COPY --from=mockgen /bin /go/bin/mockgen
@@ -46,6 +46,10 @@ RUN git init && \
git diff --exit-code && \
rm -rf .git/
FROM --platform=${BUILDPLATFORM} base AS xcompile
RUN GOOS=darwin go build -o /dev/null ./...
RUN GOOS=windows go build -o /dev/null ./...
FROM --platform=${BUILDPLATFORM} base AS build
ARG TARGETPLATFORM
ARG VERSION=unknown
@@ -106,8 +110,11 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU=1320 \
WIREGUARD_MTU= \
WIREGUARD_IMPLEMENTATION=auto \
# PMTUD
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
# VPN server filtering
SERVER_REGIONS= \
SERVER_COUNTRIES= \
+1 -1
View File
@@ -58,7 +58,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
## Features
- Based on Alpine 3.22 for a small Docker image of 41.1MB
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports OpenVPN for all providers listed
- Supports Wireguard both kernelspace and userspace
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
+27 -16
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io/fs"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
@@ -22,6 +23,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/sources/files"
"github.com/qdm12/gluetun/internal/configuration/sources/secrets"
"github.com/qdm12/gluetun/internal/constants"
copenvpn "github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/healthcheck"
@@ -166,7 +168,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
defer fmt.Println(gluetunLogo)
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2026-04-01T00:00:00Z")
if err != nil {
return err
}
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version,
Commit: buildInfo.Commit,
Created: buildInfo.Created,
Announcement: "All control server routes will become private by default after the v3.41.0 release",
Announcement: "All control server routes are now private by default",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -235,6 +237,10 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
if err != nil {
return err
}
err = netLinker.FlushConntrack()
if err != nil {
logger.Warnf("flushing conntrack failed: %s", err)
}
}
// TODO run this in a loop or in openvpn to reload from file without restarting
@@ -262,19 +268,22 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID)
const clientTimeout = 15 * time.Second
const clientTimeout = 35 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
// Create configurators
alpineConf := alpine.New()
ovpnConf := openvpn.New(
logger.New(log.SetComponent("openvpn configurator")),
cmder, puid, pgid)
ovpnVersion := ovpnConf.Version26
if allSettings.VPN.OpenVPN.Version == copenvpn.Openvpn25 {
ovpnVersion = ovpnConf.Version25
}
err = printVersions(ctx, logger, []printVersionElement{
{name: "Alpine", getVersion: alpineConf.Version},
{name: "OpenVPN 2.5", getVersion: ovpnConf.Version25},
{name: "OpenVPN 2.6", getVersion: ovpnConf.Version26},
{name: "IPtables", getVersion: firewallConf.Version},
{name: "OpenVPN", getVersion: ovpnVersion},
{name: "Firewall", getVersion: firewallConf.Version},
})
if err != nil {
return err
@@ -551,24 +560,25 @@ type netLinker interface {
Linker
IsWireguardSupported() (ok bool, err error)
IsIPv6Supported() (ok bool, err error)
FlushConntrack() error
PatchLoggerLevel(level log.Level)
}
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, addr netip.Prefix) error
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -576,11 +586,12 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) (err error)
LinkSetUp(linkIndex uint32) (err error)
LinkSetDown(linkIndex uint32) (err error)
LinkSetMTU(linkIndex, mtu uint32) error
}
type clier interface {
+16 -16
View File
@@ -4,13 +4,17 @@ go 1.25.0
require (
github.com/ProtonMail/go-srp v0.0.7
github.com/breml/rootcerts v0.3.3
github.com/breml/rootcerts v0.3.4
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/google/nftables v0.3.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0
@@ -18,13 +22,13 @@ require (
github.com/qdm12/log v0.1.0
github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1
github.com/ti-mo/netfilter v0.5.3
github.com/ulikunitz/xz v0.5.15
github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
golang.org/x/text v0.31.0
golang.org/x/net v0.49.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0
@@ -38,13 +42,10 @@ require (
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -55,12 +56,11 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+34 -32
View File
@@ -8,11 +8,13 @@ github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1J
github.com/ProtonMail/go-srp v0.0.7/go.mod h1:giCp+7qRnMIcCvI6V6U3S1lDDXDQYx2ewJ6F/9wdlJk=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw=
github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/breml/rootcerts v0.3.4 h1:9i7WNl/ctd9OEAOaTfLy//Wrlfxq/tRQ7v4okYFN9Ys=
github.com/breml/rootcerts v0.3.4/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
@@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
@@ -45,10 +49,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
@@ -69,8 +73,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc10 h1:IyeNEYXfhBsaE1dwxx5eAqdAz1HS98dT+8c7xoKODa0=
github.com/qdm12/dns/v2 v2.0.0-rc10/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 h1:0ycKUDQ50cYb2QpeyGcEnvVs9HJmC9jsb/XZNC1z28c=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
@@ -91,12 +95,12 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ=
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -106,15 +110,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -122,14 +126,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -140,12 +144,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -155,8 +157,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -164,8 +166,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -23,7 +23,9 @@ type DNSBlacklist struct {
AddBlockedIPs []netip.Addr
AddBlockedIPPrefixes []netip.Prefix
// RebindingProtectionExemptHostnames is a list of hostnames
// exempt from DNS rebinding protection.
// exempt from DNS rebinding protection. It can contain parent
// domains which are of the form "*.example.com". Note the wildcard
// can only be used at the start of the hostname.
RebindingProtectionExemptHostnames []string
}
@@ -55,6 +57,9 @@ func (b DNSBlacklist) validate() (err error) {
}
for _, host := range b.RebindingProtectionExemptHostnames {
if len(host) > 2 && host[:2] == "*." {
host = host[2:]
}
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
}
@@ -104,7 +104,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
allowedUDP = []uint16{53, 1194, 1197, 1198, 8080, 9201}
case providers.Protonvpn:
allowedTCP = []uint16{443, 5995, 8443}
allowedUDP = []uint16{80, 443, 1194, 4569, 5060}
allowedUDP = []uint16{80, 443, 1194, 4569, 5060, 51820}
case providers.SlickVPN:
allowedTCP = []uint16{443, 8080, 8888}
allowedUDP = []uint16{443, 8080, 8888}
+111
View File
@@ -0,0 +1,111 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// PMTUD contains settings to configure Path MTU Discovery.
type PMTUD struct {
// ICMPAddresses is the redundancy list of addresses to use
// for ICMP path MTU discovery. Each address MUST handle ICMP
// packets for PMTUD to work.
// It cannot be nil in the internal state.
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
// TCPAddresses is the redundancy list of addresses to use
// for TCP path MTU discovery. Each address MUST have a listening
// TCP server on the port specified.
// It cannot be nil in the internal state.
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
}
}
return nil
}
func (p *PMTUD) copy() (copied PMTUD) {
return PMTUD{
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
}
}
func (p *PMTUD) overrideWith(other PMTUD) {
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
}
func (p *PMTUD) setDefaults() {
defaultICMPAddresses := []netip.Addr{
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
}
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
const dnsPort, tlsPort = 53, 443
defaultTCPAddresses := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
}
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
}
func (p PMTUD) String() string {
return p.toLinesNode().String()
}
func (p PMTUD) toLinesNode() (node *gotree.Node) {
node = gotree.New("Path MTU discovery:")
icmpAddrNode := node.Append("ICMP addresses:")
for _, addr := range p.ICMPAddresses {
icmpAddrNode.Append(addr.String())
}
tcpAddrNode := node.Append("TCP addresses:")
for _, addr := range p.TCPAddresses {
tcpAddrNode.Append(addr.String())
}
return node
}
func (p *PMTUD) read(r *reader.Reader) (err error) {
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
if err != nil {
return err
}
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
if err != nil {
return err
}
return nil
}
@@ -2,6 +2,8 @@ package settings
import (
"fmt"
"slices"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants/providers"
@@ -31,6 +33,11 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
if vpnType == vpn.OpenVPN {
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
mullvadIndex := slices.Index(validNames, providers.Mullvad)
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
validNames = validNames[:len(validNames)-1]
sort.Strings(validNames)
} else { // Wireguard
validNames = []string{
providers.Airvpn,
@@ -29,14 +29,27 @@ func Test_Settings_String(t *testing.T) {
| | └── OpenVPN server selection settings:
| | ├── Protocol: UDP
| | └── Private Internet Access encryption preset: strong
| ── OpenVPN settings:
| ├── OpenVPN version: 2.6
| ├── User: [not set]
| ├── Password: [not set]
| ├── Private Internet Access encryption preset: strong
| ├── Network interface: tun0
| ├── Run OpenVPN as: root
| └── Verbosity level: 1
| ── OpenVPN settings:
| | ├── OpenVPN version: 2.6
| | ├── User: [not set]
| | ├── Password: [not set]
| | ├── Private Internet Access encryption preset: strong
| | ├── Network interface: tun0
| | ├── Run OpenVPN as: root
| | └── Verbosity level: 1
| └── Path MTU discovery:
| ├── ICMP addresses:
| | ├── 1.1.1.1
| | └── 8.8.8.8
| └── TCP addresses:
| ├── 1.1.1.1:53
| ├── 8.8.8.8:53
| ├── 1.1.1.1:443
| ├── 8.8.8.8:443
| ├── [2606:4700:4700::1111]:53
| ├── [2001:4860:4860::8888]:53
| ├── [2606:4700:4700::1111]:443
| └── [2001:4860:4860::8888]:443
├── DNS settings:
| ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1
+15
View File
@@ -18,6 +18,7 @@ type VPN struct {
Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
}
// TODO v4 remove pointer for receiver (because of Surfshark).
@@ -45,6 +46,11 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
}
}
err = v.PMTUD.validate()
if err != nil {
return fmt.Errorf("PMTUD settings: %w", err)
}
return nil
}
@@ -54,6 +60,7 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
}
}
@@ -62,6 +69,7 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
}
func (v *VPN) setDefaults() {
@@ -69,6 +77,7 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
}
func (v VPN) String() string {
@@ -85,6 +94,7 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
} else {
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
return node
}
@@ -107,5 +117,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err)
}
err = v.PMTUD.read(r)
if err != nil {
return fmt.Errorf("PMTUD: %w", err)
}
return nil
}
+10 -14
View File
@@ -38,14 +38,9 @@ type Wireguard struct {
Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be zero in the internal state, and defaults to
// 1320. Note it is not the wireguard-go MTU default of 1420
// because this impacts bandwidth a lot on some VPN providers,
// see https://github.com/qdm12/gluetun/issues/1650.
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
MTU uint16 `json:"mtu"`
// It cannot be nil in the internal state, and defaults to
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -194,8 +189,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
const defaultMTU = 1320
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
}
@@ -231,7 +225,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU)
if *w.MTU == 0 {
interfaceNode.Append("MTU: use path MTU discovery")
} else {
interfaceNode.Appendf("MTU: %d", *w.MTU)
}
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
@@ -272,11 +270,9 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
w.MTU = *mtuPtr
}
return nil
}
+6 -7
View File
@@ -2,7 +2,6 @@ package dns
import (
"context"
"errors"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/gluetun/internal/constants"
@@ -44,7 +43,12 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
runError, err = l.setupServer(ctx)
if err == nil {
l.backoffTime = defaultBackoffTime
l.logger.Info("ready")
l.logger.Info("ready and using DNS server at address " + settings.ServerAddress.String())
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
break
}
@@ -53,11 +57,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if ctx.Err() != nil {
return
}
if !errors.Is(err, errUpdateBlockLists) {
const fallback = true
l.useUnencryptedDNS(fallback)
}
l.logAndWait(ctx, err)
settings = l.GetSettings()
}
+7 -8
View File
@@ -2,24 +2,23 @@ package dns
import (
"context"
"errors"
"fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/dns/v2/pkg/server"
)
var errUpdateBlockLists = errors.New("cannot update filter block lists")
func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err error) {
err = l.updateFiles(ctx)
if err != nil {
return nil, fmt.Errorf("%w: %w", errUpdateBlockLists, err)
}
settings := l.GetSettings()
var updateSettings update.Settings
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
if err != nil {
return nil, fmt.Errorf("updating filter for rebinding protection: %w", err)
}
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger)
if err != nil {
+4 -15
View File
@@ -28,23 +28,12 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
return
case <-timer.C:
lastTick = l.timeNow()
status := l.GetStatus()
if status == constants.Running {
if err := l.updateFiles(ctx); err != nil {
l.statusManager.SetStatus(constants.Crashed)
l.logger.Error(err.Error())
l.logger.Warn("skipping DNS server restart due to failed files update")
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
continue
settings := l.GetSettings()
if l.GetStatus() == constants.Running {
if err := l.updateFiles(ctx, settings); err != nil {
l.logger.Warn("updating block lists failed, skipping: " + err.Error())
}
}
_, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped)
_, _ = l.statusManager.ApplyStatus(ctx, constants.Running)
settings := l.GetSettings()
timer.Reset(*settings.UpdatePeriod)
case <-l.updateTicker:
if !timer.Stop() {
+2 -4
View File
@@ -6,11 +6,10 @@ import (
"github.com/qdm12/dns/v2/pkg/blockbuilder"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/gluetun/internal/configuration/settings"
)
func (l *Loop) updateFiles(ctx context.Context) (err error) {
settings := l.GetSettings()
func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err error) {
l.logger.Info("downloading hostnames and IP block lists")
blacklistSettings := settings.Blacklist.ToBlockBuilderSettings(l.client)
@@ -37,7 +36,6 @@ func (l *Loop) updateFiles(ctx context.Context) (err error) {
IPPrefixes: result.BlockedIPPrefixes,
}
updateSettings.BlockHostnames(result.BlockedHostnames)
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings)
if err != nil {
return fmt.Errorf("updating filter: %w", err)
+30 -60
View File
@@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled {
c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil {
return fmt.Errorf("disabling firewall: %w", err)
}
c.restore(ctx)
c.enabled = false
c.logger.Info("disabled successfully")
return nil
@@ -41,64 +39,33 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil
}
func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}
return nil
}
// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *Config) enable(ctx context.Context) (err error) {
touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
c.restore, err = c.impl.SaveAndRestore(ctx)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
}
defer func() {
if err != nil {
c.restore(context.Background())
}
}()
if err = c.impl.SetBaseChainsPolicy(ctx, "DROP"); err != nil {
return err
}
touched = true
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
return err
}
const remove = false
defer func() {
if touched && err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err
}
@@ -108,7 +75,9 @@ func (c *Config) enable(ctx context.Context) (err error) {
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
network.InterfaceName, network.IP, network.IPNet, remove)
if err != nil {
return err
}
@@ -117,7 +86,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue
}
localInterfaces[network.InterfaceName] = struct{}{}
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
}
@@ -130,7 +99,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
return err
}
}
@@ -139,12 +108,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.redirectPorts(ctx, remove)
err = c.redirectPorts(ctx)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
@@ -164,7 +133,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
continue
}
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil {
return fmt.Errorf("accepting output traffic through VPN: %w", err)
}
@@ -186,7 +155,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
firewallUpdated = true
const remove = false
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
@@ -204,7 +173,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
for port, netInterfaces := range c.allowedInputPorts {
for netInterface := range netInterfaces {
const remove = false
err = c.acceptInputToPort(ctx, netInterface, port, remove)
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, netInterface, err)
@@ -214,9 +183,10 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
return nil
}
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
func (c *Config) redirectPorts(ctx context.Context) (err error) {
for _, portRedirection := range c.portRedirections {
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
+15 -21
View File
@@ -2,28 +2,28 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/routing"
)
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
runner CmdRunner
logger Logger
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
// Fixed state
ipTables string
ip6Tables string
// Fixed
impl firewallImpl
customRulesPath string
// State
enabled bool
restore func(context.Context)
vpnConnection models.Connection
vpnIntf string
outboundSubnets []netip.Prefix
@@ -38,25 +38,19 @@ func NewConfig(ctx context.Context, logger Logger,
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
localNetworks []routing.LocalNetwork,
) (config *Config, err error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
impl, err := iptables.New(ctx, runner, logger)
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
return nil, fmt.Errorf("creating iptables firewall: %w", err)
}
return &Config{
runner: runner,
logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}),
ipTables: iptables,
ip6Tables: ip6tables,
customRulesPath: "/iptables/post-rules.txt",
// Obtained from routing
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
defaultRoutes: defaultRoutes,
localNetworks: localNetworks,
impl: impl,
customRulesPath: "/iptables/post-rules.txt",
}, nil
}
+28 -1
View File
@@ -1,6 +1,12 @@
package firewall
import "os/exec"
import (
"context"
"net/netip"
"os/exec"
"github.com/qdm12/gluetun/internal/models"
)
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
@@ -12,3 +18,24 @@ type Logger interface {
Warn(s string)
Error(s string)
}
type firewallImpl interface { //nolint:interfacebloat
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
connection models.Connection, remove bool) error
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string) error
SetBaseChainsPolicy(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error)
Version(ctx context.Context) (version string, err error)
}
+85
View File
@@ -0,0 +1,85 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strings"
)
// SaveAndRestore saves the current iptables and ip6tables rules and
// returns a restore function that can be called to restore the saved rules.
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
return c.saveAndRestore(ctx)
}
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
// before calling this function. Note the restore function does not interact with mutexes
// so the caller must make sure the mutexes are locked when calling the restore function.
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return nil, err
}
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return nil, err
}
restore = func(ctx context.Context) {
restoreIPv4(ctx)
if restoreIPv6 != nil {
restoreIPv6(ctx)
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
if c.ip6Tables == "" {
return nil, nil //nolint:nilnil
}
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"fmt"
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
+51
View File
@@ -0,0 +1,51 @@
package iptables
import (
"context"
"sync"
"github.com/qdm12/gluetun/internal/mod"
)
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
// Fixed state
ipTables string
ip6Tables string
nftables bool
xtMark bool
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
modules := map[string]bool{
"xt_mark": false,
"nf_tables": false,
}
for module := range modules {
err := mod.Probe(module)
modules[module] = err == nil
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
nftables: modules["nf_tables"],
xtMark: modules["xt_mark"],
}, nil
}
+14
View File
@@ -0,0 +1,14 @@
package iptables
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -14,8 +14,8 @@ import (
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
ip6tablesPath string, err error,
) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
if errors.Is(err, ErrIPTablesNotSupported) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
if errors.Is(err, ErrNotSupported) {
return "", nil
} else if err != nil {
return "", err
@@ -24,8 +24,23 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
}
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
}
@@ -33,11 +48,24 @@ func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []st
}
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
if c.ip6Tables == "" {
return nil
}
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
@@ -53,18 +81,3 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
}
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -26,22 +26,6 @@ func appendOrDelete(remove bool) string {
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
}
return rule
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
@@ -54,12 +38,28 @@ func (c *Config) Version(ctx context.Context) (string, error) {
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
}
return words[1], nil
return "iptables " + words[1], nil
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
}
@@ -70,6 +70,19 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
@@ -85,42 +98,33 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
return nil
}
func (c *Config) clearAllRules(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
}
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
func (c *Config) SetBaseChainsPolicy(ctx context.Context, policy string) error {
policy = strings.ToUpper(policy)
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runIptablesInstructions(ctx, []string{
return c.runMixedIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
"--append INPUT -i %s -j ACCEPT", intf))
}
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool,
) error {
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
@@ -131,20 +135,20 @@ func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
})
}
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
protocol := connection.Protocol
@@ -162,8 +166,11 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
// Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
@@ -184,21 +191,24 @@ func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4).
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool,
) error {
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
appendOrDelete(remove), interfaceFlag)
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
// Used for port forwarding, with intf set to tun.
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
@@ -209,8 +219,12 @@ func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16
})
}
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) redirectPort(ctx context.Context, intf string,
// RedirectPort redirects incoming traffic on the specified source port to the
// specified destination port, for both TCP and UDP protocols, on the interface intf.
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
// the redirection is removed instead of added. This is used for VPN server side
// port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) RedirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool,
) (err error) {
interfaceFlag := "-i " + intf
@@ -218,7 +232,17 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
interfaceFlag = ""
}
err = c.runIptablesInstructions(ctx, []string{
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -229,11 +253,12 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx)
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructions(ctx, []string{
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -244,6 +269,7 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx) // just in case
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
@@ -257,7 +283,7 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
return nil
}
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
@@ -273,16 +299,17 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
return err
}
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, line := range lines {
var ipv4 bool
var rule string
@@ -309,23 +336,18 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
continue
}
if remove {
rule = flipRule(rule)
}
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
err = c.runIptablesInstructionNoSave(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstruction(ctx, rule)
err = c.runIP6tablesInstructionNoSave(ctx, rule)
}
if err != nil {
restore(ctx)
return err
}
successfulRules = append(successfulRules, rule)
}
return nil
}
+49
View File
@@ -0,0 +1,49 @@
package iptables
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, instruction := range instructions {
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
restore(ctx)
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runMixedIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstructionNoSave(ctx, instruction)
}
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"errors"
@@ -26,10 +26,18 @@ type chainRule struct {
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
sourcePort uint16 // Not specified if set to zero.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
}
type mark struct {
invert bool
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -241,19 +249,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i := 0
for i < len(optionalFields) {
switch optionalFields[i] {
case "udp":
i++
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
consumed, err := parseUDPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err)
return fmt.Errorf("parsing UDP optional fields: %w", err)
}
rule.destinationPort = uint16(destinationPort)
i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
case "redir":
i++
switch optionalFields[i] {
@@ -264,20 +276,136 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
i++
case "mark":
i++
mark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing mark: %w", err)
}
rule.mark = mark
i += consumed
default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
return parsePort(value)
}
func parseSourcePort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "spt:")
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
@@ -286,16 +414,40 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
ports[i], err = parsePort(field)
if err != nil {
return nil, fmt.Errorf("parsing port %q: %w", field, err)
return nil, err
}
ports[i] = uint16(port)
}
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
consumed++
if optionalFields[consumed] == "!" {
m.invert = true
consumed++
}
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"net/netip"
@@ -1,3 +1,3 @@
package firewall
package iptables
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger)
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
// Package firewall is a generated GoMock package.
package firewall
// Package iptables is a generated GoMock package.
package iptables
import (
exec "os/exec"
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"errors"
@@ -18,10 +18,13 @@ type iptablesInstruction struct {
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
sourcePort uint16 // if zero, there is no source port
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
}
func (i *iptablesInstruction) setDefaults() {
@@ -43,6 +46,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case i.destinationPort != rule.destinationPort:
return false
case i.sourcePort != rule.sourcePort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
@@ -55,6 +60,11 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
case i.mark != rule.mark:
return false
default:
return true
}
@@ -77,26 +87,29 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
consumed, err = preCheckInstructionFields(fields)
if err != nil {
return 0, err
}
flag := fields[0]
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
@@ -109,7 +122,19 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-m", "--match":
consumed, err = parseMatchModule(fields, instruction)
if err != nil {
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
}
instruction.mark.value = uint(value)
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
@@ -117,37 +142,61 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "--sport":
instruction.sourcePort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
instruction.destinationPort, err = parsePort(value)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
return 0, fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
instruction.toPorts, err = parseToPorts(value)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
}
return consumed, nil
}
func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
return expected, nil
}
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
@@ -162,3 +211,52 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}
func parsePort(value string) (port uint16, err error) {
const base, bitLength = 10, 16
portValue, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, err
}
return uint16(portValue), nil
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
_ = fields[consumed] // -m or --match flag already detected
consumed++
switch fields[consumed] {
case "tcp", "udp":
consumed++
// for now ignore the protocol match since it's auto-loaded
// when parsing the -p/--protocol flag, and we don't need to
// parse it twice.
case "mark":
consumed++
switch fields[consumed] {
case "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
}
return consumed, nil
}
func parseToPorts(value string) (toPorts []uint16, err error) {
portStrings := strings.Split(value, ",")
toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
toPorts[i], err = parsePort(portString)
if err != nil {
return nil, err
}
}
return toPorts, nil
}
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"net/netip"
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
},
"unknown_key": {
s: "-x something",
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -11,10 +11,10 @@ import (
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrIPTablesNotSupported = errors.New("no iptables supported found")
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
)
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
}
return "", fmt.Errorf("%w: errors encountered are: %s",
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; "))
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
}
func testIptablesPath(ctx context.Context, path string,
@@ -1,4 +1,4 @@
package firewall
package iptables
import (
"context"
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrIPTablesNotSupported,
errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
+96
View File
@@ -0,0 +1,96 @@
package iptables
import (
"context"
"errors"
"fmt"
"net/netip"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
if !c.nftables && !c.xtMark {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
run := c.runIptablesInstruction
if dst.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}
-21
View File
@@ -1,21 +0,0 @@
package firewall
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstruction(ctx, instruction)
}
+99
View File
@@ -0,0 +1,99 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
)
// SaveAndRestore saves the current nftables tree and returns a restore function that
// can be called to restore the saved tree.
func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("creating nftables connection: %w", err)
}
tables, err := saveTables(conn)
if err != nil {
return nil, fmt.Errorf("saving nftables state: %w", err)
}
return func(_ context.Context) {
conn, err := nftables.New()
if err != nil {
f.logger.Warnf("creating nftables connection for restore: %s", err)
return
}
err = restoreTables(conn, tables)
if err != nil {
f.logger.Warnf("restoring nftables state: %s", err)
}
}, nil
}
type savedTable struct {
table *nftables.Table
chains []savedChain
}
type savedChain struct {
chain *nftables.Chain
rules []*nftables.Rule
}
func saveTables(conn *nftables.Conn) ([]savedTable, error) {
tables, err := conn.ListTables()
if err != nil {
return nil, err
}
savedTables := make([]savedTable, len(tables))
for i, table := range tables {
savedTables[i].table = table
chains, err := conn.ListChains()
if err != nil {
return nil, err
}
for _, chain := range chains {
if chain.Table.Name != table.Name ||
chain.Table.Family != table.Family {
continue
}
rules, err := conn.GetRules(table, chain)
if err != nil {
return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err)
}
savedChain := savedChain{chain: chain, rules: rules}
savedTables[i].chains = append(savedTables[i].chains, savedChain)
}
}
return savedTables, nil
}
func restoreTables(conn *nftables.Conn, savedTables []savedTable) error {
conn.FlushRuleset()
for _, savedTable := range savedTables {
table := conn.AddTable(savedTable.table)
for _, savedChain := range savedTable.chains {
// Make the [nftables.Chain.Table] points to the new [nftables.Table]
// created in this connection.
savedChain.chain.Table = table
savedChain.chain = conn.AddChain(savedChain.chain)
for _, rule := range savedChain.rules {
rule.Table = table
rule.Chain = savedChain.chain
conn.AddRule(rule)
}
}
}
return conn.Flush()
}
+50
View File
@@ -0,0 +1,50 @@
package nftables
import (
"context"
"errors"
"fmt"
"strings"
"github.com/google/nftables"
)
var ErrPolicyUnknown = errors.New("unknown policy")
// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT)
// for the filter table to the given policy (accept or drop).
func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
var chainPolicy nftables.ChainPolicy
switch strings.ToLower(policy) {
case "accept":
chainPolicy = nftables.ChainPolicyAccept
case "drop":
chainPolicy = nftables.ChainPolicyDrop
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
_, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn)
inputChain.Policy = &chainPolicy
forwardChain.Policy = &chainPolicy
outputChain.Policy = &chainPolicy
conn.AddChain(inputChain)
conn.AddChain(forwardChain)
conn.AddChain(outputChain)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing nftables changes: %w", err)
}
return nil
}
+61
View File
@@ -0,0 +1,61 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptEstablishedRelatedTraffic(_ context.Context) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, outputChain := setupFilterWithBaseChains(conn)
ctStateExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4, //nolint:mnd
Mask: []byte{byte(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), 0x00, 0x00, 0x00},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
conn.AddRule(&nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: ctStateExprs,
})
conn.AddRule(&nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: ctStateExprs,
})
if err := conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+27
View File
@@ -0,0 +1,27 @@
package nftables
import (
"errors"
"fmt"
"reflect"
"github.com/google/nftables"
)
var errRuleToDeleteNotFound = errors.New("rule not found for removal")
func (f *Firewall) deleteRule(conn *nftables.Conn, rule *nftables.Rule) error {
for i, existing := range f.rules {
if !reflect.DeepEqual(existing, rule) {
continue
}
err := conn.DelRule(existing)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
f.rules[i], f.rules[len(f.rules)-1] = f.rules[len(f.rules)-1], f.rules[i]
f.rules = f.rules[:len(f.rules)-1]
return nil
}
return fmt.Errorf("%w: %#v", errRuleToDeleteNotFound, rule)
}
+38
View File
@@ -0,0 +1,38 @@
package nftables
import "github.com/google/nftables"
func setupFilterWithBaseChains(conn *nftables.Conn) (table *nftables.Table,
inputChain, forwardChain, outputChain *nftables.Chain,
) {
table = conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "filter",
})
inputChain = conn.AddChain(&nftables.Chain{
Name: "input",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
})
forwardChain = conn.AddChain(&nftables.Chain{
Name: "forward",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
outputChain = conn.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
})
return table, inputChain, forwardChain, outputChain
}
+22
View File
@@ -0,0 +1,22 @@
package nftables
import (
"sync"
"github.com/google/nftables"
)
type Firewall struct {
logger Logger
// rules are only rules added and tracked for later removal.
// Not all rules added are tracked for removal.
rules []*nftables.Rule
mutex sync.Mutex
}
func New(logger Logger) *Firewall {
return &Firewall{
logger: logger,
}
}
+170
View File
@@ -0,0 +1,170 @@
package nftables
import (
"context"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(intf + "\x00"),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
const tcp, udp uint8 = 6, 17
protocols := []uint8{tcp, udp}
for _, protocol := range protocols {
const maxExprsLen = 7
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
if !remove {
conn.AddRule(rule)
f.rules = append(f.rules, rule)
continue
}
err = f.deleteRule(conn, rule)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
}
err = conn.Flush()
if err != nil {
f.rules = f.rules[:len(f.rules)-len(protocols)]
return fmt.Errorf("flushing: %w", err)
}
return nil
}
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
const maxExprsLen = 5
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
var payloadOffset uint32
if subnet.Addr().Is4() {
payloadOffset = 16
} else {
payloadOffset = 24
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: payloadOffset,
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: subnet.Addr().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+5
View File
@@ -0,0 +1,5 @@
package nftables
type Logger interface {
Warnf(format string, args ...any)
}
+78
View File
@@ -0,0 +1,78 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptIpv6MulticastOutput(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, _, _, outputChain := setupFilterWithBaseChains(conn)
const maxExprsLen = 6
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
// ff02::1:ff00:0/104 mask is 13 bytes of 0xff
mask := []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
addr := []byte{
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 24, // IPv6 Destination Address offset //nolint:mnd
Len: 16, //nolint:mnd
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 16, //nolint:mnd
Mask: mask,
Xor: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //nolint:mnd
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: addr,
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+12
View File
@@ -0,0 +1,12 @@
package nftables
import "github.com/google/nftables"
func IsSupported() bool {
conn, err := nftables.New()
if err != nil {
return false
}
_, err = conn.ListTable("filter")
return err == nil
}
+2 -2
View File
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
}
firewallUpdated = true
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subNet, remove)
if err != nil {
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
}
firewallUpdated = true
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
+2 -2
View File
@@ -35,7 +35,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
const remove = false
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err)
}
@@ -68,7 +68,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
const remove = true
for netInterface := range interfacesSet {
err := c.acceptInputToPort(ctx, netInterface, port, remove)
err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("removing allowed port %d on interface %s: %w",
port, netInterface, err)
+2 -2
View File
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
return nil
case conflict != nil:
const remove = true
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove)
if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err)
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
}
const remove = false
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
+4 -4
View File
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove := true
if c.vpnConnection.IP.IsValid() {
for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
}
}
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
c.vpnConnection = models.Connection{}
if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
}
}
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove = false
for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
}
}
c.vpnConnection = connection
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf
+21
View File
@@ -0,0 +1,21 @@
package firewall
import (
"context"
"net/netip"
)
func (c *Config) Version(ctx context.Context) (version string, err error) {
return c.impl.Version(ctx)
}
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
}
+47 -13
View File
@@ -23,6 +23,7 @@ type Checker struct {
logger Logger
icmpTargetIPs []netip.Addr
smallCheckType string
startupOnFail bool
configMutex sync.Mutex
icmpNotPermitted *bool
@@ -45,26 +46,43 @@ func NewChecker(logger Logger) *Checker {
}
}
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
// to target and the desired small check type (dns or icmp).
// SetConfig sets the following:
// - TCP+TLS dial addresses
// - ICMP echo IP addresses to target
// - the desired small check type (dns or icmp)
// - whether to startup the periodic checks if the startup check fails.
// This function MUST be called before calling [Checker.Start].
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
smallCheckType string,
smallCheckType string, startupOnFail bool,
) {
c.configMutex.Lock()
defer c.configMutex.Unlock()
c.tlsDialAddrs = tlsDialAddrs
c.icmpTargetIPs = icmpTargets
c.smallCheckType = smallCheckType
c.startupOnFail = startupOnFail
}
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
// and, on success, starts the periodic checks in a separate goroutine:
// Start starts the [Checker] which behaves differently according to its
// internal field startupOnFail, which is set by calling [Checker.SetConfig].
//
// By default, startupOnFail should be false and the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// an error is returned and the [Checker] is not started.
// On success, it starts the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// If startupOnFail is true, the behavior is as follows:
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
// the error is sent to the runError channel, but no error is returned
// and the [Checker] continues to start the periodic checks in a separate goroutine, returning
// the runError error channel and a nil error.
//
// The periodic checks consist in:
// - a "small" ICMP echo check every minute
// - a "full" TCP+TLS check every 5 minutes
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
// It returns an error if the initial TCP+TLS check fails.
// The Checker has to be ultimately stopped by calling [Checker.Stop].
//
// The [Checker] has to be ultimately stopped by calling [Checker.Stop].
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
panic("call Checker.SetConfig with non empty values before Checker.Start")
@@ -76,9 +94,19 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
}
c.echoer.Reset()
// runErrorCh MUST be buffered in the case startupOnFail is true, and
// a startup error was encountered, to avoid blocking the startup
// goroutine when sending the error, especially since the caller may
// not be ready to receive from the channel yet.
runErrorCh := make(chan error, 1)
runError = runErrorCh
err = c.startupCheck(ctx)
if err != nil {
return nil, fmt.Errorf("startup check: %w", err)
err = fmt.Errorf("startup check: %w", err)
if !c.startupOnFail {
return nil, err
}
runErrorCh <- err
}
ready := make(chan struct{})
@@ -90,8 +118,6 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
smallCheckTimer := time.NewTimer(smallCheckPeriod)
const fullCheckPeriod = 5 * time.Minute
fullCheckTimer := time.NewTimer(fullCheckPeriod)
runErrorCh := make(chan error)
runError = runErrorCh
go func() {
defer close(done)
close(ready)
@@ -106,14 +132,22 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
if err != nil {
err = fmt.Errorf("small periodic check: %w", err)
}
runErrorCh <- err
select {
case <-ctx.Done():
continue
case runErrorCh <- err:
}
smallCheckTimer.Reset(smallCheckPeriod)
case <-fullCheckTimer.C:
err := c.fullPeriodicCheck(ctx)
if err != nil {
err = fmt.Errorf("full periodic check: %w", err)
}
runErrorCh <- err
select {
case <-ctx.Done():
continue
case runErrorCh <- err:
}
fullCheckTimer.Reset(fullCheckPeriod)
}
}
+4
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"time"
)
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
@@ -21,6 +22,9 @@ func NewClient(httpClient *http.Client) *Client {
}
func (c *Client) Check(ctx context.Context, url string) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
-1
View File
@@ -13,7 +13,6 @@ func (h *handler) isAuthorized(responseWriter http.ResponseWriter, request *http
}
basicAuth := request.Header.Get("Proxy-Authorization")
if basicAuth == "" {
h.logger.Info("Proxy-Authorization header not found from " + request.RemoteAddr)
responseWriter.Header().Set("Proxy-Authenticate", `Basic realm="Access to Gluetun over HTTP"`)
responseWriter.WriteHeader(http.StatusProxyAuthRequired)
return false
+33
View File
@@ -0,0 +1,33 @@
package mod
import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
)
var errBuiltinModuleNotFound = errors.New("builtin module not found")
func checkModulesBuiltin(modulesPath, moduleName string) error {
f, err := os.Open(filepath.Join(modulesPath, "modules.builtin"))
if err != nil {
return err
}
defer f.Close()
moduleName = strings.TrimSuffix(moduleName, ".ko")
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSuffix(line, ".ko")
if strings.HasSuffix(line, "/"+moduleName) {
return nil
}
}
return fmt.Errorf("%w: %s", errBuiltinModuleNotFound, moduleName)
}
+132
View File
@@ -0,0 +1,132 @@
package mod
import (
"bufio"
"compress/gzip"
"errors"
"fmt"
"os"
"strings"
)
var (
errModuleNameUnknown = errors.New("unknown module name")
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
errKernelFeatureNotSet = errors.New("kernel feature not set")
errKernelFeatureNotFound = errors.New("kernel feature not found")
)
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
// feature is a module, not built-in.
// If the kernel feature is found and not set, it returns an error indicating that the kernel
// feature is not set. If the kernel feature is not found, it returns an error indicating that the kernel
// feature is not found.
func checkProcConfig(moduleName string) error {
f, err := os.Open("/proc/config.gz")
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("creating gzip reader: %w", err)
}
defer gz.Close()
// If any group of kernel features is satisfied, then the module is considered supported.
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
if !ok {
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
}
groups := make([]map[string]bool, len(kernelFeatureGroups))
for i, group := range kernelFeatureGroups {
featureToOK := make(map[string]bool)
for _, feature := range group {
featureToOK[feature] = false
}
groups[i] = featureToOK
}
scanner := bufio.NewScanner(gz)
for scanner.Scan() {
line := scanner.Text()
for _, featureToOK := range groups {
for name, ok := range featureToOK {
switch {
case ok:
case strings.HasPrefix(line, name+"=m"):
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
case strings.HasPrefix(line, name+"=y"):
featureToOK[name] = true
if allFeaturesOK(featureToOK) {
return nil
}
case strings.HasPrefix(line, "# "+name+" is not set"):
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
}
}
}
}
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
}
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
moduleMap := map[string][][]string{
"nf_tables": {{"CONFIG_NF_TABLES"}},
// Netfilter Matches
"xt_conntrack": {{"CONFIG_NETFILTER_XT_MATCH_CONNTRACK"}},
"xt_connmark": {
{"CONFIG_NETFILTER_XT_CONNMARK"},
{"CONFIG_NETFILTER_XT_MATCH_CONNMARK", "CONFIG_NETFILTER_XT_TARGET_CONNMARK"},
},
"xt_mark": {
{"CONFIG_NETFILTER_XT_MARK"},
{"CONFIG_NETFILTER_XT_MATCH_MARK", "CONFIG_NETFILTER_XT_TARGET_MARK"},
},
"nf_conntrack_netlink": {{"CONFIG_NF_CT_NETLINK"}},
"nf_reject_ipv4": {{"CONFIG_NF_REJECT_IPV4"}},
// Common Netfilter Targets
"xt_log": {{"CONFIG_NETFILTER_XT_TARGET_LOG"}},
"xt_reject": {
{"CONFIG_IP_NF_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
{"CONFIG_NETFILTER_XT_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
},
"xt_masquerade": {{"CONFIG_NETFILTER_XT_TARGET_MASQUERADE"}},
// Additional Netfilter Matches
"xt_addrtype": {{"CONFIG_NETFILTER_XT_MATCH_ADDRTYPE"}},
"xt_comment": {{"CONFIG_NETFILTER_XT_MATCH_COMMENT"}},
"xt_multiport": {{"CONFIG_NETFILTER_XT_MATCH_MULTIPORT"}},
"xt_state": {{"CONFIG_NETFILTER_XT_MATCH_STATE"}},
"xt_tcpudp": {{"CONFIG_NETFILTER_XT_MATCH_TCPUDP"}},
// Tunneling and Virtualization
"tun": {{"CONFIG_TUN"}},
"bridge": {{"CONFIG_BRIDGE"}},
"veth": {{"CONFIG_VETH"}},
"vxlan": {{"CONFIG_VXLAN"}},
"wireguard": {{"CONFIG_WIREGUARD"}},
// Filesystems
"overlay": {{"CONFIG_OVERLAY_FS"}},
"fuse": {{"CONFIG_FUSE_FS"}},
}
featureGroups, ok = moduleMap[strings.ToLower(moduleName)]
return featureGroups, ok
}
func allFeaturesOK(featureToOK map[string]bool) bool {
for _, ok := range featureToOK {
if !ok {
return false
}
}
return true
}
@@ -1,3 +1,5 @@
//go:build !windows
package mod
import (
@@ -28,36 +30,7 @@ type moduleInfo struct {
var ErrModulesDirectoryNotFound = errors.New("modules directory not found")
func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return nil, fmt.Errorf("getting unix uname release: %w", err)
}
release := unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
var modulesPath string
var found bool
for _, modulesPath = range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err error) {
dependencyFilepath := filepath.Join(modulesPath, "modules.dep")
dependencyFile, err := os.Open(dependencyFilepath)
if err != nil {
@@ -109,6 +82,39 @@ func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
return modulesInfo, nil
}
func getModulesPath() (string, error) {
release, err := getReleaseName()
if err != nil {
return "", fmt.Errorf("getting release name: %w", err)
}
modulePaths := []string{
filepath.Join("/lib/modules", release),
filepath.Join("/usr/lib/modules", release),
}
for _, modulesPath := range modulePaths {
info, err := os.Stat(modulesPath)
if err == nil && info.IsDir() {
return modulesPath, nil
}
}
return "", fmt.Errorf("%w: %s are not valid existing directories"+
"; have you bind mounted the /lib/modules directory?",
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
}
func getReleaseName() (release string, err error) {
var utsName unix.Utsname
err = unix.Uname(&utsName)
if err != nil {
return "", fmt.Errorf("getting unix uname release: %w", err)
}
release = unix.ByteSliceToString(utsName.Release[:])
release = strings.TrimSpace(release)
return release, nil
}
func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error {
file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin"))
if err != nil {
-37
View File
@@ -1,37 +0,0 @@
package mod
import (
"fmt"
)
// Probe loads the given kernel module and its dependencies.
func Probe(moduleName string) error {
modulesInfo, err := getModulesInfo()
if err != nil {
return fmt.Errorf("getting modules information: %w", err)
}
modulePath, err := findModulePath(moduleName, modulesInfo)
if err != nil {
return fmt.Errorf("finding module path: %w", err)
}
info := modulesInfo[modulePath]
if info.state == builtin || info.state == loaded {
return nil
}
info.state = loading
for _, dependencyModulePath := range info.dependencyPaths {
err = initDependencies(dependencyModulePath, modulesInfo)
if err != nil {
return fmt.Errorf("init dependencies: %w", err)
}
}
err = initModule(modulePath)
if err != nil {
return fmt.Errorf("init module: %w", err)
}
return nil
}
+74
View File
@@ -0,0 +1,74 @@
package mod
import (
"errors"
"fmt"
)
// Probe is a expanded version of modprobe, in which it checks if the Kernel
// built-in features contain the given module name.
// It first tries to locate the modules directory in [getModulesPath].
// If it fails (like on WSL), it then only checks for the kernel feature
// in /proc/config.gz with [checkProcConfig].
// Otherwise, it first checks if the modules directory modules.builtin
// file contains the given module name in [checkModulesBuiltin].
// If the module is not found, it then runs the classic [modProbe] behavior,
// trying to load the module in the kernel.
// If this fails, it does one final try running [checkProcConfig].
func Probe(moduleName string) error {
modulesPath, err := getModulesPath()
if err != nil {
if errors.Is(err, ErrModulesDirectoryNotFound) {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
return nil
}
return fmt.Errorf("getting modules path: %w", err)
}
err = checkModulesBuiltin(modulesPath, moduleName)
if err != nil {
err = modProbe(modulesPath, moduleName)
if err != nil {
err = checkProcConfig(moduleName)
if err != nil {
return fmt.Errorf("checking /proc/config.gz: %w", err)
}
}
}
return nil
}
// modProbe is the classic modprobe behavior.
func modProbe(modulesPath, moduleName string) error {
modulesInfo, err := getModulesInfo(modulesPath)
if err != nil {
return fmt.Errorf("getting modules information: %w", err)
}
modulePath, err := findModulePath(moduleName, modulesInfo)
if err != nil {
return fmt.Errorf("finding module path: %w", err)
}
info := modulesInfo[modulePath]
if info.state == builtin || info.state == loaded {
return nil
}
info.state = loading
for _, dependencyModulePath := range info.dependencyPaths {
err = initDependencies(dependencyModulePath, modulesInfo)
if err != nil {
return fmt.Errorf("init dependencies: %w", err)
}
}
err = initModule(modulePath)
if err != nil {
return fmt.Errorf("init module: %w", err)
}
return nil
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux
package mod
func Probe(moduleName string) error {
panic("not implemented")
}
+59 -17
View File
@@ -1,33 +1,75 @@
//go:build linux || darwin
package netlink
import (
"github.com/vishvananda/netlink"
"fmt"
"net"
"net/netip"
"github.com/jsimonetti/rtnetlink/rtnl"
)
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
ipPrefixes []netip.Prefix, err error,
) {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
conn, err := rtnl.Dial(nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ifc := &net.Interface{
Index: int(linkIndex),
}
ipNets, err := conn.Addrs(ifc, int(family))
if err != nil {
return nil, fmt.Errorf("failed to list addresses: %w", err)
}
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
ipPrefixes = make([]netip.Prefix, len(ipNets))
for i := range ipNets {
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
}
return addresses, nil
return ipPrefixes, nil
}
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
conn, err := rtnl.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ipNet := netipPrefixToIPNet(prefix)
// Remove any address identical to the one we want to add
family := FamilyV4
if prefix.Addr().Is6() {
family = FamilyV6
}
ifc := &net.Interface{
Index: int(linkIndex),
}
addresses, err := conn.Addrs(ifc, int(family))
if err != nil {
return fmt.Errorf("listing addresses: %w", err)
}
for _, address := range addresses {
if address.IP.Equal(ipNet.IP) &&
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
err = conn.AddrDel(ifc, address)
if err != nil {
return fmt.Errorf("deleting address from interface: %w", err)
}
break
}
}
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
// Add the new address to the interface
err = conn.AddrAdd(ifc, ipNet)
if err != nil {
return fmt.Errorf("adding address to interface: %w", err)
}
return nil
}
-13
View File
@@ -1,13 +0,0 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
) {
panic("not implemented")
}
func (n *NetLink) AddrReplace(Link, Addr) error {
panic("not implemented")
}
+38
View File
@@ -0,0 +1,38 @@
package netlink
import (
"fmt"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
)
func (n *NetLink) FlushConntrack() error {
conn, err := netfilter.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netfilter: %w", err)
}
defer conn.Close()
families := [...]netfilter.ProtoFamily{netfilter.ProtoIPv4, netfilter.ProtoIPv6}
for _, family := range families {
const IPCtnlMsgCtDelete = 2
request, err := netfilter.MarshalNetlink(
netfilter.Header{
SubsystemID: netfilter.NFSubsysCTNetlink,
MessageType: netfilter.MessageType(IPCtnlMsgCtDelete),
Family: family,
Flags: netlink.Request | netlink.Acknowledge,
},
nil)
if err != nil {
return fmt.Errorf("encoding netlink request: %w", err)
}
_, err = conn.Query(request)
if err != nil {
return fmt.Errorf("querying netlink request: %w", err)
}
}
return nil
}
@@ -2,6 +2,6 @@
package netlink
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
func (n *NetLink) FlushConntrack() error {
panic("not implemented")
}
+24
View File
@@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
return netip.PrefixFrom(ip, bits)
}
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
if ip == nil || len(*ip) == 0 {
return netip.Prefix{}
}
var dstIP netip.Addr
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
dstIP = netip.AddrFrom4([4]byte(*ip))
} else {
dstIP = netip.AddrFrom16([16]byte(*ip))
}
return netip.PrefixFrom(dstIP, int(length))
}
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
if !prefix.IsValid() {
return nil, 0
}
prefixIP := prefix.Addr().Unmap()
ip = new(net.IP)
*ip = netipAddrToNetIP(prefixIP)
length = uint8(prefix.Bits()) //nolint:gosec
return ip, length
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
+1 -7
View File
@@ -4,13 +4,7 @@ import (
"fmt"
)
const (
FamilyAll = 0
FamilyV4 = 2
FamilyV6 = 10
)
func FamilyToString(family int) string {
func FamilyToString(family uint8) string {
switch family {
case FamilyAll:
return "all"
+9
View File
@@ -0,0 +1,9 @@
package netlink
import "golang.org/x/sys/unix"
const (
FamilyAll uint8 = unix.AF_UNSPEC
FamilyV4 uint8 = unix.AF_INET
FamilyV6 uint8 = unix.AF_INET6
)
+25 -1
View File
@@ -1,8 +1,32 @@
package netlink
import "net/netip"
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/log"
)
func ptrTo[T any](v T) *T { return &v }
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
name := make([]byte, 8)
for i := range name {
name[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(name)
}
type noopLogger struct{}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Patch(_ ...log.Option) {}
+1 -1
View File
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
return false, fmt.Errorf("finding link corresponding to route: %w", err)
}
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch {
case !sourceIsIPv6 && !destinationIsIPv6,
+162 -76
View File
@@ -1,105 +1,191 @@
//go:build linux || darwin
package netlink
import "github.com/vishvananda/netlink"
import (
"errors"
"fmt"
"github.com/jsimonetti/rtnetlink"
)
type DeviceType uint16
type Link struct {
Index uint32
Name string
DeviceType DeviceType
VirtualType string
MTU uint32
}
func (n *NetLink) LinkList() (links []Link, err error) {
netlinkLinks, err := netlink.LinkList()
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkMessages, err := conn.Link.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
links = make([]Link, len(linkMessages))
for i, message := range linkMessages {
virtualType := ""
if message.Attributes.Info != nil {
virtualType = message.Attributes.Info.Kind
}
links[i] = Link{
Index: message.Index,
Name: message.Attributes.Name,
DeviceType: DeviceType(message.Type),
VirtualType: virtualType,
MTU: message.Attributes.MTU,
}
}
return links, nil
}
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) {
netlinkLink, err := netlink.LinkByName(name)
links, err := n.LinkList()
if err != nil {
return Link{}, err
return Link{}, fmt.Errorf("listing links: %w", err)
}
return netlinkLinkToLink(netlinkLink), nil
for _, link := range links {
if link.Name == name {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
netlinkLink, err := netlink.LinkByIndex(index)
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
links, err := n.LinkList()
if err != nil {
return Link{}, err
return Link{}, fmt.Errorf("listing links: %w", err)
}
return netlinkLinkToLink(netlinkLink), nil
for _, link = range links {
if link.Index == index {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return 0, err
return 0, fmt.Errorf("dialing netlink: %w", err)
}
return netlinkLink.Attrs().Index, nil
}
defer conn.Close()
func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU), //nolint:gosec
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
tx := &rtnetlink.LinkMessage{
Type: uint16(link.DeviceType),
Attributes: &rtnetlink.LinkAttributes{
MTU: link.MTU,
Name: link.Name,
},
}
if link.VirtualType != "" {
tx.Attributes.Info = &rtnetlink.LinkInfo{
Kind: link.VirtualType,
}
}
err = conn.Link.New(tx)
if err != nil {
return 0, fmt.Errorf("creating new link: %w", err)
}
linkMessages, err := conn.Link.List()
if err != nil {
return 0, fmt.Errorf("listing links: %w", err)
}
for _, linkMessage := range linkMessages {
if linkMessage.Attributes.Name == link.Name {
return linkMessage.Index, nil
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Link.Delete(linkIndex)
}
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
rx, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
tx := &rtnetlink.LinkMessage{
Type: rx.Type,
Index: linkIndex,
Flags: iffUp,
Change: iffUp,
}
return conn.Link.Set(tx)
}
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkInfo, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
message := &rtnetlink.LinkMessage{
Type: linkInfo.Type,
Index: linkIndex,
Flags: 0,
Change: iffUp,
}
return conn.Link.Set(message)
}
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
message := &rtnetlink.LinkMessage{
Index: linkIndex,
Attributes: &rtnetlink.LinkAttributes{
MTU: mtu,
},
}
err = conn.Link.Set(message)
if err != nil {
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
mtu, linkIndex, err)
}
return nil
}
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
iffUp = unix.IFF_UP
)
+85
View File
@@ -0,0 +1,85 @@
//go:build linux
package netlink
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NetLink_LinkList(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
initialLinks, err := netlink.LinkList()
require.NoError(t, err)
require.NotEmpty(t, initialLinks)
loopbackFound := false
for _, link := range initialLinks {
if link.Name != "lo" {
continue
}
loopbackFound = true
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
break
}
assert.True(t, loopbackFound, "loopback interface not found")
testLink := Link{
Name: makeLinkName(),
// note if [Link.VirtualType] is set, [Link.DeviceType]
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
links, err := netlink.LinkList()
require.NoError(t, err)
testLink.Index = index
for _, link := range links {
if link.Name != testLink.Name {
continue
}
assert.Equal(t, testLink, link)
return
}
t.Errorf("created link %q not found", testLink.Name)
}
func Test_NetLink_LinkSetMTU(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
testLink := Link{
Name: makeLinkName(),
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
testLink.Index = index
err = netlink.LinkSetMTU(index, 1500)
require.NoError(t, err)
link, err := netlink.LinkByIndex(index)
require.NoError(t, err)
testLink.MTU = 1500
assert.Equal(t, testLink, link)
}
-31
View File
@@ -1,31 +0,0 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) LinkList() (links []Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByName(name string) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkDel(link Link) (err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
panic("not implemented")
}
+56
View File
@@ -0,0 +1,56 @@
//go:build !linux
package netlink
const (
// FamilyAll is a placeholder only and should not
// be used.
FamilyAll uint8 = iota
// FamilyV4 is a placeholder only and should not
// be used.
FamilyV4
// FamilyV6 is a placeholder only and should not
// be used.
FamilyV6
// DeviceTypeEthernet is a placeholder only and should not be used.
DeviceTypeEthernet DeviceType = 0
// DeviceTypeLoopback is a placeholder only and should not be used.
DeviceTypeLoopback DeviceType = 0
// DeviceTypeNone is a placeholder only and should not be used.
DeviceTypeNone DeviceType = 0
// iffUp is a placeholder only and should not be used.
iffUp = 0
// RouteTypeUnicast is a placeholder only and should not be used.
RouteTypeUnicast = 0
// ScopeUniverse is a placeholder only and should not be used.
ScopeUniverse = 0
// ProtoStatic is a placeholder only and should not be used.
ProtoStatic = 0
// FlagInvert is a placeholder only and should not be used.
FlagInvert = 0
// ActionToTable is a placeholder only and should not be used.
ActionToTable = 0
// rtTableCompat is a placeholder only and should not be used.
rtTableCompat = 0
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) IsWireguardSupported() (bool, error) {
panic("not implemented")
}
+114 -46
View File
@@ -1,69 +1,137 @@
//go:build linux || darwin
package netlink
import (
"github.com/vishvananda/netlink"
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
)
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
// We set the filter to netlink.RT_FILTER_TABLE so that
// routes from all tables are listed, as long as the filter
// table is set to 0.
const filterMask = netlink.RT_FILTER_TABLE
// The filter is not left to `nil` otherwise non-main tables
// are ignored.
filter := &netlink.Route{}
type Route struct {
LinkIndex uint32
Dst netip.Prefix
Src netip.Prefix
Gw netip.Addr
Priority uint32
Family uint8
Table uint32
Type uint8
Scope uint8
Proto uint8
AdvMSS uint32
}
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = message.Attributes.Table
}
r.LinkIndex = message.Attributes.OutIface
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
if metrics := message.Attributes.Metrics; metrics != nil {
r.AdvMSS = metrics.AdvMSS
}
}
func (r Route) message() *rtnetlink.RouteMessage {
dst, dstLength := prefixToIPAndLength(r.Dst)
src, srcLength := prefixToIPAndLength(r.Src)
var table uint8
var extendedTable uint32
if r.Table <= uint32(^uint8(0)) {
table = uint8(r.Table)
} else {
table = rtTableCompat
extendedTable = r.Table
}
message := &rtnetlink.RouteMessage{
Family: r.Family,
DstLength: dstLength,
SrcLength: srcLength,
Table: table,
Type: r.Type,
Scope: r.Scope,
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
},
}
if src != nil { // src is optional
message.Attributes.Src = *src
}
if dst != nil {
message.Attributes.Dst = *dst
}
if r.AdvMSS != 0 {
if message.Attributes.Metrics == nil {
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
}
message.Attributes.Metrics.AdvMSS = r.AdvMSS
}
return message
}
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
routeMessages, err := conn.Route.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
}
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
routes = make([]Route, 0, len(routeMessages))
for _, routeMessage := range routeMessages {
if family != FamilyAll && routeMessage.Family != family {
continue
}
var route Route
route.fromMessage(routeMessage)
routes = append(routes, route)
}
return routes, nil
}
func (n *NetLink) RouteAdd(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Add(route.message())
}
func (n *NetLink) RouteDel(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Delete(route.message())
}
func (n *NetLink) RouteReplace(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
}
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
}
defer conn.Close()
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
return conn.Route.Replace(route.message())
}
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
RouteTypeUnicast = unix.RTN_UNICAST
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
ProtoStatic = unix.RTPROT_STATIC
rtTableCompat = unix.RT_TABLE_COMPAT
)
-21
View File
@@ -1,21 +0,0 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) RouteList(family int) (
routes []Route, err error,
) {
panic("not implemented")
}
func (n *NetLink) RouteAdd(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteDel(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteReplace(route Route) error {
panic("not implemented")
}
+76 -71
View File
@@ -1,91 +1,96 @@
//go:build linux
package netlink
import (
"fmt"
"net/netip"
"github.com/vishvananda/netlink"
"github.com/jsimonetti/rtnetlink"
)
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: 0,
}
type Rule struct {
Priority *uint32
Family uint8
Table uint32
Mark *uint32
Src netip.Prefix
Dst netip.Prefix
Flags uint32
Action uint8
}
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = *message.Attributes.Table
}
netlinkRules, err := netlink.RuleList(family)
if err != nil {
return nil, err
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Mark = message.Attributes.FwMark
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
r.Flags = message.Flags
r.Action = message.Action
}
func (r Rule) message() *rtnetlink.RuleMessage {
src, srcLength := prefixToIPAndLength(r.Src)
dst, dstLength := prefixToIPAndLength(r.Dst)
message := &rtnetlink.RuleMessage{
Family: r.Family,
SrcLength: srcLength,
DstLength: dstLength,
Flags: r.Flags,
Action: r.Action,
Attributes: &rtnetlink.RuleAttributes{
Priority: r.Priority,
FwMark: r.Mark,
Src: src,
Dst: dst,
},
}
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
if r.Table <= uint32(^uint8(0)) {
message.Table = uint8(r.Table)
} else {
message.Table = rtTableCompat
message.Attributes.Table = &r.Table
}
return rules, nil
return message
}
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(true, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
}
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(false, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
}
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
priority := ""
if r.Priority != nil {
priority = fmt.Sprintf(" %d", *r.Priority)
}
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
priority, from, to, r.Table)
}
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
func (r Rule) debugMessage(add bool) (debugMessage string) {
debugMessage = "ip"
switch rule.Family {
switch r.Family {
case FamilyV4:
debugMessage += " -f inet"
case FamilyV6:
debugMessage += " -f inet6"
default:
debugMessage += " -f " + fmt.Sprint(rule.Family)
debugMessage += " -f " + fmt.Sprint(r.Family)
}
debugMessage += " rule"
@@ -96,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
debugMessage += " del"
}
if rule.Src.IsValid() {
debugMessage += " from " + rule.Src.String()
if r.Src.IsValid() {
debugMessage += " from " + r.Src.String()
}
if rule.Dst.IsValid() {
debugMessage += " to " + rule.Dst.String()
if r.Dst.IsValid() {
debugMessage += " to " + r.Dst.String()
}
if rule.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(rule.Table)
if r.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(r.Table)
}
if rule.Priority != -1 {
debugMessage += " pref " + fmt.Sprint(rule.Priority)
if r.Priority != nil {
debugMessage += " pref " + fmt.Sprint(*r.Priority)
}
return debugMessage
+69
View File
@@ -0,0 +1,69 @@
package netlink
import (
"fmt"
"github.com/jsimonetti/rtnetlink"
"golang.org/x/sys/unix"
)
const (
FlagInvert = unix.FIB_RULE_INVERT
ActionToTable = unix.FR_ACT_TO_TBL
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ruleMessages, err := conn.Rule.List()
if err != nil {
return nil, err
}
rules = make([]Rule, 0, len(ruleMessages))
for _, message := range ruleMessages {
if family != FamilyAll && family != message.Family {
continue
}
var rule Rule
rule.fromMessage(message)
rules = append(rules, rule)
}
return rules, nil
}
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(true))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Add(rule.message())
}
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(false))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Delete(rule.message())
}
+5 -5
View File
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_ruleDbgMsg(t *testing.T) {
func Test_Rule_debugMessage(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) {
dbgMsg string
}{
"default values": {
dbgMsg: "ip -f 0 rule del pref 0",
dbgMsg: "ip -f 0 rule del",
},
"add rule": {
add: true,
@@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: 101,
Priority: ptrTo(uint32(101)),
},
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: 101,
Priority: ptrTo(uint32(101)),
},
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
dbgMsg := testCase.rule.debugMessage(testCase.add)
assert.Equal(t, testCase.dbgMsg, dbgMsg)
})
-19
View File
@@ -1,19 +0,0 @@
//go:build !linux
package netlink
func NewRule() Rule {
return Rule{}
}
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
-58
View File
@@ -1,58 +0,0 @@
package netlink
import (
"fmt"
"net/netip"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
}
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
type Rule struct {
Priority int
Family int
Table int
Mark uint32
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
@@ -1,12 +1,12 @@
//go:build linux
package netlink
import (
"errors"
"fmt"
"os"
"github.com/mdlayher/genetlink"
"github.com/qdm12/gluetun/internal/mod"
"github.com/vishvananda/netlink"
)
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
@@ -15,9 +15,8 @@ func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// modules directory, such as WSL2 kernels.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking for wireguard family: %w", err)
}
if ok {
return false, fmt.Errorf("checking wireguard family: %w", err)
} else if ok {
return true, nil
}
@@ -35,20 +34,25 @@ func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// the wireguard kernel module.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking for wireguard family: %w", err)
return false, fmt.Errorf("checking wireguard family: %w", err)
}
return ok, nil
}
func hasWireguardFamily() (ok bool, err error) {
families, err := netlink.GenlFamilyList()
conn, err := genetlink.Dial(nil)
if err != nil {
return false, fmt.Errorf("listing gen 1 families: %w", err)
return false, fmt.Errorf("dialing netlink: %w", err)
}
for _, family := range families {
if family.Name == "wireguard" {
return true, nil
defer conn.Close()
_, err = conn.GetFamily("wireguard")
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("getting wireguard family: %w", err)
}
return false, nil
return true, nil
}
+4 -3
View File
@@ -9,10 +9,11 @@ import (
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
t.Skip() // TODO unskip once the data race problem with netlink.GenlFamilyList() is fixed
t.Parallel()
netLink := &NetLink{}
netLink := &NetLink{
debugLogger: &noopLogger{},
}
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
if ok { // cannot assert since this depends on kernel
+1 -2
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"os/exec"
"syscall"
"github.com/qdm12/gluetun/internal/constants/openvpn"
)
@@ -33,7 +32,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
args := []string{"--config", configPath}
args = append(args, flags...)
cmd := exec.CommandContext(ctx, bin, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
setCmdSysProcAttr(cmd)
return starter.Start(cmd)
}
+10
View File
@@ -0,0 +1,10 @@
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
+12
View File
@@ -0,0 +1,12 @@
//go:build !linux
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
+24
View File
@@ -0,0 +1,24 @@
package constants
const (
MaxEthernetFrameSize uint32 = 1500
// MinIPv4MTU is defined according to
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
MinIPv4MTU uint32 = 68
MinIPv6MTU uint32 = 1280
IPv4HeaderLength uint32 = 20
IPv6HeaderLength uint32 = 40
UDPHeaderLength uint32 = 8
// BaseTCPHeaderLength is the TCP header length without options,
// which is the minimum TCP header length.
BaseTCPHeaderLength uint32 = 20
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
MaxTCPHeaderLength uint32 = 60
WireguardHeaderLength uint32 = 32
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
8 + // session id
4 + // packet id
28 // max possible auth tag/iv
)
+16
View File
@@ -0,0 +1,16 @@
//go:build linux || darwin
package constants
import "golang.org/x/sys/unix"
//nolint:revive
const (
SOCK_RAW = unix.SOCK_RAW
SOCK_STREAM = unix.SOCK_STREAM
AF_INET = unix.AF_INET
AF_INET6 = unix.AF_INET6
IPPROTO_TCP = unix.IPPROTO_TCP
EAGAIN = unix.EAGAIN
EWOULDBLOCK = unix.EWOULDBLOCK
)
@@ -0,0 +1,13 @@
package constants
import "golang.org/x/sys/windows"
const (
SOCK_RAW = windows.SOCK_RAW
SOCK_STREAM = windows.SOCK_STREAM
AF_INET = windows.AF_INET
AF_INET6 = windows.AF_INET6
IPPROTO_TCP = windows.IPPROTO_TCP
EAGAIN = windows.WSAEWOULDBLOCK
EWOULDBLOCK = windows.WSAEWOULDBLOCK
)
+49
View File
@@ -0,0 +1,49 @@
package icmp
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}
+83
View File
@@ -0,0 +1,83 @@
package icmp
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrEchoDataMismatch, sent, received)
}
return nil
}
+14
View File
@@ -0,0 +1,14 @@
package icmp
import (
"golang.org/x/sys/unix"
)
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
if ipv4 {
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP,
unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
}
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6,
unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
}
+10
View File
@@ -0,0 +1,10 @@
//go:build !linux && !windows
package icmp
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
return nil
}

Some files were not shown because too many files have changed in this diff Show More