mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-28 06:47:29 +02:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f22fb3276 | |||
| 6909a0c123 | |||
| 3e1f48932a | |||
| 50744852c5 | |||
| 09e52bc685 | |||
| 857fe425ec |
@@ -1,2 +1,2 @@
|
||||
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
||||
RUN apk add wireguard-tools htop openssl tcpdump iptables
|
||||
RUN apk add wireguard-tools htop openssl
|
||||
|
||||
@@ -45,7 +45,6 @@ jobs:
|
||||
level: error
|
||||
exclude: |
|
||||
./internal/storage/servers.json
|
||||
./golangci.yml
|
||||
*.md
|
||||
|
||||
- name: Linting
|
||||
@@ -60,13 +59,10 @@ jobs:
|
||||
- name: Run tests in test container
|
||||
run: |
|
||||
touch coverage.txt
|
||||
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
|
||||
docker run --rm --device /dev/net/tun \
|
||||
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
||||
test-container
|
||||
|
||||
- name: Verify dev cross platform compatibility
|
||||
run: docker build --target xcompile .
|
||||
|
||||
- name: Build final image
|
||||
run: docker build -t final-image .
|
||||
|
||||
|
||||
+1
-2
@@ -22,7 +22,6 @@ linters:
|
||||
- "^disabled$"
|
||||
# Firewall and routing strings
|
||||
- "^(ACCEPT|DROP)$"
|
||||
- "^--append$"
|
||||
- "^--delete$"
|
||||
- "^all$"
|
||||
- "^(tcp|udp)$"
|
||||
@@ -48,7 +47,7 @@ linters:
|
||||
path: internal\/server\/.+\.go
|
||||
- linters:
|
||||
- ireturn
|
||||
text: returns interface \(golang\.org\/x\/sys\/unix\.Sockaddr\)
|
||||
text: returns interface \(github\.com\/vishvananda\/netlink\.Link\)
|
||||
- linters:
|
||||
- ireturn
|
||||
path: internal\/openvpn\/pkcs8\/descbc\.go
|
||||
|
||||
+2
-9
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
|
||||
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
|
||||
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
|
||||
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
|
||||
RUN apk --update add git g++ findutils iptables
|
||||
RUN apk --update add git g++ findutils
|
||||
ENV CGO_ENABLED=0
|
||||
COPY --from=golangci-lint /bin /go/bin/golangci-lint
|
||||
COPY --from=mockgen /bin /go/bin/mockgen
|
||||
@@ -46,10 +46,6 @@ RUN git init && \
|
||||
git diff --exit-code && \
|
||||
rm -rf .git/
|
||||
|
||||
FROM --platform=${BUILDPLATFORM} base AS xcompile
|
||||
RUN GOOS=darwin go build -o /dev/null ./...
|
||||
RUN GOOS=windows go build -o /dev/null ./...
|
||||
|
||||
FROM --platform=${BUILDPLATFORM} base AS build
|
||||
ARG TARGETPLATFORM
|
||||
ARG VERSION=unknown
|
||||
@@ -110,11 +106,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||
WIREGUARD_ADDRESSES= \
|
||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
WIREGUARD_MTU= \
|
||||
WIREGUARD_MTU=1320 \
|
||||
WIREGUARD_IMPLEMENTATION=auto \
|
||||
# PMTUD
|
||||
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
||||
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
|
||||
# VPN server filtering
|
||||
SERVER_REGIONS= \
|
||||
SERVER_COUNTRIES= \
|
||||
|
||||
@@ -58,7 +58,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
||||
## Features
|
||||
|
||||
- Based on Alpine 3.22 for a small Docker image of 41.1MB
|
||||
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
|
||||
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
|
||||
- Supports OpenVPN for all providers listed
|
||||
- Supports Wireguard both kernelspace and userspace
|
||||
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
|
||||
|
||||
+16
-19
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
@@ -168,7 +167,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
defer fmt.Println(gluetunLogo)
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2026-04-01T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,7 +178,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "All control server routes are now private by default",
|
||||
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -227,7 +226,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
firewallLogger.Patch(log.SetLevel(log.LevelDebug))
|
||||
}
|
||||
firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder,
|
||||
netLinker, defaultRoutes, localNetworks)
|
||||
defaultRoutes, localNetworks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -264,7 +263,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID)
|
||||
|
||||
const clientTimeout = 35 * time.Second
|
||||
const clientTimeout = 15 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
// Create configurators
|
||||
alpineConf := alpine.New()
|
||||
@@ -279,7 +278,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
err = printVersions(ctx, logger, []printVersionElement{
|
||||
{name: "Alpine", getVersion: alpineConf.Version},
|
||||
{name: "OpenVPN", getVersion: ovpnVersion},
|
||||
{name: "Firewall", getVersion: firewallConf.Version},
|
||||
{name: "IPtables", getVersion: firewallConf.Version},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -554,27 +553,26 @@ type netLinker interface {
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
IsWireguardSupported() bool
|
||||
IsIPv6Supported() (ok bool, err error)
|
||||
FlushConntrack() error
|
||||
PatchLoggerLevel(level log.Level)
|
||||
}
|
||||
|
||||
type Addresser interface {
|
||||
AddrList(linkIndex uint32, family uint8) (
|
||||
addresses []netip.Prefix, err error)
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
AddrList(link netlink.Link, family int) (
|
||||
addresses []netlink.Addr, err error)
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
RouteDel(route netlink.Route) error
|
||||
RouteReplace(route netlink.Route) error
|
||||
}
|
||||
|
||||
type Ruler interface {
|
||||
RuleList(family uint8) (rules []netlink.Rule, err error)
|
||||
RuleList(family int) (rules []netlink.Rule, err error)
|
||||
RuleAdd(rule netlink.Rule) error
|
||||
RuleDel(rule netlink.Rule) error
|
||||
}
|
||||
@@ -582,12 +580,11 @@ type Ruler interface {
|
||||
type Linker interface {
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkByIndex(index uint32) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkDel(linkIndex uint32) (err error)
|
||||
LinkSetUp(linkIndex uint32) (err error)
|
||||
LinkSetDown(linkIndex uint32) (err error)
|
||||
LinkSetMTU(linkIndex, mtu uint32) error
|
||||
LinkByIndex(index int) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
}
|
||||
|
||||
type clier interface {
|
||||
|
||||
@@ -7,13 +7,10 @@ require (
|
||||
github.com/breml/rootcerts v0.3.3
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/jsimonetti/rtnetlink v1.4.2
|
||||
github.com/klauspost/compress v1.18.1
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/mdlayher/genetlink v1.3.2
|
||||
github.com/mdlayher/netlink v1.7.2
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10
|
||||
github.com/qdm12/gosettings v0.4.4
|
||||
github.com/qdm12/goshutdown v0.3.0
|
||||
github.com/qdm12/gosplash v0.2.0
|
||||
@@ -21,13 +18,13 @@ require (
|
||||
github.com/qdm12/log v0.1.0
|
||||
github.com/qdm12/ss-server v0.6.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/ti-mo/netfilter v0.5.3
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.org/x/text v0.33.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/text v0.31.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
gopkg.in/ini.v1 v1.67.0
|
||||
@@ -41,11 +38,13 @@ require (
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
@@ -56,11 +55,12 @@ require (
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
golang.org/x/crypto v0.47.0 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/mod v0.29.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/protobuf v1.35.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
@@ -13,8 +13,6 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
|
||||
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
|
||||
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
|
||||
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
@@ -28,12 +26,10 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
|
||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
||||
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
|
||||
@@ -51,8 +47,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
@@ -73,8 +69,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
|
||||
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 h1:0ycKUDQ50cYb2QpeyGcEnvVs9HJmC9jsb/XZNC1z28c=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10 h1:IyeNEYXfhBsaE1dwxx5eAqdAz1HS98dT+8c7xoKODa0=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
|
||||
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
|
||||
@@ -95,10 +91,12 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ=
|
||||
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
|
||||
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
@@ -108,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -124,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -142,10 +140,12 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -155,8 +155,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -164,8 +164,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -23,9 +23,7 @@ type DNSBlacklist struct {
|
||||
AddBlockedIPs []netip.Addr
|
||||
AddBlockedIPPrefixes []netip.Prefix
|
||||
// RebindingProtectionExemptHostnames is a list of hostnames
|
||||
// exempt from DNS rebinding protection. It can contain parent
|
||||
// domains which are of the form "*.example.com". Note the wildcard
|
||||
// can only be used at the start of the hostname.
|
||||
// exempt from DNS rebinding protection.
|
||||
RebindingProtectionExemptHostnames []string
|
||||
}
|
||||
|
||||
@@ -57,9 +55,6 @@ func (b DNSBlacklist) validate() (err error) {
|
||||
}
|
||||
|
||||
for _, host := range b.RebindingProtectionExemptHostnames {
|
||||
if len(host) > 2 && host[:2] == "*." {
|
||||
host = host[2:]
|
||||
}
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
|
||||
}
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
// PMTUD contains settings to configure Path MTU Discovery.
|
||||
type PMTUD struct {
|
||||
// ICMPAddresses is the redundancy list of addresses to use
|
||||
// for ICMP path MTU discovery. Each address MUST handle ICMP
|
||||
// packets for PMTUD to work.
|
||||
// It cannot be nil in the internal state.
|
||||
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
|
||||
// TCPAddresses is the redundancy list of addresses to use
|
||||
// for TCP path MTU discovery. Each address MUST have a listening
|
||||
// TCP server on the port specified.
|
||||
// It cannot be nil in the internal state.
|
||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||
)
|
||||
|
||||
// Validate validates PMTUD settings.
|
||||
func (p PMTUD) validate() (err error) {
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
for i, addr := range p.TCPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PMTUD) copy() (copied PMTUD) {
|
||||
return PMTUD{
|
||||
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
|
||||
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PMTUD) overrideWith(other PMTUD) {
|
||||
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
|
||||
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
|
||||
}
|
||||
|
||||
func (p *PMTUD) setDefaults() {
|
||||
defaultICMPAddresses := []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
|
||||
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
|
||||
}
|
||||
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
|
||||
|
||||
const dnsPort, tlsPort = 53, 443
|
||||
defaultTCPAddresses := []netip.AddrPort{
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), dnsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
|
||||
}
|
||||
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
||||
}
|
||||
|
||||
func (p PMTUD) String() string {
|
||||
return p.toLinesNode().String()
|
||||
}
|
||||
|
||||
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Path MTU discovery:")
|
||||
|
||||
icmpAddrNode := node.Append("ICMP addresses:")
|
||||
for _, addr := range p.ICMPAddresses {
|
||||
icmpAddrNode.Append(addr.String())
|
||||
}
|
||||
|
||||
tcpAddrNode := node.Append("TCP addresses:")
|
||||
for _, addr := range p.TCPAddresses {
|
||||
tcpAddrNode.Append(addr.String())
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func (p *PMTUD) read(r *reader.Reader) (err error) {
|
||||
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
@@ -33,11 +31,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
if vpnType == vpn.OpenVPN {
|
||||
validNames = providers.AllWithCustom()
|
||||
validNames = append(validNames, "pia") // Retro-compatibility
|
||||
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
|
||||
mullvadIndex := slices.Index(validNames, providers.Mullvad)
|
||||
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
|
||||
validNames = validNames[:len(validNames)-1]
|
||||
sort.Strings(validNames)
|
||||
} else { // Wireguard
|
||||
validNames = []string{
|
||||
providers.Airvpn,
|
||||
|
||||
@@ -29,27 +29,14 @@ func Test_Settings_String(t *testing.T) {
|
||||
| | └── OpenVPN server selection settings:
|
||||
| | ├── Protocol: UDP
|
||||
| | └── Private Internet Access encryption preset: strong
|
||||
| ├── OpenVPN settings:
|
||||
| | ├── OpenVPN version: 2.6
|
||||
| | ├── User: [not set]
|
||||
| | ├── Password: [not set]
|
||||
| | ├── Private Internet Access encryption preset: strong
|
||||
| | ├── Network interface: tun0
|
||||
| | ├── Run OpenVPN as: root
|
||||
| | └── Verbosity level: 1
|
||||
| └── Path MTU discovery:
|
||||
| ├── ICMP addresses:
|
||||
| | ├── 1.1.1.1
|
||||
| | └── 8.8.8.8
|
||||
| └── TCP addresses:
|
||||
| ├── 1.1.1.1:53
|
||||
| ├── 8.8.8.8:53
|
||||
| ├── 1.1.1.1:443
|
||||
| ├── 8.8.8.8:443
|
||||
| ├── [2606:4700:4700::1111]:53
|
||||
| ├── [2001:4860:4860::8888]:53
|
||||
| ├── [2606:4700:4700::1111]:443
|
||||
| └── [2001:4860:4860::8888]:443
|
||||
| └── OpenVPN settings:
|
||||
| ├── OpenVPN version: 2.6
|
||||
| ├── User: [not set]
|
||||
| ├── Password: [not set]
|
||||
| ├── Private Internet Access encryption preset: strong
|
||||
| ├── Network interface: tun0
|
||||
| ├── Run OpenVPN as: root
|
||||
| └── Verbosity level: 1
|
||||
├── DNS settings:
|
||||
| ├── Keep existing nameserver(s): no
|
||||
| ├── DNS server address to use: 127.0.0.1
|
||||
|
||||
@@ -18,7 +18,6 @@ type VPN struct {
|
||||
Provider Provider `json:"provider"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
PMTUD PMTUD `json:"pmtud"`
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
@@ -46,11 +45,6 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
}
|
||||
}
|
||||
|
||||
err = v.PMTUD.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +54,6 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
Provider: v.Provider.copy(),
|
||||
OpenVPN: v.OpenVPN.copy(),
|
||||
Wireguard: v.Wireguard.copy(),
|
||||
PMTUD: v.PMTUD.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +62,6 @@ func (v *VPN) OverrideWith(other VPN) {
|
||||
v.Provider.overrideWith(other.Provider)
|
||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||
v.Wireguard.overrideWith(other.Wireguard)
|
||||
v.PMTUD.overrideWith(other.PMTUD)
|
||||
}
|
||||
|
||||
func (v *VPN) setDefaults() {
|
||||
@@ -77,7 +69,6 @@ func (v *VPN) setDefaults() {
|
||||
v.Provider.setDefaults()
|
||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||
v.Wireguard.setDefaults(v.Provider.Name)
|
||||
v.PMTUD.setDefaults()
|
||||
}
|
||||
|
||||
func (v VPN) String() string {
|
||||
@@ -94,7 +85,6 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
||||
} else {
|
||||
node.AppendNode(v.Wireguard.toLinesNode())
|
||||
}
|
||||
node.AppendNode(v.PMTUD.toLinesNode())
|
||||
|
||||
return node
|
||||
}
|
||||
@@ -117,10 +107,5 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
||||
return fmt.Errorf("wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = v.PMTUD.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,9 +38,14 @@ type Wireguard struct {
|
||||
Interface string `json:"interface"`
|
||||
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||
// It cannot be nil in the internal state, and defaults to
|
||||
// 0 indicating to use PMTUD.
|
||||
MTU *uint32 `json:"mtu"`
|
||||
// It cannot be zero in the internal state, and defaults to
|
||||
// 1320. Note it is not the wireguard-go MTU default of 1420
|
||||
// because this impacts bandwidth a lot on some VPN providers,
|
||||
// see https://github.com/qdm12/gluetun/issues/1650.
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue:
|
||||
// https://github.com/qdm12/gluetun/issues/2533.
|
||||
MTU uint16 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
// It defaults to "auto" and cannot be the empty string
|
||||
@@ -189,7 +194,8 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||
const defaultMTU = 1320
|
||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||
}
|
||||
|
||||
@@ -225,11 +231,7 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||
if *w.MTU == 0 {
|
||||
interfaceNode.Append("MTU: use path MTU discovery")
|
||||
} else {
|
||||
interfaceNode.Appendf("MTU: %d", *w.MTU)
|
||||
}
|
||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||
|
||||
if w.Implementation != "auto" {
|
||||
node.Appendf("Implementation: %s", w.Implementation)
|
||||
@@ -270,9 +272,11 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
||||
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if mtuPtr != nil {
|
||||
w.MTU = *mtuPtr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+7
-6
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
@@ -43,12 +44,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
runError, err = l.setupServer(ctx)
|
||||
if err == nil {
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.logger.Info("ready and using DNS server at address " + settings.ServerAddress.String())
|
||||
|
||||
err = l.updateFiles(ctx, settings)
|
||||
if err != nil {
|
||||
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
|
||||
}
|
||||
l.logger.Info("ready")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -57,6 +53,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.Is(err, errUpdateBlockLists) {
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
settings = l.GetSettings()
|
||||
}
|
||||
|
||||
@@ -2,24 +2,25 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
"github.com/qdm12/dns/v2/pkg/server"
|
||||
)
|
||||
|
||||
var errUpdateBlockLists = errors.New("cannot update filter block lists")
|
||||
|
||||
func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err error) {
|
||||
settings := l.GetSettings()
|
||||
var updateSettings update.Settings
|
||||
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
|
||||
err = l.filter.Update(updateSettings)
|
||||
err = l.updateFiles(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("updating filter for rebinding protection: %w", err)
|
||||
return nil, fmt.Errorf("%w: %w", errUpdateBlockLists, err)
|
||||
}
|
||||
|
||||
settings := l.GetSettings()
|
||||
|
||||
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("building server settings: %w", err)
|
||||
|
||||
+15
-4
@@ -28,12 +28,23 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
|
||||
return
|
||||
case <-timer.C:
|
||||
lastTick = l.timeNow()
|
||||
settings := l.GetSettings()
|
||||
if l.GetStatus() == constants.Running {
|
||||
if err := l.updateFiles(ctx, settings); err != nil {
|
||||
l.logger.Warn("updating block lists failed, skipping: " + err.Error())
|
||||
|
||||
status := l.GetStatus()
|
||||
if status == constants.Running {
|
||||
if err := l.updateFiles(ctx); err != nil {
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logger.Error(err.Error())
|
||||
l.logger.Warn("skipping DNS server restart due to failed files update")
|
||||
settings := l.GetSettings()
|
||||
timer.Reset(*settings.UpdatePeriod)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = l.statusManager.ApplyStatus(ctx, constants.Running)
|
||||
|
||||
settings := l.GetSettings()
|
||||
timer.Reset(*settings.UpdatePeriod)
|
||||
case <-l.updateTicker:
|
||||
if !timer.Stop() {
|
||||
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/blockbuilder"
|
||||
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
)
|
||||
|
||||
func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err error) {
|
||||
func (l *Loop) updateFiles(ctx context.Context) (err error) {
|
||||
settings := l.GetSettings()
|
||||
|
||||
l.logger.Info("downloading hostnames and IP block lists")
|
||||
blacklistSettings := settings.Blacklist.ToBlockBuilderSettings(l.client)
|
||||
|
||||
@@ -36,6 +37,7 @@ func (l *Loop) updateFiles(ctx context.Context, settings settings.DNS) (err erro
|
||||
IPPrefixes: result.BlockedIPPrefixes,
|
||||
}
|
||||
updateSettings.BlockHostnames(result.BlockedHostnames)
|
||||
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
|
||||
err = l.filter.Update(updateSettings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating filter: %w", err)
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
+58
-37
@@ -22,7 +22,9 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
c.restore(ctx)
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return fmt.Errorf("disabling firewall: %w", err)
|
||||
}
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
return nil
|
||||
@@ -39,42 +41,64 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
c.restore, err = c.impl.SaveAndRestore(ctx)
|
||||
func (c *Config) disable(ctx context.Context) (err error) {
|
||||
if err = c.clearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("clearing all rules: %w", err)
|
||||
}
|
||||
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv4 policies: %w", err)
|
||||
}
|
||||
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv6 policies: %w", err)
|
||||
}
|
||||
|
||||
const remove = true
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving firewall rules: %w", err)
|
||||
return fmt.Errorf("removing port redirections: %w", err)
|
||||
}
|
||||
|
||||
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// To use in defered call when enabling the firewall.
|
||||
func (c *Config) fallbackToDisabled(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := c.disable(ctx); err != nil {
|
||||
c.logger.Error("failed reversing firewall changes: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
touched := false
|
||||
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
touched = true
|
||||
|
||||
if err = c.impl.SetIPv6AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.restore(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
|
||||
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
|
||||
defer func() {
|
||||
if touched && err != nil {
|
||||
c.fallbackToDisabled(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.flushExistingConnections(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("flushing existing connections: %w", err)
|
||||
}
|
||||
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
|
||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -84,9 +108,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
|
||||
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
|
||||
for _, network := range c.localNetworks {
|
||||
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
|
||||
network.InterfaceName, network.IP, network.IPNet, remove)
|
||||
if err != nil {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -95,7 +117,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
localInterfaces[network.InterfaceName] = struct{}{}
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
|
||||
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
|
||||
}
|
||||
@@ -108,7 +130,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
// Allows packets from any IP address to go through eth0 / local network
|
||||
// to reach Gluetun.
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
|
||||
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -117,12 +139,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.redirectPorts(ctx)
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting ports: %w", err)
|
||||
}
|
||||
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
|
||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
return fmt.Errorf("running user defined post firewall rules: %w", err)
|
||||
}
|
||||
|
||||
@@ -142,7 +164,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
|
||||
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
|
||||
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting output traffic through VPN: %w", err)
|
||||
}
|
||||
@@ -164,7 +186,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
|
||||
firewallUpdated = true
|
||||
|
||||
const remove = false
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subnet, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -182,7 +204,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
for port, netInterfaces := range c.allowedInputPorts {
|
||||
for netInterface := range netInterfaces {
|
||||
const remove = false
|
||||
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
|
||||
err = c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting input port %d on interface %s: %w",
|
||||
port, netInterface, err)
|
||||
@@ -192,10 +214,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) redirectPorts(ctx context.Context) (err error) {
|
||||
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
|
||||
for _, portRedirection := range c.portRedirections {
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
portRedirection.destinationPort, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,29 +2,28 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall/iptables"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
runner CmdRunner
|
||||
netlinker Netlinker
|
||||
logger Logger
|
||||
defaultRoutes []routing.DefaultRoute
|
||||
localNetworks []routing.LocalNetwork
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
defaultRoutes []routing.DefaultRoute
|
||||
localNetworks []routing.LocalNetwork
|
||||
|
||||
// Fixed
|
||||
impl firewallImpl
|
||||
// Fixed state
|
||||
ipTables string
|
||||
ip6Tables string
|
||||
customRulesPath string
|
||||
|
||||
// State
|
||||
enabled bool
|
||||
restore func(context.Context)
|
||||
vpnConnection models.Connection
|
||||
vpnIntf string
|
||||
outboundSubnets []netip.Prefix
|
||||
@@ -36,23 +35,28 @@ type Config struct {
|
||||
// NewConfig creates a new Config instance and returns an error
|
||||
// if no iptables implementation is available.
|
||||
func NewConfig(ctx context.Context, logger Logger,
|
||||
runner CmdRunner, netlinker Netlinker,
|
||||
defaultRoutes []routing.DefaultRoute, localNetworks []routing.LocalNetwork,
|
||||
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
|
||||
localNetworks []routing.LocalNetwork,
|
||||
) (config *Config, err error) {
|
||||
impl, err := iptables.New(ctx, runner, logger)
|
||||
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating iptables firewall: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip6tables, err := findIP6tablesSupported(ctx, runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Config{
|
||||
runner: runner,
|
||||
netlinker: netlinker,
|
||||
logger: logger,
|
||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
// Obtained from routing
|
||||
defaultRoutes: defaultRoutes,
|
||||
localNetworks: localNetworks,
|
||||
impl: impl,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
defaultRoutes: defaultRoutes,
|
||||
localNetworks: localNetworks,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall/iptables"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (c *Config) flushExistingConnections(ctx context.Context) error {
|
||||
tries := []struct {
|
||||
name string
|
||||
f func(ctx context.Context) error
|
||||
}{
|
||||
{name: "flushing conntrack", f: func(_ context.Context) error {
|
||||
return c.netlinker.FlushConntrack()
|
||||
}},
|
||||
{name: "marking and filtering unmarked packets", f: c.impl.AcceptOutputPublicOnlyNewTraffic},
|
||||
{name: "rejecting connections for one second", f: c.rejectOutputTrafficTemporarily},
|
||||
{name: "dropping connections for one second", f: c.dropOutputTrafficTemporarily},
|
||||
}
|
||||
errs := make([]error, 0, len(tries))
|
||||
for i, try := range tries {
|
||||
if i > 0 {
|
||||
c.logger.Debugf("falling back to %s because %s failed: %s", try.name, tries[i-1].name, errs[i-1])
|
||||
}
|
||||
err := try.f(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
err = fmt.Errorf("%s: %w", try.name, err)
|
||||
if !errors.Is(err, iptables.ErrKernelModuleMissing) && !errors.Is(err, netlink.ErrConntrackNetlinkNotSupported) {
|
||||
return err
|
||||
}
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return fmt.Errorf("all tries failed: %v", errs) //nolint:err113
|
||||
}
|
||||
|
||||
func (c *Config) rejectOutputTrafficTemporarily(ctx context.Context) error {
|
||||
return setupThenRevert(ctx, c.impl.RejectOutputPublicTraffic)
|
||||
}
|
||||
|
||||
func (c *Config) dropOutputTrafficTemporarily(ctx context.Context) error {
|
||||
return setupThenRevert(ctx, c.impl.DropOutputPublicTraffic)
|
||||
}
|
||||
|
||||
// setupThenRevert is a helper function to run a setup function that takes a remove boolean argument,
|
||||
// and then run the same function with remove set to true after one second or when the context is canceled,
|
||||
// whichever comes first.
|
||||
func setupThenRevert(ctx context.Context, f func(ctx context.Context, remove bool) error) error {
|
||||
remove := false
|
||||
err := f(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up: %w", err)
|
||||
}
|
||||
timer := time.NewTimer(time.Second)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
}
|
||||
remove = true
|
||||
// Use [context.Background] to make sure this is removed, even if the context
|
||||
// passed to this function is canceled.
|
||||
err = f(context.Background(), remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reverting: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,12 +1,6 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
import "os/exec"
|
||||
|
||||
type CmdRunner interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
@@ -14,37 +8,7 @@ type CmdRunner interface {
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Debugf(format string, args ...any)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type Netlinker interface {
|
||||
FlushConntrack() error
|
||||
}
|
||||
|
||||
type firewallImpl interface { //nolint:interfacebloat
|
||||
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
|
||||
AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error
|
||||
RejectOutputPublicTraffic(ctx context.Context, remove bool) error
|
||||
DropOutputPublicTraffic(ctx context.Context, remove bool) error
|
||||
AcceptInputThroughInterface(ctx context.Context, intf string) error
|
||||
AcceptEstablishedRelatedTraffic(ctx context.Context) error
|
||||
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
|
||||
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
|
||||
connection models.Connection, remove bool) error
|
||||
RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
destinationPort uint16, remove bool) error
|
||||
RunUserPostRules(ctx context.Context, customRulesPath string) error
|
||||
SetIPv4AllPolicies(ctx context.Context, policy string) error
|
||||
SetIPv6AllPolicies(ctx context.Context, policy string) error
|
||||
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error)
|
||||
Version(ctx context.Context) (version string, err error)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
ip6tablesPath string, err error,
|
||||
) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
|
||||
if errors.Is(err, ErrNotSupported) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
|
||||
if errors.Is(err, ErrIPTablesNotSupported) {
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
@@ -24,23 +24,8 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -48,24 +33,11 @@ func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instruction
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if c.ip6Tables == "" {
|
||||
return nil
|
||||
}
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
@@ -76,9 +48,6 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
if strings.Contains(output, "missing kernel module") {
|
||||
err = ErrKernelModuleMissing
|
||||
}
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ip6Tables, instruction, output, err)
|
||||
}
|
||||
@@ -87,7 +56,7 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
@@ -0,0 +1,331 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
if remove {
|
||||
return "--delete"
|
||||
}
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(rule, "-A"):
|
||||
return strings.Replace(rule, "-A", "-D", 1)
|
||||
case strings.HasPrefix(rule, "--append"):
|
||||
return strings.Replace(rule, "--append", "-D", 1)
|
||||
case strings.HasPrefix(rule, "-D"):
|
||||
return strings.Replace(rule, "-D", "-A", 1)
|
||||
case strings.HasPrefix(rule, "--delete"):
|
||||
return strings.Replace(rule, "--delete", "-A", 1)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
|
||||
destination netip.Prefix, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destination.String())
|
||||
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool,
|
||||
) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp"
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
protocol, connection.Port)
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// Thanks to @npawelek.
|
||||
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
|
||||
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
|
||||
) error {
|
||||
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
|
||||
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4).
|
||||
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
|
||||
intf string, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag)
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// Used for port forwarding, with intf set to tun.
|
||||
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
})
|
||||
}
|
||||
|
||||
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
|
||||
func (c *Config) redirectPort(ctx context.Context, intf string,
|
||||
sourcePort, destinationPort uint16, remove bool,
|
||||
) (err error) {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
errMessage := err.Error()
|
||||
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
|
||||
if !remove {
|
||||
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
}()
|
||||
for _, line := range lines {
|
||||
var ipv4 bool
|
||||
var rule string
|
||||
switch {
|
||||
case strings.HasPrefix(line, "iptables "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables ")
|
||||
case strings.HasPrefix(line, "iptables-nft "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-nft ")
|
||||
case strings.HasPrefix(line, "iptables-legacy "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-legacy ")
|
||||
case strings.HasPrefix(line, "ip6tables "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables ")
|
||||
case strings.HasPrefix(line, "ip6tables-nft "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-nft ")
|
||||
case strings.HasPrefix(line, "ip6tables-legacy "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstruction(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstruction(ctx, rule)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SaveAndRestore saves the current iptables and ip6tables rules and
|
||||
// returns a restore function that can be called to restore the saved rules.
|
||||
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
return c.saveAndRestore(ctx)
|
||||
}
|
||||
|
||||
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
|
||||
// before calling this function. Note the restore function does not interact with mutexes
|
||||
// so the caller must make sure the mutexes are locked when calling the restore function.
|
||||
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
restoreIPv4(ctx)
|
||||
if restoreIPv6 != nil {
|
||||
restoreIPv6(ctx)
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
|
||||
if c.ip6Tables == "" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrKernelModuleMissing = errors.New("kernel module is missing for this operation")
|
||||
|
||||
type Config struct {
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
|
||||
// Fixed state
|
||||
ipTables string
|
||||
ip6Tables string
|
||||
}
|
||||
|
||||
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
|
||||
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip6tables, err := findIP6tablesSupported(ctx, runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Config{
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import "os/exec"
|
||||
|
||||
type CmdRunner interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -1,485 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
if remove {
|
||||
return "--delete"
|
||||
}
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
if strings.Contains(output, "missing kernel module") {
|
||||
err = ErrKernelModuleMissing
|
||||
}
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"--append INPUT -i %s -j ACCEPT", intf))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
|
||||
interfaceFlag, destination.String())
|
||||
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
// AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept
|
||||
// established or related packets with this mark only. This effectively forces
|
||||
// previously established or related traffic to be blocked.
|
||||
// If remove is true, the rules are removed instead of appended.
|
||||
// If the relevant kernel modules are not available, it returns an error indicating
|
||||
// which kernel module is missing.
|
||||
func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error {
|
||||
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
|
||||
appendToBoth := func(instruction string) {
|
||||
ipv4Instructions = append(ipv4Instructions, instruction)
|
||||
ipv6Instructions = append(ipv6Instructions, instruction)
|
||||
}
|
||||
|
||||
// Mark new connections with mark 0x567
|
||||
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate NEW -j CONNMARK --set-mark 0x567")
|
||||
// Drop related/established connections that made it through; marked connections would
|
||||
// be directly accepted by the first rule in the OUTPUT chain (see below)
|
||||
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j DROP")
|
||||
// Set the PUBLIC_ONLY chain as the second rule in the OUTPUT chain, so that it is evaluated
|
||||
// after the accept rule below, for performance reasons.
|
||||
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
|
||||
appendToBoth("-I OUTPUT -m conntrack --ctstate RELATED,ESTABLISHED -m connmark --mark 0x567 -j ACCEPT")
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) RejectOutputPublicTraffic(ctx context.Context, remove bool) error {
|
||||
return c.targetOutputPublicTraffic(ctx, "REJECT", remove)
|
||||
}
|
||||
|
||||
func (c *Config) DropOutputPublicTraffic(ctx context.Context, remove bool) error {
|
||||
return c.targetOutputPublicTraffic(ctx, "DROP", remove)
|
||||
}
|
||||
|
||||
func (c *Config) targetOutputPublicTraffic(ctx context.Context, target string, remove bool) error {
|
||||
removeInstructions := []string{
|
||||
"-D OUTPUT -j PUBLIC_ONLY",
|
||||
"-F PUBLIC_ONLY",
|
||||
"-X PUBLIC_ONLY",
|
||||
}
|
||||
if remove {
|
||||
return c.runMixedIptablesInstructions(ctx, removeInstructions)
|
||||
}
|
||||
|
||||
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
|
||||
appendToBoth := func(instruction string) {
|
||||
ipv4Instructions = append(ipv4Instructions, instruction)
|
||||
ipv6Instructions = append(ipv6Instructions, instruction)
|
||||
}
|
||||
|
||||
if target == "REJECT" {
|
||||
// Block TCP by sending back TCP RST packets.
|
||||
appendToBoth("-A PUBLIC_ONLY -p tcp -m conntrack --ctstate RELATED,ESTABLISHED " +
|
||||
"-j REJECT --reject-with tcp-reset")
|
||||
// Block UDP and ICMP, sending back ICMP port unreachable.
|
||||
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j REJECT")
|
||||
} else {
|
||||
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j " + target)
|
||||
}
|
||||
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
|
||||
|
||||
err := c.runIptablesInstructions(ctx, ipv4Instructions)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), " support") {
|
||||
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructions(ctx, ipv6Instructions)
|
||||
if err != nil {
|
||||
_ = c.runIptablesInstructions(ctx, removeInstructions)
|
||||
if strings.Contains(err.Error(), " support") {
|
||||
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeCreatePublicIPChainInstructions() (ipv4Instructions, ipv6Instructions []string) {
|
||||
ipv4PrivatePrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("127.0.0.0/8"),
|
||||
}
|
||||
ipv6PrivatePrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fc00::/7"),
|
||||
netip.MustParsePrefix("fe80::/10"),
|
||||
netip.MustParsePrefix("::1/128"),
|
||||
}
|
||||
|
||||
ipv4Instructions = append(ipv4Instructions, "-N PUBLIC_ONLY")
|
||||
ipv6Instructions = append(ipv6Instructions, "-N PUBLIC_ONLY")
|
||||
|
||||
for _, prefix := range ipv4PrivatePrefixes {
|
||||
ipv4Instructions = append(ipv4Instructions, fmt.Sprintf(
|
||||
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
|
||||
}
|
||||
|
||||
for _, prefix := range ipv6PrivatePrefixes {
|
||||
ipv6Instructions = append(ipv6Instructions, fmt.Sprintf(
|
||||
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
|
||||
}
|
||||
|
||||
return ipv4Instructions, ipv6Instructions
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool,
|
||||
) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp"
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
protocol, connection.Port)
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
// Thanks to @npawelek.
|
||||
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
|
||||
) error {
|
||||
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
|
||||
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
|
||||
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
|
||||
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
|
||||
// all interfaces. If remove is true, the rule is removed instead of added.
|
||||
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
|
||||
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
|
||||
// intf set to the VPN tunnel interface.
|
||||
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectPort redirects incoming traffic on the specified source port to the
|
||||
// specified destination port, for both TCP and UDP protocols, on the interface intf.
|
||||
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
|
||||
// the redirection is removed instead of added. This is used for VPN server side
|
||||
// port forwarding, with intf set to the VPN tunnel interface.
|
||||
func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
sourcePort, destinationPort uint16, remove bool,
|
||||
) (err error) {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx) // just in case
|
||||
errMessage := err.Error()
|
||||
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
|
||||
if !remove {
|
||||
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
var ipv4 bool
|
||||
var rule string
|
||||
switch {
|
||||
case strings.HasPrefix(line, "iptables "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables ")
|
||||
case strings.HasPrefix(line, "iptables-nft "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-nft ")
|
||||
case strings.HasPrefix(line, "iptables-legacy "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-legacy ")
|
||||
case strings.HasPrefix(line, "ip6tables "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables ")
|
||||
case strings.HasPrefix(line, "ip6tables-nft "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-nft ")
|
||||
case strings.HasPrefix(line, "ip6tables-legacy "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstruction(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstruction(ctx, rule)
|
||||
}
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
}
|
||||
@@ -1,339 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type operation uint8
|
||||
|
||||
const (
|
||||
opNone operation = iota
|
||||
opAppend
|
||||
opDelete
|
||||
opInsert
|
||||
opReplace
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
operation operation
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
sourcePort uint16 // if zero, there is no source port
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
connMark mark
|
||||
setMark uint // only used for jump CONNMARK --set-mark
|
||||
rejectWith string // only used for REJECT targets
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case i.sourcePort != rule.sourcePort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
case i.mark != rule.mark:
|
||||
return false
|
||||
case i.connMark != rule.connMark:
|
||||
return false
|
||||
case i.setMark != rule.setMark:
|
||||
return false
|
||||
case i.rejectWith != rule.rejectWith:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
consumed, err = preCheckInstructionFields(fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
flag := fields[0]
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.operation = opDelete
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.operation = opAppend
|
||||
instruction.chain = value
|
||||
case "-I", "--insert":
|
||||
instruction.operation = opInsert
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
subConsumed, err := parseJumpFlag(fields[1:], instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing jump flag: %w", err)
|
||||
}
|
||||
consumed += subConsumed
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match":
|
||||
consumed, err = parseMatchModule(fields, instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||
}
|
||||
case "--mark":
|
||||
n, err := parseAny32bNumber(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing mark value %q: %w", value, err)
|
||||
}
|
||||
instruction.mark.value = n
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "--sport":
|
||||
instruction.sourcePort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
instruction.destinationPort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
instruction.toPorts, err = parseToPorts(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
case "--reject-with":
|
||||
instruction.rejectWith = value // for example "tcp-reset"
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags":
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseJumpFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
instruction.target = fields[0]
|
||||
// consumed in the caller already takes fields[0] into account
|
||||
if instruction.target != "CONNMARK" {
|
||||
return consumed, nil
|
||||
}
|
||||
// consumed already accounts for the "CONNMARK" value
|
||||
const expectedFields = 3
|
||||
if len(fields) < expectedFields {
|
||||
return 0, fmt.Errorf("%w: jump CONNMARK requires at least two additional values",
|
||||
ErrIptablesCommandMalformed)
|
||||
}
|
||||
switch fields[1] {
|
||||
case "--set-mark":
|
||||
n, err := parseAny32bNumber(fields[2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing connmark mark value %q: %w", fields[2], err)
|
||||
}
|
||||
consumed++
|
||||
instruction.setMark = n
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported jump CONNMARK with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[1])
|
||||
}
|
||||
consumed++
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
func parsePort(value string) (port uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(portValue), nil
|
||||
}
|
||||
|
||||
func parseAny32bNumber(mark string) (value uint, err error) {
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
n, err := strconv.ParseUint(mark, base, bits)
|
||||
return uint(n), err
|
||||
}
|
||||
|
||||
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed int, err error,
|
||||
) {
|
||||
_ = fields[consumed] // -m or --match flag already detected
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "tcp", "udp":
|
||||
consumed++
|
||||
// for now ignore the protocol match since it's auto-loaded
|
||||
// when parsing the -p/--protocol flag, and we don't need to
|
||||
// parse it twice.
|
||||
case "mark":
|
||||
consumed++
|
||||
switch {
|
||||
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
|
||||
// end or another flag
|
||||
return consumed, nil
|
||||
case fields[consumed] == "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
case "connmark":
|
||||
consumed++
|
||||
switch {
|
||||
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
|
||||
// end or another flag
|
||||
return consumed, nil
|
||||
case fields[consumed] == "!":
|
||||
consumed++
|
||||
instruction.connMark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match connmark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||
portStrings := strings.Split(value, ",")
|
||||
toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
toPorts[i], err = parsePort(portString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toPorts, nil
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
type tcpFlags struct {
|
||||
mask []tcpFlag
|
||||
comparison []tcpFlag
|
||||
}
|
||||
|
||||
type tcpFlag uint8
|
||||
|
||||
const (
|
||||
tcpFlagFIN tcpFlag = 1 << iota
|
||||
tcpFlagSYN
|
||||
tcpFlagRST
|
||||
tcpFlagPSH
|
||||
tcpFlagACK
|
||||
tcpFlagURG
|
||||
tcpFlagECE
|
||||
tcpFlagCWR
|
||||
)
|
||||
|
||||
func (f tcpFlag) String() string {
|
||||
switch f {
|
||||
case tcpFlagFIN:
|
||||
return "FIN"
|
||||
case tcpFlagSYN:
|
||||
return "SYN"
|
||||
case tcpFlagRST:
|
||||
return "RST"
|
||||
case tcpFlagPSH:
|
||||
return "PSH"
|
||||
case tcpFlagACK:
|
||||
return "ACK"
|
||||
case tcpFlagURG:
|
||||
return "URG"
|
||||
case tcpFlagECE:
|
||||
return "ECE"
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
|
||||
}
|
||||
for _, flag := range allFlags {
|
||||
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("libxt_mark.so module is missing")
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||
}
|
||||
|
||||
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
|
||||
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
run := c.runIptablesInstruction
|
||||
if dst.Addr().Is6() {
|
||||
run = c.runIP6tablesInstruction
|
||||
}
|
||||
revert = func(ctx context.Context) error {
|
||||
return run(ctx, revertInstruction)
|
||||
}
|
||||
err = run(ctx, instruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running instruction: %w", err)
|
||||
}
|
||||
return revert, nil
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -26,21 +26,10 @@ type chainRule struct {
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
sourcePort uint16 // Not specified if set to zero.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
connMark mark
|
||||
setMark uint
|
||||
rejectWith string // for example "tcp-reset", only used for REJECT targets
|
||||
}
|
||||
|
||||
type mark struct {
|
||||
invert bool
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
@@ -222,6 +211,10 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
@@ -248,23 +241,19 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
i := 0
|
||||
for i < len(optionalFields) {
|
||||
switch optionalFields[i] {
|
||||
case "udp":
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
consumed, err := parseUDPOptional(optionalFields[i:], rule)
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing UDP optional fields: %w", err)
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
i += consumed
|
||||
case "tcp":
|
||||
i++
|
||||
consumed, err := parseTCPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing TCP optional fields: %w", err)
|
||||
}
|
||||
i += consumed
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
@@ -275,163 +264,20 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
i++
|
||||
case "mark":
|
||||
i++
|
||||
mark, consumed, err := parseMark(optionalFields[i:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mark: %w", err)
|
||||
}
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
case "reject-with":
|
||||
i++
|
||||
rule.rejectWith = optionalFields[i] // for example "tcp-reset"
|
||||
i++
|
||||
case "connmark":
|
||||
i++
|
||||
connMark, consumed, err := parseMark(optionalFields[i:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing connmark: %w", err)
|
||||
}
|
||||
rule.connMark = connMark
|
||||
i += consumed
|
||||
case "CONNMARK":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "set":
|
||||
i++
|
||||
value, err := parseAny32bNumber(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing CONNMARK set value: %w", err)
|
||||
}
|
||||
rule.setMark = value
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after CONNMARK",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a UDP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a TCP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "flags:"):
|
||||
rule.tcpFlags, err = parseTCPFlags(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseDestinationPort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
func parseSourcePort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "spt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
var err error
|
||||
for i, maskFlag := range maskFlags {
|
||||
mask[i], err = parseTCPFlag(maskFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
|
||||
}
|
||||
}
|
||||
comparisonFlags := strings.Split(fields[1], ",")
|
||||
comparison := make([]tcpFlag, len(comparisonFlags))
|
||||
for i, comparisonFlag := range comparisonFlags {
|
||||
comparison[i], err = parseTCPFlag(comparisonFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
|
||||
}
|
||||
}
|
||||
return tcpFlags{
|
||||
mask: mask,
|
||||
comparison: comparison,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
@@ -440,36 +286,16 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
ports[i], err = parsePort(field)
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
consumed++
|
||||
if optionalFields[consumed] == "!" {
|
||||
m.invert = true
|
||||
consumed++
|
||||
}
|
||||
|
||||
value, err := parseAny32bNumber(optionalFields[consumed])
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("value malformed: %w", err)
|
||||
}
|
||||
m.value = value
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger)
|
||||
|
||||
// Package iptables is a generated GoMock package.
|
||||
package iptables
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
exec "os/exec"
|
||||
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
|
||||
}
|
||||
|
||||
firewallUpdated = true
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subNet, remove)
|
||||
if err != nil {
|
||||
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
||||
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
|
||||
}
|
||||
|
||||
firewallUpdated = true
|
||||
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||
defaultRoute.AssignedIP, subnet, remove)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match": // ignore match
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
@@ -33,9 +33,9 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
operation: opAppend,
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
operation: opAppend,
|
||||
append: true,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
operation: opDelete,
|
||||
append: false,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
@@ -35,7 +35,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
|
||||
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
|
||||
|
||||
const remove = false
|
||||
if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("allowing input to port %d through interface %s: %w",
|
||||
port, intf, err)
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
const remove = true
|
||||
for netInterface := range interfacesSet {
|
||||
err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
|
||||
err := c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing allowed port %d on interface %s: %w",
|
||||
port, netInterface, err)
|
||||
|
||||
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
return nil
|
||||
case conflict != nil:
|
||||
const remove = true
|
||||
err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
|
||||
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
|
||||
conflict.destinationPort, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing conflicting redirection: %w", err)
|
||||
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
}
|
||||
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
|
||||
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting port: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrIPTablesNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
|
||||
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
func testIptablesPath(ctx context.Context, path string,
|
||||
@@ -1,4 +1,4 @@
|
||||
package iptables
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errSentinel: ErrIPTablesNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove := true
|
||||
if c.vpnConnection.IP.IsValid() {
|
||||
for _, defaultRoute := range c.defaultRoutes {
|
||||
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
c.vpnConnection = models.Connection{}
|
||||
|
||||
if c.vpnIntf != "" {
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove = false
|
||||
|
||||
for _, defaultRoute := range c.defaultRoutes {
|
||||
if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
|
||||
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
|
||||
}
|
||||
}
|
||||
c.vpnConnection = connection
|
||||
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
|
||||
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
|
||||
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
|
||||
}
|
||||
c.vpnIntf = vpnIntf
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (c *Config) Version(ctx context.Context) (version string, err error) {
|
||||
return c.impl.Version(ctx)
|
||||
}
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
|
||||
}
|
||||
@@ -23,7 +23,6 @@ type Checker struct {
|
||||
logger Logger
|
||||
icmpTargetIPs []netip.Addr
|
||||
smallCheckType string
|
||||
startupOnFail bool
|
||||
configMutex sync.Mutex
|
||||
|
||||
icmpNotPermitted *bool
|
||||
@@ -46,43 +45,26 @@ func NewChecker(logger Logger) *Checker {
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the following:
|
||||
// - TCP+TLS dial addresses
|
||||
// - ICMP echo IP addresses to target
|
||||
// - the desired small check type (dns or icmp)
|
||||
// - whether to startup the periodic checks if the startup check fails.
|
||||
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
|
||||
// to target and the desired small check type (dns or icmp).
|
||||
// This function MUST be called before calling [Checker.Start].
|
||||
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
|
||||
smallCheckType string, startupOnFail bool,
|
||||
smallCheckType string,
|
||||
) {
|
||||
c.configMutex.Lock()
|
||||
defer c.configMutex.Unlock()
|
||||
c.tlsDialAddrs = tlsDialAddrs
|
||||
c.icmpTargetIPs = icmpTargets
|
||||
c.smallCheckType = smallCheckType
|
||||
c.startupOnFail = startupOnFail
|
||||
}
|
||||
|
||||
// Start starts the [Checker] which behaves differently according to its
|
||||
// internal field startupOnFail, which is set by calling [Checker.SetConfig].
|
||||
//
|
||||
// By default, startupOnFail should be false and the behavior is as follows:
|
||||
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
|
||||
// an error is returned and the [Checker] is not started.
|
||||
// On success, it starts the periodic checks in a separate goroutine, returning
|
||||
// the runError error channel and a nil error.
|
||||
//
|
||||
// If startupOnFail is true, the behavior is as follows:
|
||||
// A blocking 6s-timed TCP+TLS check is performed first. If it fails,
|
||||
// the error is sent to the runError channel, but no error is returned
|
||||
// and the [Checker] continues to start the periodic checks in a separate goroutine, returning
|
||||
// the runError error channel and a nil error.
|
||||
//
|
||||
// The periodic checks consist in:
|
||||
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
|
||||
// and, on success, starts the periodic checks in a separate goroutine:
|
||||
// - a "small" ICMP echo check every minute
|
||||
// - a "full" TCP+TLS check every 5 minutes
|
||||
//
|
||||
// The [Checker] has to be ultimately stopped by calling [Checker.Stop].
|
||||
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
|
||||
// It returns an error if the initial TCP+TLS check fails.
|
||||
// The Checker has to be ultimately stopped by calling [Checker.Stop].
|
||||
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
|
||||
panic("call Checker.SetConfig with non empty values before Checker.Start")
|
||||
@@ -94,19 +76,9 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
}
|
||||
c.echoer.Reset()
|
||||
|
||||
// runErrorCh MUST be buffered in the case startupOnFail is true, and
|
||||
// a startup error was encountered, to avoid blocking the startup
|
||||
// goroutine when sending the error, especially since the caller may
|
||||
// not be ready to receive from the channel yet.
|
||||
runErrorCh := make(chan error, 1)
|
||||
runError = runErrorCh
|
||||
err = c.startupCheck(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("startup check: %w", err)
|
||||
if !c.startupOnFail {
|
||||
return nil, err
|
||||
}
|
||||
runErrorCh <- err
|
||||
return nil, fmt.Errorf("startup check: %w", err)
|
||||
}
|
||||
|
||||
ready := make(chan struct{})
|
||||
@@ -118,6 +90,8 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
smallCheckTimer := time.NewTimer(smallCheckPeriod)
|
||||
const fullCheckPeriod = 5 * time.Minute
|
||||
fullCheckTimer := time.NewTimer(fullCheckPeriod)
|
||||
runErrorCh := make(chan error)
|
||||
runError = runErrorCh
|
||||
go func() {
|
||||
defer close(done)
|
||||
close(ready)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
|
||||
@@ -22,9 +21,6 @@ func NewClient(httpClient *http.Client) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Check(ctx context.Context, url string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errBuiltinModuleNotFound = errors.New("builtin module not found")
|
||||
|
||||
func checkModulesBuiltin(modulesPath, moduleName string) error {
|
||||
f, err := os.Open(filepath.Join(modulesPath, "modules.builtin"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
moduleName = strings.TrimSuffix(moduleName, ".ko")
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSuffix(line, ".ko")
|
||||
if strings.HasSuffix(line, "/"+moduleName) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %s", errBuiltinModuleNotFound, moduleName)
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errModuleNameUnknown = errors.New("unknown module name")
|
||||
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
|
||||
errKernelFeatureNotSet = errors.New("kernel feature not set")
|
||||
errKernelFeatureNotFound = errors.New("kernel feature not found")
|
||||
)
|
||||
|
||||
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
||||
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
||||
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
||||
// feature is a module, not built-in.
|
||||
// If the kernel feature is found and not set, it returns an error indicating that the kernel
|
||||
// feature is not set. If the kernel feature is not found, it returns an error indicating that the kernel
|
||||
// feature is not found.
|
||||
func checkProcConfig(moduleName string) error {
|
||||
f, err := os.Open("/proc/config.gz")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
// If any group of kernel features is satisfied, then the module is considered supported.
|
||||
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
|
||||
}
|
||||
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
||||
for i, group := range kernelFeatureGroups {
|
||||
featureToOK := make(map[string]bool)
|
||||
for _, feature := range group {
|
||||
featureToOK[feature] = false
|
||||
}
|
||||
groups[i] = featureToOK
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(gz)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
for _, featureToOK := range groups {
|
||||
for name, ok := range featureToOK {
|
||||
switch {
|
||||
case ok:
|
||||
case strings.HasPrefix(line, name+"=m"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
|
||||
case strings.HasPrefix(line, name+"=y"):
|
||||
featureToOK[name] = true
|
||||
if allFeaturesOK(featureToOK) {
|
||||
return nil
|
||||
}
|
||||
case strings.HasPrefix(line, "# "+name+" is not set"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
|
||||
}
|
||||
|
||||
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
||||
moduleMap := map[string][][]string{
|
||||
"nf_tables": {{"CONFIG_NF_TABLES"}},
|
||||
|
||||
// Netfilter Matches
|
||||
"xt_conntrack": {{"CONFIG_NETFILTER_XT_MATCH_CONNTRACK"}},
|
||||
"xt_connmark": {
|
||||
{"CONFIG_NETFILTER_XT_CONNMARK"},
|
||||
{"CONFIG_NETFILTER_XT_MATCH_CONNMARK", "CONFIG_NETFILTER_XT_TARGET_CONNMARK"},
|
||||
},
|
||||
"xt_mark": {
|
||||
{"CONFIG_NETFILTER_XT_MARK"},
|
||||
{"CONFIG_NETFILTER_XT_MATCH_MARK", "CONFIG_NETFILTER_XT_TARGET_MARK"},
|
||||
},
|
||||
"nf_conntrack_netlink": {{"CONFIG_NF_CT_NETLINK"}},
|
||||
"nf_reject_ipv4": {{"CONFIG_NF_REJECT_IPV4"}},
|
||||
|
||||
// Common Netfilter Targets
|
||||
"xt_log": {{"CONFIG_NETFILTER_XT_TARGET_LOG"}},
|
||||
"xt_reject": {
|
||||
{"CONFIG_IP_NF_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
|
||||
{"CONFIG_NETFILTER_XT_TARGET_REJECT", "CONFIG_NF_REJECT_IPV4"},
|
||||
},
|
||||
"xt_masquerade": {{"CONFIG_NETFILTER_XT_TARGET_MASQUERADE"}},
|
||||
|
||||
// Additional Netfilter Matches
|
||||
"xt_addrtype": {{"CONFIG_NETFILTER_XT_MATCH_ADDRTYPE"}},
|
||||
"xt_comment": {{"CONFIG_NETFILTER_XT_MATCH_COMMENT"}},
|
||||
"xt_multiport": {{"CONFIG_NETFILTER_XT_MATCH_MULTIPORT"}},
|
||||
"xt_state": {{"CONFIG_NETFILTER_XT_MATCH_STATE"}},
|
||||
"xt_tcpudp": {{"CONFIG_NETFILTER_XT_MATCH_TCPUDP"}},
|
||||
|
||||
// Tunneling and Virtualization
|
||||
"tun": {{"CONFIG_TUN"}},
|
||||
"bridge": {{"CONFIG_BRIDGE"}},
|
||||
"veth": {{"CONFIG_VETH"}},
|
||||
"vxlan": {{"CONFIG_VXLAN"}},
|
||||
"wireguard": {{"CONFIG_WIREGUARD"}},
|
||||
|
||||
// Filesystems
|
||||
"overlay": {{"CONFIG_OVERLAY_FS"}},
|
||||
"fuse": {{"CONFIG_FUSE_FS"}},
|
||||
}
|
||||
|
||||
featureGroups, ok = moduleMap[strings.ToLower(moduleName)]
|
||||
return featureGroups, ok
|
||||
}
|
||||
|
||||
func allFeaturesOK(featureToOK map[string]bool) bool {
|
||||
for _, ok := range featureToOK {
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !windows
|
||||
|
||||
package mod
|
||||
|
||||
import (
|
||||
@@ -30,7 +28,36 @@ type moduleInfo struct {
|
||||
|
||||
var ErrModulesDirectoryNotFound = errors.New("modules directory not found")
|
||||
|
||||
func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err error) {
|
||||
func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) {
|
||||
var utsName unix.Utsname
|
||||
err = unix.Uname(&utsName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting unix uname release: %w", err)
|
||||
}
|
||||
release := unix.ByteSliceToString(utsName.Release[:])
|
||||
release = strings.TrimSpace(release)
|
||||
|
||||
modulePaths := []string{
|
||||
filepath.Join("/lib/modules", release),
|
||||
filepath.Join("/usr/lib/modules", release),
|
||||
}
|
||||
|
||||
var modulesPath string
|
||||
var found bool
|
||||
for _, modulesPath = range modulePaths {
|
||||
info, err := os.Stat(modulesPath)
|
||||
if err == nil && info.IsDir() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("%w: %s are not valid existing directories"+
|
||||
"; have you bind mounted the /lib/modules directory?",
|
||||
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
|
||||
}
|
||||
|
||||
dependencyFilepath := filepath.Join(modulesPath, "modules.dep")
|
||||
dependencyFile, err := os.Open(dependencyFilepath)
|
||||
if err != nil {
|
||||
@@ -82,39 +109,6 @@ func getModulesInfo(modulesPath string) (modulesInfo map[string]moduleInfo, err
|
||||
return modulesInfo, nil
|
||||
}
|
||||
|
||||
func getModulesPath() (string, error) {
|
||||
release, err := getReleaseName()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting release name: %w", err)
|
||||
}
|
||||
|
||||
modulePaths := []string{
|
||||
filepath.Join("/lib/modules", release),
|
||||
filepath.Join("/usr/lib/modules", release),
|
||||
}
|
||||
|
||||
for _, modulesPath := range modulePaths {
|
||||
info, err := os.Stat(modulesPath)
|
||||
if err == nil && info.IsDir() {
|
||||
return modulesPath, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%w: %s are not valid existing directories"+
|
||||
"; have you bind mounted the /lib/modules directory?",
|
||||
ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", "))
|
||||
}
|
||||
|
||||
func getReleaseName() (release string, err error) {
|
||||
var utsName unix.Utsname
|
||||
err = unix.Uname(&utsName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting unix uname release: %w", err)
|
||||
}
|
||||
release = unix.ByteSliceToString(utsName.Release[:])
|
||||
release = strings.TrimSpace(release)
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error {
|
||||
file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin"))
|
||||
if err != nil {
|
||||
@@ -0,0 +1,37 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Probe loads the given kernel module and its dependencies.
|
||||
func Probe(moduleName string) error {
|
||||
modulesInfo, err := getModulesInfo()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting modules information: %w", err)
|
||||
}
|
||||
|
||||
modulePath, err := findModulePath(moduleName, modulesInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding module path: %w", err)
|
||||
}
|
||||
|
||||
info := modulesInfo[modulePath]
|
||||
if info.state == builtin || info.state == loaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
info.state = loading
|
||||
for _, dependencyModulePath := range info.dependencyPaths {
|
||||
err = initDependencies(dependencyModulePath, modulesInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = initModule(modulePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init module: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Probe is a expanded version of modprobe, in which it checks if the Kernel
|
||||
// built-in features contain the given module name.
|
||||
// It first tries to locate the modules directory in [getModulesPath].
|
||||
// If it fails (like on WSL), it then only checks for the kernel feature
|
||||
// in /proc/config.gz with [checkProcConfig].
|
||||
// Otherwise, it first checks if the modules directory modules.builtin
|
||||
// file contains the given module name in [checkModulesBuiltin].
|
||||
// If the module is not found, it then runs the classic [modProbe] behavior,
|
||||
// trying to load the module in the kernel.
|
||||
// If this fails, it does one final try running [checkProcConfig].
|
||||
func Probe(moduleName string) error {
|
||||
modulesPath, err := getModulesPath()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrModulesDirectoryNotFound) {
|
||||
err = checkProcConfig(moduleName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking /proc/config.gz: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("getting modules path: %w", err)
|
||||
}
|
||||
|
||||
err = checkModulesBuiltin(modulesPath, moduleName)
|
||||
if err != nil {
|
||||
err = modProbe(modulesPath, moduleName)
|
||||
if err != nil {
|
||||
err = checkProcConfig(moduleName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking /proc/config.gz: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// modProbe is the classic modprobe behavior.
|
||||
func modProbe(modulesPath, moduleName string) error {
|
||||
modulesInfo, err := getModulesInfo(modulesPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting modules information: %w", err)
|
||||
}
|
||||
|
||||
modulePath, err := findModulePath(moduleName, modulesInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding module path: %w", err)
|
||||
}
|
||||
|
||||
info := modulesInfo[modulePath]
|
||||
if info.state == builtin || info.state == loaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
info.state = loading
|
||||
for _, dependencyModulePath := range info.dependencyPaths {
|
||||
err = initDependencies(dependencyModulePath, modulesInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = initModule(modulePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init module: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package mod
|
||||
|
||||
func Probe(moduleName string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
+17
-59
@@ -1,75 +1,33 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink/rtnl"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
|
||||
ipPrefixes []netip.Prefix, err error,
|
||||
func (n *NetLink) AddrList(link Link, family int) (
|
||||
addresses []Addr, err error,
|
||||
) {
|
||||
conn, err := rtnl.Dial(nil)
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ifc := &net.Interface{
|
||||
Index: int(linkIndex),
|
||||
}
|
||||
ipNets, err := conn.Addrs(ifc, int(family))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list addresses: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipPrefixes = make([]netip.Prefix, len(ipNets))
|
||||
for i := range ipNets {
|
||||
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
|
||||
addresses = make([]Addr, len(netlinkAddresses))
|
||||
for i := range netlinkAddresses {
|
||||
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
|
||||
}
|
||||
|
||||
return ipPrefixes, nil
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
|
||||
conn, err := rtnl.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ipNet := netipPrefixToIPNet(prefix)
|
||||
|
||||
// Remove any address identical to the one we want to add
|
||||
family := FamilyV4
|
||||
if prefix.Addr().Is6() {
|
||||
family = FamilyV6
|
||||
}
|
||||
ifc := &net.Interface{
|
||||
Index: int(linkIndex),
|
||||
}
|
||||
addresses, err := conn.Addrs(ifc, int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing addresses: %w", err)
|
||||
}
|
||||
for _, address := range addresses {
|
||||
if address.IP.Equal(ipNet.IP) &&
|
||||
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
|
||||
err = conn.AddrDel(ifc, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting address from interface: %w", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
netlinkAddress := netlink.Addr{
|
||||
IPNet: netipPrefixToIPNet(addr.Network),
|
||||
}
|
||||
|
||||
// Add the new address to the interface
|
||||
err = conn.AddrAdd(ifc, ipNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding address to interface: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) AddrList(link Link, family int) (
|
||||
addresses []Addr, err error,
|
||||
) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) AddrReplace(Link, Addr) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/ti-mo/netfilter"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel")
|
||||
|
||||
func (n *NetLink) FlushConntrack() error {
|
||||
conn, err := netfilter.Dial(nil)
|
||||
if err != nil {
|
||||
if !n.conntrackNetlink {
|
||||
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
|
||||
}
|
||||
return fmt.Errorf("dialing netfilter: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
const ipCtnlMsgCtDelete = netfilter.MessageType(2)
|
||||
header := netfilter.Header{
|
||||
SubsystemID: netfilter.NFSubsysCTNetlink,
|
||||
MessageType: ipCtnlMsgCtDelete,
|
||||
Family: unix.AF_UNSPEC,
|
||||
Flags: netlink.Request | netlink.Acknowledge,
|
||||
}
|
||||
request, err := netfilter.MarshalNetlink(header, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding netlink request: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.Query(request)
|
||||
if err != nil {
|
||||
if !n.conntrackNetlink {
|
||||
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
|
||||
}
|
||||
return fmt.Errorf("querying netlink request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrConntrackNetlinkNotSupported = errors.New("error not implemented")
|
||||
|
||||
func (n *NetLink) FlushConntrack() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -36,30 +36,6 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
|
||||
return netip.PrefixFrom(ip, bits)
|
||||
}
|
||||
|
||||
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
|
||||
if ip == nil || len(*ip) == 0 {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
var dstIP netip.Addr
|
||||
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
|
||||
dstIP = netip.AddrFrom4([4]byte(*ip))
|
||||
} else {
|
||||
dstIP = netip.AddrFrom16([16]byte(*ip))
|
||||
}
|
||||
return netip.PrefixFrom(dstIP, int(length))
|
||||
}
|
||||
|
||||
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
|
||||
if !prefix.IsValid() {
|
||||
return nil, 0
|
||||
}
|
||||
prefixIP := prefix.Addr().Unmap()
|
||||
ip = new(net.IP)
|
||||
*ip = netipAddrToNetIP(prefixIP)
|
||||
length = uint8(prefix.Bits()) //nolint:gosec
|
||||
return ip, length
|
||||
}
|
||||
|
||||
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
|
||||
switch {
|
||||
case !address.IsValid():
|
||||
|
||||
@@ -4,7 +4,13 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func FamilyToString(family uint8) string {
|
||||
const (
|
||||
FamilyAll = 0
|
||||
FamilyV4 = 2
|
||||
FamilyV6 = 10
|
||||
)
|
||||
|
||||
func FamilyToString(family int) string {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
return "all"
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
FamilyAll uint8 = unix.AF_UNSPEC
|
||||
FamilyV4 uint8 = unix.AF_INET
|
||||
FamilyV6 uint8 = unix.AF_INET6
|
||||
)
|
||||
@@ -1,30 +1,16 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
func ptrTo[T any](v T) *T { return &v }
|
||||
|
||||
func makeNetipPrefix(n byte) netip.Prefix {
|
||||
const bits = 24
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
||||
}
|
||||
|
||||
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||
|
||||
func makeLinkName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||
name := make([]byte, 8)
|
||||
for i := range name {
|
||||
name[i] = alphabet[rng.IntN(len(alphabet))]
|
||||
}
|
||||
return "test" + string(name)
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Debug(_ string) {}
|
||||
|
||||
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
|
||||
return false, fmt.Errorf("finding link corresponding to route: %w", err)
|
||||
}
|
||||
|
||||
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
|
||||
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
|
||||
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
||||
switch {
|
||||
case !sourceIsIPv6 && !destinationIsIPv6,
|
||||
|
||||
+76
-162
@@ -1,191 +1,105 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
)
|
||||
|
||||
type DeviceType uint16
|
||||
|
||||
type Link struct {
|
||||
Index uint32
|
||||
Name string
|
||||
DeviceType DeviceType
|
||||
VirtualType string
|
||||
MTU uint32
|
||||
}
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
netlinkLinks, err := netlink.LinkList()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
linkMessages, err := conn.Link.List()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
links = make([]Link, len(linkMessages))
|
||||
for i, message := range linkMessages {
|
||||
virtualType := ""
|
||||
if message.Attributes.Info != nil {
|
||||
virtualType = message.Attributes.Info.Kind
|
||||
}
|
||||
links[i] = Link{
|
||||
Index: message.Index,
|
||||
Name: message.Attributes.Name,
|
||||
DeviceType: DeviceType(message.Type),
|
||||
VirtualType: virtualType,
|
||||
MTU: message.Attributes.MTU,
|
||||
}
|
||||
links = make([]Link, len(netlinkLinks))
|
||||
for i := range netlinkLinks {
|
||||
links[i] = netlinkLinkToLink(netlinkLinks[i])
|
||||
}
|
||||
|
||||
return links, nil
|
||||
}
|
||||
|
||||
var ErrLinkNotFound = errors.New("link not found")
|
||||
|
||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
links, err := n.LinkList()
|
||||
netlinkLink, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||
return Link{}, err
|
||||
}
|
||||
|
||||
for _, link := range links {
|
||||
if link.Name == name {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
|
||||
return netlinkLinkToLink(netlinkLink), nil
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||
links, err := n.LinkList()
|
||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
||||
netlinkLink, err := netlink.LinkByIndex(index)
|
||||
if err != nil {
|
||||
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||
return Link{}, err
|
||||
}
|
||||
|
||||
for _, link = range links {
|
||||
if link.Index == index {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
|
||||
return netlinkLinkToLink(netlinkLink), nil
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
err = netlink.LinkAdd(netlinkLink)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("dialing netlink: %w", err)
|
||||
return 0, err
|
||||
}
|
||||
defer conn.Close()
|
||||
return netlinkLink.Attrs().Index, nil
|
||||
}
|
||||
|
||||
tx := &rtnetlink.LinkMessage{
|
||||
Type: uint16(link.DeviceType),
|
||||
Attributes: &rtnetlink.LinkAttributes{
|
||||
MTU: link.MTU,
|
||||
Name: link.Name,
|
||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
||||
return netlink.LinkDel(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
err = netlink.LinkSetUp(netlinkLink)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return netlinkLink.Attrs().Index, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
type netlinkLinkImpl struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
linkType string
|
||||
}
|
||||
|
||||
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
|
||||
return n.attrs
|
||||
}
|
||||
|
||||
func (n *netlinkLinkImpl) Type() string {
|
||||
return n.linkType
|
||||
}
|
||||
|
||||
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
|
||||
attributes := netlinkLink.Attrs()
|
||||
return Link{
|
||||
Type: netlinkLink.Type(),
|
||||
Name: attributes.Name,
|
||||
Index: attributes.Index,
|
||||
EncapType: attributes.EncapType,
|
||||
MTU: uint16(attributes.MTU), //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
|
||||
// so that the vishvananda/netlink package can compare the returned
|
||||
// value against an untyped nil.
|
||||
func linkToNetlinkLink(link *Link) netlink.Link {
|
||||
if link == nil {
|
||||
return nil
|
||||
}
|
||||
return &netlinkLinkImpl{
|
||||
linkType: link.Type,
|
||||
attrs: &netlink.LinkAttrs{
|
||||
Name: link.Name,
|
||||
Index: link.Index,
|
||||
EncapType: link.EncapType,
|
||||
MTU: int(link.MTU),
|
||||
},
|
||||
}
|
||||
if link.VirtualType != "" {
|
||||
tx.Attributes.Info = &rtnetlink.LinkInfo{
|
||||
Kind: link.VirtualType,
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Link.New(tx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("creating new link: %w", err)
|
||||
}
|
||||
|
||||
linkMessages, err := conn.Link.List()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
for _, linkMessage := range linkMessages {
|
||||
if linkMessage.Attributes.Name == link.Name {
|
||||
return linkMessage.Index, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Link.Delete(linkIndex)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
rx, err := conn.Link.Get(linkIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting link: %w", err)
|
||||
}
|
||||
tx := &rtnetlink.LinkMessage{
|
||||
Type: rx.Type,
|
||||
Index: linkIndex,
|
||||
Flags: iffUp,
|
||||
Change: iffUp,
|
||||
}
|
||||
return conn.Link.Set(tx)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
linkInfo, err := conn.Link.Get(linkIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting link: %w", err)
|
||||
}
|
||||
message := &rtnetlink.LinkMessage{
|
||||
Type: linkInfo.Type,
|
||||
Index: linkIndex,
|
||||
Flags: 0,
|
||||
Change: iffUp,
|
||||
}
|
||||
return conn.Link.Set(message)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
message := &rtnetlink.LinkMessage{
|
||||
Index: linkIndex,
|
||||
Attributes: &rtnetlink.LinkAttributes{
|
||||
MTU: mtu,
|
||||
},
|
||||
}
|
||||
|
||||
err = conn.Link.Set(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
|
||||
mtu, linkIndex, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
|
||||
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
|
||||
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
|
||||
|
||||
iffUp = unix.IFF_UP
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_NetLink_LinkList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlink := &NetLink{}
|
||||
|
||||
initialLinks, err := netlink.LinkList()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, initialLinks)
|
||||
|
||||
loopbackFound := false
|
||||
for _, link := range initialLinks {
|
||||
if link.Name != "lo" {
|
||||
continue
|
||||
}
|
||||
loopbackFound = true
|
||||
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
|
||||
break
|
||||
}
|
||||
assert.True(t, loopbackFound, "loopback interface not found")
|
||||
|
||||
testLink := Link{
|
||||
Name: makeLinkName(),
|
||||
// note if [Link.VirtualType] is set, [Link.DeviceType]
|
||||
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
|
||||
DeviceType: DeviceTypeNone,
|
||||
VirtualType: "wireguard",
|
||||
MTU: 1420,
|
||||
}
|
||||
index, err := netlink.LinkAdd(testLink)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = netlink.LinkDel(index)
|
||||
})
|
||||
|
||||
links, err := netlink.LinkList()
|
||||
require.NoError(t, err)
|
||||
|
||||
testLink.Index = index
|
||||
for _, link := range links {
|
||||
if link.Name != testLink.Name {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testLink, link)
|
||||
return
|
||||
}
|
||||
t.Errorf("created link %q not found", testLink.Name)
|
||||
}
|
||||
|
||||
func Test_NetLink_LinkSetMTU(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlink := &NetLink{}
|
||||
|
||||
testLink := Link{
|
||||
Name: makeLinkName(),
|
||||
DeviceType: DeviceTypeNone,
|
||||
VirtualType: "wireguard",
|
||||
MTU: 1420,
|
||||
}
|
||||
index, err := netlink.LinkAdd(testLink)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = netlink.LinkDel(index)
|
||||
})
|
||||
testLink.Index = index
|
||||
|
||||
err = netlink.LinkSetMTU(index, 1500)
|
||||
require.NoError(t, err)
|
||||
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
require.NoError(t, err)
|
||||
testLink.MTU = 1500
|
||||
assert.Equal(t, testLink, link)
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,22 +1,14 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
import "github.com/qdm12/log"
|
||||
|
||||
type NetLink struct {
|
||||
debugLogger DebugLogger
|
||||
|
||||
// Fixed state
|
||||
conntrackNetlink bool
|
||||
}
|
||||
|
||||
func New(debugLogger DebugLogger) *NetLink {
|
||||
conntrackNetlink := mod.Probe("nf_conntrack_netlink") == nil
|
||||
return &NetLink{
|
||||
debugLogger: debugLogger,
|
||||
conntrackNetlink: conntrackNetlink,
|
||||
debugLogger: debugLogger,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
const (
|
||||
// FamilyAll is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyAll uint8 = iota
|
||||
// FamilyV4 is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyV4
|
||||
// FamilyV6 is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyV6
|
||||
|
||||
// DeviceTypeEthernet is a placeholder only and should not be used.
|
||||
DeviceTypeEthernet DeviceType = 0
|
||||
// DeviceTypeLoopback is a placeholder only and should not be used.
|
||||
DeviceTypeLoopback DeviceType = 0
|
||||
// DeviceTypeNone is a placeholder only and should not be used.
|
||||
DeviceTypeNone DeviceType = 0
|
||||
|
||||
// iffUp is a placeholder only and should not be used.
|
||||
iffUp = 0
|
||||
|
||||
// RouteTypeUnicast is a placeholder only and should not be used.
|
||||
RouteTypeUnicast = 0
|
||||
// ScopeUniverse is a placeholder only and should not be used.
|
||||
ScopeUniverse = 0
|
||||
// ProtoStatic is a placeholder only and should not be used.
|
||||
ProtoStatic = 0
|
||||
|
||||
// FlagInvert is a placeholder only and should not be used.
|
||||
FlagInvert = 0
|
||||
// ActionToTable is a placeholder only and should not be used.
|
||||
ActionToTable = 0
|
||||
|
||||
// rtTableCompat is a placeholder only and should not be used.
|
||||
rtTableCompat = 0
|
||||
)
|
||||
|
||||
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (bool, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
+48
-116
@@ -1,137 +1,69 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
LinkIndex uint32
|
||||
Dst netip.Prefix
|
||||
Src netip.Prefix
|
||||
Gw netip.Addr
|
||||
Priority uint32
|
||||
Family uint8
|
||||
Table uint32
|
||||
Type uint8
|
||||
Scope uint8
|
||||
Proto uint8
|
||||
AdvMSS uint32
|
||||
}
|
||||
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
|
||||
// We set the filter to netlink.RT_FILTER_TABLE so that
|
||||
// routes from all tables are listed, as long as the filter
|
||||
// table is set to 0.
|
||||
const filterMask = netlink.RT_FILTER_TABLE
|
||||
// The filter is not left to `nil` otherwise non-main tables
|
||||
// are ignored.
|
||||
filter := &netlink.Route{}
|
||||
|
||||
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
table := uint32(message.Table)
|
||||
if table == 0 || table == rtTableCompat {
|
||||
table = message.Attributes.Table
|
||||
}
|
||||
r.LinkIndex = message.Attributes.OutIface
|
||||
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
|
||||
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
|
||||
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
|
||||
r.Priority = message.Attributes.Priority
|
||||
r.Family = message.Family
|
||||
r.Table = table
|
||||
r.Type = message.Type
|
||||
r.Scope = message.Scope
|
||||
r.Proto = message.Protocol
|
||||
if metrics := message.Attributes.Metrics; metrics != nil {
|
||||
r.AdvMSS = metrics.AdvMSS
|
||||
}
|
||||
}
|
||||
|
||||
func (r Route) message() *rtnetlink.RouteMessage {
|
||||
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||
src, srcLength := prefixToIPAndLength(r.Src)
|
||||
var table uint8
|
||||
var extendedTable uint32
|
||||
if r.Table <= uint32(^uint8(0)) {
|
||||
table = uint8(r.Table)
|
||||
} else {
|
||||
table = rtTableCompat
|
||||
extendedTable = r.Table
|
||||
}
|
||||
message := &rtnetlink.RouteMessage{
|
||||
Family: r.Family,
|
||||
DstLength: dstLength,
|
||||
SrcLength: srcLength,
|
||||
Table: table,
|
||||
Type: r.Type,
|
||||
Scope: r.Scope,
|
||||
Protocol: r.Proto,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
OutIface: r.LinkIndex,
|
||||
Gateway: netipAddrToNetIP(r.Gw),
|
||||
Priority: r.Priority,
|
||||
Table: extendedTable,
|
||||
},
|
||||
}
|
||||
if src != nil { // src is optional
|
||||
message.Attributes.Src = *src
|
||||
}
|
||||
if dst != nil {
|
||||
message.Attributes.Dst = *dst
|
||||
}
|
||||
if r.AdvMSS != 0 {
|
||||
if message.Attributes.Metrics == nil {
|
||||
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
|
||||
}
|
||||
message.Attributes.Metrics.AdvMSS = r.AdvMSS
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
routeMessages, err := conn.Route.List()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes = make([]Route, 0, len(routeMessages))
|
||||
for _, routeMessage := range routeMessages {
|
||||
if family != FamilyAll && routeMessage.Family != family {
|
||||
continue
|
||||
}
|
||||
var route Route
|
||||
route.fromMessage(routeMessage)
|
||||
routes = append(routes, route)
|
||||
routes = make([]Route, len(netlinkRoutes))
|
||||
for i := range netlinkRoutes {
|
||||
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteAdd(route Route) error {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Route.Add(route.message())
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteAdd(&netlinkRoute)
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteDel(route Route) error {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Route.Delete(route.message())
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteDel(&netlinkRoute)
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteReplace(route Route) error {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Route.Replace(route.message())
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteReplace(&netlinkRoute)
|
||||
}
|
||||
|
||||
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
|
||||
return Route{
|
||||
LinkIndex: netlinkRoute.LinkIndex,
|
||||
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
|
||||
Src: netIPToNetipAddress(netlinkRoute.Src),
|
||||
Gw: netIPToNetipAddress(netlinkRoute.Gw),
|
||||
Priority: netlinkRoute.Priority,
|
||||
Family: netlinkRoute.Family,
|
||||
Table: netlinkRoute.Table,
|
||||
Type: netlinkRoute.Type,
|
||||
}
|
||||
}
|
||||
|
||||
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
|
||||
return netlink.Route{
|
||||
LinkIndex: route.LinkIndex,
|
||||
Dst: netipPrefixToIPNet(route.Dst),
|
||||
Src: netipAddrToNetIP(route.Src),
|
||||
Gw: netipAddrToNetIP(route.Gw),
|
||||
Priority: route.Priority,
|
||||
Family: route.Family,
|
||||
Table: route.Table,
|
||||
Type: route.Type,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
RouteTypeUnicast = unix.RTN_UNICAST
|
||||
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
|
||||
ProtoStatic = unix.RTPROT_STATIC
|
||||
|
||||
rtTableCompat = unix.RT_TABLE_COMPAT
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) RouteList(family int) (
|
||||
routes []Route, err error,
|
||||
) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteAdd(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteDel(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteReplace(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
+73
-78
@@ -1,96 +1,91 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
Priority *uint32
|
||||
Family uint8
|
||||
Table uint32
|
||||
Mark *uint32
|
||||
Src netip.Prefix
|
||||
Dst netip.Prefix
|
||||
Flags uint32
|
||||
Action uint8
|
||||
func NewRule() Rule {
|
||||
// defaults found from netlink.NewRule() for fields we use,
|
||||
// the rest of the defaults is set when converting from a `Rule`
|
||||
// to a `netlink.Rule`
|
||||
return Rule{
|
||||
Priority: -1,
|
||||
Mark: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
|
||||
table := uint32(message.Table)
|
||||
if table == 0 || table == rtTableCompat {
|
||||
table = *message.Attributes.Table
|
||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
case FamilyV4:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
case FamilyV6:
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
}
|
||||
r.Priority = message.Attributes.Priority
|
||||
r.Family = message.Family
|
||||
r.Table = table
|
||||
r.Mark = message.Attributes.FwMark
|
||||
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
|
||||
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
|
||||
r.Flags = message.Flags
|
||||
r.Action = message.Action
|
||||
netlinkRules, err := netlink.RuleList(family)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rules = make([]Rule, len(netlinkRules))
|
||||
for i := range netlinkRules {
|
||||
rules[i] = netlinkRuleToRule(netlinkRules[i])
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (r Rule) message() *rtnetlink.RuleMessage {
|
||||
src, srcLength := prefixToIPAndLength(r.Src)
|
||||
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||
|
||||
message := &rtnetlink.RuleMessage{
|
||||
Family: r.Family,
|
||||
SrcLength: srcLength,
|
||||
DstLength: dstLength,
|
||||
Flags: r.Flags,
|
||||
Action: r.Action,
|
||||
Attributes: &rtnetlink.RuleAttributes{
|
||||
Priority: r.Priority,
|
||||
FwMark: r.Mark,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Table <= uint32(^uint8(0)) {
|
||||
message.Table = uint8(r.Table)
|
||||
} else {
|
||||
message.Table = rtTableCompat
|
||||
message.Attributes.Table = &r.Table
|
||||
}
|
||||
|
||||
return message
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
n.debugLogger.Debug(ruleDbgMsg(true, rule))
|
||||
netlinkRule := ruleToNetlinkRule(rule)
|
||||
return netlink.RuleAdd(&netlinkRule)
|
||||
}
|
||||
|
||||
func (r Rule) String() string {
|
||||
from := "all"
|
||||
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
||||
from = r.Src.String()
|
||||
}
|
||||
|
||||
to := "all"
|
||||
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
||||
to = r.Dst.String()
|
||||
}
|
||||
|
||||
priority := ""
|
||||
if r.Priority != nil {
|
||||
priority = fmt.Sprintf(" %d", *r.Priority)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
|
||||
priority, from, to, r.Table)
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
n.debugLogger.Debug(ruleDbgMsg(false, rule))
|
||||
netlinkRule := ruleToNetlinkRule(rule)
|
||||
return netlink.RuleDel(&netlinkRule)
|
||||
}
|
||||
|
||||
func (r Rule) debugMessage(add bool) (debugMessage string) {
|
||||
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
|
||||
netlinkRule = *netlink.NewRule()
|
||||
netlinkRule.Priority = rule.Priority
|
||||
netlinkRule.Family = rule.Family
|
||||
netlinkRule.Table = rule.Table
|
||||
netlinkRule.Mark = rule.Mark
|
||||
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
|
||||
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
|
||||
netlinkRule.Invert = rule.Invert
|
||||
return netlinkRule
|
||||
}
|
||||
|
||||
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
|
||||
return Rule{
|
||||
Priority: netlinkRule.Priority,
|
||||
Family: netlinkRule.Family,
|
||||
Table: netlinkRule.Table,
|
||||
Mark: netlinkRule.Mark,
|
||||
Src: netIPNetToNetipPrefix(netlinkRule.Src),
|
||||
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
|
||||
Invert: netlinkRule.Invert,
|
||||
}
|
||||
}
|
||||
|
||||
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
|
||||
debugMessage = "ip"
|
||||
|
||||
switch r.Family {
|
||||
switch rule.Family {
|
||||
case FamilyV4:
|
||||
debugMessage += " -f inet"
|
||||
case FamilyV6:
|
||||
debugMessage += " -f inet6"
|
||||
default:
|
||||
debugMessage += " -f " + fmt.Sprint(r.Family)
|
||||
debugMessage += " -f " + fmt.Sprint(rule.Family)
|
||||
}
|
||||
|
||||
debugMessage += " rule"
|
||||
@@ -101,20 +96,20 @@ func (r Rule) debugMessage(add bool) (debugMessage string) {
|
||||
debugMessage += " del"
|
||||
}
|
||||
|
||||
if r.Src.IsValid() {
|
||||
debugMessage += " from " + r.Src.String()
|
||||
if rule.Src.IsValid() {
|
||||
debugMessage += " from " + rule.Src.String()
|
||||
}
|
||||
|
||||
if r.Dst.IsValid() {
|
||||
debugMessage += " to " + r.Dst.String()
|
||||
if rule.Dst.IsValid() {
|
||||
debugMessage += " to " + rule.Dst.String()
|
||||
}
|
||||
|
||||
if r.Table != 0 {
|
||||
debugMessage += " lookup " + fmt.Sprint(r.Table)
|
||||
if rule.Table != 0 {
|
||||
debugMessage += " lookup " + fmt.Sprint(rule.Table)
|
||||
}
|
||||
|
||||
if r.Priority != nil {
|
||||
debugMessage += " pref " + fmt.Sprint(*r.Priority)
|
||||
if rule.Priority != -1 {
|
||||
debugMessage += " pref " + fmt.Sprint(rule.Priority)
|
||||
}
|
||||
|
||||
return debugMessage
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
FlagInvert = unix.FIB_RULE_INVERT
|
||||
ActionToTable = unix.FR_ACT_TO_TBL
|
||||
)
|
||||
|
||||
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
case FamilyV4:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
case FamilyV6:
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
}
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ruleMessages, err := conn.Rule.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rules = make([]Rule, 0, len(ruleMessages))
|
||||
for _, message := range ruleMessages {
|
||||
if family != FamilyAll && family != message.Family {
|
||||
continue
|
||||
}
|
||||
var rule Rule
|
||||
rule.fromMessage(message)
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
n.debugLogger.Debug(rule.debugMessage(true))
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn.Rule.Add(rule.message())
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
n.debugLogger.Debug(rule.debugMessage(false))
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn.Rule.Delete(rule.message())
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Rule_debugMessage(t *testing.T) {
|
||||
func Test_ruleDbgMsg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -15,7 +15,7 @@ func Test_Rule_debugMessage(t *testing.T) {
|
||||
dbgMsg string
|
||||
}{
|
||||
"default values": {
|
||||
dbgMsg: "ip -f 0 rule del",
|
||||
dbgMsg: "ip -f 0 rule del pref 0",
|
||||
},
|
||||
"add rule": {
|
||||
add: true,
|
||||
@@ -24,7 +24,7 @@ func Test_Rule_debugMessage(t *testing.T) {
|
||||
Src: makeNetipPrefix(1),
|
||||
Dst: makeNetipPrefix(2),
|
||||
Table: 100,
|
||||
Priority: ptrTo(uint32(101)),
|
||||
Priority: 101,
|
||||
},
|
||||
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||
},
|
||||
@@ -34,7 +34,7 @@ func Test_Rule_debugMessage(t *testing.T) {
|
||||
Src: makeNetipPrefix(1),
|
||||
Dst: makeNetipPrefix(2),
|
||||
Table: 100,
|
||||
Priority: ptrTo(uint32(101)),
|
||||
Priority: 101,
|
||||
},
|
||||
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||
},
|
||||
@@ -44,7 +44,7 @@ func Test_Rule_debugMessage(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbgMsg := testCase.rule.debugMessage(testCase.add)
|
||||
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
|
||||
|
||||
assert.Equal(t, testCase.dbgMsg, dbgMsg)
|
||||
})
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
func NewRule() Rule {
|
||||
return Rule{}
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
Network netip.Prefix
|
||||
}
|
||||
|
||||
func (a Addr) String() string {
|
||||
return a.Network.String()
|
||||
}
|
||||
|
||||
type Link struct {
|
||||
Type string
|
||||
Name string
|
||||
Index int
|
||||
EncapType string
|
||||
MTU uint16
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
LinkIndex int
|
||||
Dst netip.Prefix
|
||||
Src netip.Addr
|
||||
Gw netip.Addr
|
||||
Priority int
|
||||
Family int
|
||||
Table int
|
||||
Type int
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
Priority int
|
||||
Family int
|
||||
Table int
|
||||
Mark uint32
|
||||
Src netip.Prefix
|
||||
Dst netip.Prefix
|
||||
Invert bool
|
||||
}
|
||||
|
||||
func (r Rule) String() string {
|
||||
from := "all"
|
||||
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
||||
from = r.Src.String()
|
||||
}
|
||||
|
||||
to := "all"
|
||||
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
||||
to = r.Dst.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
|
||||
r.Priority, from, to, r.Table)
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() bool {
|
||||
// Check for Wireguard family without loading the wireguard module.
|
||||
// Some kernels have the wireguard module built-in, and don't have a
|
||||
// modules directory, such as WSL2 kernels.
|
||||
ok := hasWireguardFamily()
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try loading the wireguard module, since some systems do not load
|
||||
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
||||
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
||||
err := mod.Probe("wireguard")
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
||||
return false
|
||||
}
|
||||
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
||||
|
||||
// Re-check if the Wireguard family is now available, after loading
|
||||
// the wireguard kernel module.
|
||||
return hasWireguardFamily()
|
||||
}
|
||||
|
||||
func hasWireguardFamily() bool {
|
||||
_, err := netlink.GenlFamilyGet("wireguard")
|
||||
return err == nil
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/mdlayher/genetlink"
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
)
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
||||
// Check for Wireguard family without loading the wireguard module.
|
||||
// Some kernels have the wireguard module built-in, and don't have a
|
||||
// modules directory, such as WSL2 kernels.
|
||||
ok, err = hasWireguardFamily()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||
} else if ok {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Try loading the wireguard module, since some systems do not load
|
||||
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
||||
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
||||
err = mod.Probe("wireguard")
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
||||
|
||||
// Re-check if the Wireguard family is now available, after loading
|
||||
// the wireguard kernel module.
|
||||
ok, err = hasWireguardFamily()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func hasWireguardFamily() (ok bool, err error) {
|
||||
conn, err := genetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.GetFamily("wireguard")
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("getting wireguard family: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -4,8 +4,6 @@ package netlink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
||||
@@ -14,8 +12,7 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
||||
netLink := &NetLink{
|
||||
debugLogger: &noopLogger{},
|
||||
}
|
||||
ok, err := netLink.IsWireguardSupported()
|
||||
require.NoError(t, err)
|
||||
ok := netLink.IsWireguardSupported()
|
||||
if ok { // cannot assert since this depends on kernel
|
||||
t.Log("wireguard is supported")
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||
)
|
||||
@@ -32,7 +33,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
|
||||
args := []string{"--config", configPath}
|
||||
args = append(args, flags...)
|
||||
cmd := exec.CommandContext(ctx, bin, args...)
|
||||
setCmdSysProcAttr(cmd)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
|
||||
return starter.Start(cmd)
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
MaxEthernetFrameSize uint32 = 1500
|
||||
// MinIPv4MTU is defined according to
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
MinIPv4MTU uint32 = 68
|
||||
MinIPv6MTU uint32 = 1280
|
||||
|
||||
IPv4HeaderLength uint32 = 20
|
||||
IPv6HeaderLength uint32 = 40
|
||||
UDPHeaderLength uint32 = 8
|
||||
// BaseTCPHeaderLength is the TCP header length without options,
|
||||
// which is the minimum TCP header length.
|
||||
BaseTCPHeaderLength uint32 = 20
|
||||
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
|
||||
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
|
||||
MaxTCPHeaderLength uint32 = 60
|
||||
WireguardHeaderLength uint32 = 32
|
||||
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
|
||||
8 + // session id
|
||||
4 + // packet id
|
||||
28 // max possible auth tag/iv
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package constants
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
//nolint:revive
|
||||
const (
|
||||
SOCK_RAW = unix.SOCK_RAW
|
||||
SOCK_STREAM = unix.SOCK_STREAM
|
||||
AF_INET = unix.AF_INET
|
||||
AF_INET6 = unix.AF_INET6
|
||||
IPPROTO_TCP = unix.IPPROTO_TCP
|
||||
EAGAIN = unix.EAGAIN
|
||||
EWOULDBLOCK = unix.EWOULDBLOCK
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
package constants
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
SOCK_RAW = windows.SOCK_RAW
|
||||
SOCK_STREAM = windows.SOCK_STREAM
|
||||
AF_INET = windows.AF_INET
|
||||
AF_INET6 = windows.AF_INET6
|
||||
IPPROTO_TCP = windows.IPPROTO_TCP
|
||||
EAGAIN = windows.WSAEWOULDBLOCK
|
||||
EWOULDBLOCK = windows.WSAEWOULDBLOCK
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
var _ net.PacketConn = &ipv4Wrapper{}
|
||||
|
||||
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
|
||||
// the net.PacketConn interface. It's only used for Darwin or iOS.
|
||||
type ipv4Wrapper struct {
|
||||
ipv4Conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
|
||||
return &ipv4Wrapper{ipv4Conn: ipv4}
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return i.ipv4Conn.WriteTo(p, nil, addr)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) Close() error {
|
||||
return i.ipv4Conn.Close()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) LocalAddr() net.Addr {
|
||||
return i.ipv4Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetWriteDeadline(t)
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message,
|
||||
) (match bool, err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
) (err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking sent and received bodies: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrEchoDataMismatch, len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrEchoDataMismatch, sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
if ipv4 {
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP,
|
||||
unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6,
|
||||
unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package icmp
|
||||
|
||||
// setDontFragment for platforms other than Linux and Windows
|
||||
// is not implemented, so we just return assuming the don't
|
||||
// fragment flag is set on IP packets.
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
return nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr, ipv4 bool) (err error) {
|
||||
if ipv4 {
|
||||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, 14, 1)
|
||||
}
|
||||
return windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, 14, 1)
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrMTUNotFound = errors.New("MTU not found")
|
||||
errTimeout = errors.New("operation timed out")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w: after %s", errTimeout, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// PathMTUDiscover discovers the path MTU to the given IP address
|
||||
// using ICMP.
|
||||
// It first tries to get the next hop MTU using ICMP messages.
|
||||
// If that fails, it falls back to sending echo requests with
|
||||
// different packet sizes to find the maximum MTU.
|
||||
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is4() {
|
||||
logger.Debugf("finding IPv4 next hop MTU to %s", ip)
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("requesting IPv6 ICMP packet-too-big reply from %s", ip)
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, errTimeout): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debugf("falling back to sending different sized echo packets to %s", ip)
|
||||
minMTU := constants.MinIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = constants.MinIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package icmp
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
icmpv4Protocol = 1
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
const ipv4 = true
|
||||
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is6() {
|
||||
panic("IP address is not v4")
|
||||
}
|
||||
conn, err := listenICMPv4(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
// for loop in case we read an ICMP message from another ICMP request
|
||||
// or TCP/UDP traffic triggering an ICMP response.
|
||||
for {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
// Side note: echo reply should be at most the number of bytes
|
||||
// sent, and can be lower, more precisely 576-ipHeader bytes,
|
||||
// in case the next hop we are reaching replies with a destination
|
||||
// unreachable and wants to ensure the response makes it way back
|
||||
// by keeping a low packet size, see:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.DstUnreach:
|
||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||
const portUnreachable = 3
|
||||
const communicationAdministrativelyProhibitedCode = 13
|
||||
switch inboundMessage.Code {
|
||||
case fragmentationRequiredAndDFFlagSetCode:
|
||||
case portUnreachable: // triggered by TCP or UDP from applications
|
||||
continue // ignore and wait for the next message
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrDestinationUnreachable,
|
||||
ErrCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrDestinationUnreachable, inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
// Note: the go library does not handle this NextHopMTU section.
|
||||
nextHopMTU := packetBytes[6:8]
|
||||
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||
}
|
||||
|
||||
// The code below is really for sanity checks
|
||||
packetBytes = packetBytes[8:]
|
||||
header, err := ipv4.ParseHeader(packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
|
||||
}
|
||||
packetBytes = packetBytes[header.Len:] // truncated original datagram
|
||||
|
||||
const truncated = true
|
||||
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking echo reply: %w", err)
|
||||
}
|
||||
return mtu, nil
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
|
||||
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
const ipv4 = false
|
||||
setDFErr = setDontFragment(fd, ipv4) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is4() {
|
||||
panic("IP address is not v6")
|
||||
}
|
||||
conn, err := listenICMPv6(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop if we encounter another ICMP packet with an unknown id.
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
packetBytes = packetBytes[ipv6.HeaderLen:]
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.PacketTooBig:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
const truncatedBody = true
|
||||
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message: %w", err)
|
||||
}
|
||||
return uint32(typedBody.MTU), nil //nolint:gosec
|
||||
case *icmp.DstUnreach:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
|
||||
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user