mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-29 07:17:34 +02:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db947c17a8 | |||
| b0a75673bd | |||
| 5f0c499808 | |||
| bdd69a1fb7 | |||
| 1af75bb30c | |||
| 9c1cd7e8b1 | |||
| facc6df3be | |||
| e292a4c9be | |||
| 9e4dd61c19 | |||
| fe3d4a94d4 | |||
| de38d759a4 | |||
| fba60af772 | |||
| 9b9b723887 |
@@ -59,10 +59,13 @@ jobs:
|
||||
- name: Run tests in test container
|
||||
run: |
|
||||
touch coverage.txt
|
||||
docker run --rm --device /dev/net/tun \
|
||||
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
|
||||
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
||||
test-container
|
||||
|
||||
- name: Verify dev cross platform compatibility
|
||||
run: docker build --target xcompile .
|
||||
|
||||
- name: Build final image
|
||||
run: docker build -t final-image .
|
||||
|
||||
|
||||
@@ -46,6 +46,10 @@ RUN git init && \
|
||||
git diff --exit-code && \
|
||||
rm -rf .git/
|
||||
|
||||
FROM --platform=${BUILDPLATFORM} base AS xcompile
|
||||
RUN GOOS=darwin go build -o /dev/null ./...
|
||||
RUN GOOS=windows go build -o /dev/null ./...
|
||||
|
||||
FROM --platform=${BUILDPLATFORM} base AS build
|
||||
ARG TARGETPLATFORM
|
||||
ARG VERSION=unknown
|
||||
|
||||
+14
-12
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
@@ -393,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
|
||||
dnsLogger := logger.New(log.SetComponent("dns"))
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
|
||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
|
||||
dnsLogger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating DNS loop: %w", err)
|
||||
@@ -553,26 +554,26 @@ type netLinker interface {
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() bool
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
IsIPv6Supported() (ok bool, err error)
|
||||
PatchLoggerLevel(level log.Level)
|
||||
}
|
||||
|
||||
type Addresser interface {
|
||||
AddrList(link netlink.Link, family int) (
|
||||
addresses []netlink.Addr, err error)
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
AddrList(linkIndex uint32, family uint8) (
|
||||
addresses []netip.Prefix, err error)
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
RouteDel(route netlink.Route) error
|
||||
RouteReplace(route netlink.Route) error
|
||||
}
|
||||
|
||||
type Ruler interface {
|
||||
RuleList(family int) (rules []netlink.Rule, err error)
|
||||
RuleList(family uint8) (rules []netlink.Rule, err error)
|
||||
RuleAdd(rule netlink.Rule) error
|
||||
RuleDel(rule netlink.Rule) error
|
||||
}
|
||||
@@ -580,11 +581,12 @@ type Ruler interface {
|
||||
type Linker interface {
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkByIndex(index int) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkByIndex(index uint32) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkDel(linkIndex uint32) (err error)
|
||||
LinkSetUp(linkIndex uint32) (err error)
|
||||
LinkSetDown(linkIndex uint32) (err error)
|
||||
LinkSetMTU(linkIndex, mtu uint32) error
|
||||
}
|
||||
|
||||
type clier interface {
|
||||
|
||||
@@ -7,8 +7,10 @@ require (
|
||||
github.com/breml/rootcerts v0.3.3
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/jsimonetti/rtnetlink v1.4.2
|
||||
github.com/klauspost/compress v1.18.1
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/mdlayher/genetlink v1.3.2
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc10
|
||||
github.com/qdm12/gosettings v0.4.4
|
||||
@@ -19,12 +21,11 @@ require (
|
||||
github.com/qdm12/ss-server v0.6.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/text v0.31.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.org/x/text v0.33.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
gopkg.in/ini.v1 v1.67.0
|
||||
@@ -38,13 +39,12 @@ require (
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
@@ -55,12 +55,11 @@ require (
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/mod v0.29.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.47.0 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/protobuf v1.35.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
@@ -13,6 +13,8 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
|
||||
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
|
||||
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
|
||||
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
@@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
|
||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
||||
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
|
||||
@@ -47,8 +51,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
@@ -93,10 +97,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
@@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -140,12 +140,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -155,8 +153,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -164,8 +162,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -48,6 +48,10 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
||||
}
|
||||
|
||||
if p.Name == providers.Mullvad && vpnType == vpn.OpenVPN {
|
||||
warner.Warn("https://mullvad.net/en/blog/removing-openvpn-15th-january-2026")
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server selection: %w", err)
|
||||
|
||||
@@ -45,7 +45,8 @@ type Wireguard struct {
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue:
|
||||
// https://github.com/qdm12/gluetun/issues/2533.
|
||||
MTU uint16 `json:"mtu"`
|
||||
// Note this should now be replaced with the PMTUD feature.
|
||||
MTU uint32 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
// It defaults to "auto" and cannot be the empty string
|
||||
@@ -272,7 +273,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
||||
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if mtuPtr != nil {
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type Firewall interface {
|
||||
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dns
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -24,6 +24,7 @@ type Loop struct {
|
||||
localResolvers []netip.Addr
|
||||
resolvConf string
|
||||
client *http.Client
|
||||
firewall Firewall
|
||||
logger Logger
|
||||
userTrigger bool
|
||||
start <-chan struct{}
|
||||
@@ -39,7 +40,7 @@ type Loop struct {
|
||||
const defaultBackoffTime = 10 * time.Second
|
||||
|
||||
func NewLoop(settings settings.DNS,
|
||||
client *http.Client, logger Logger,
|
||||
client *http.Client, firewall Firewall, logger Logger,
|
||||
) (loop *Loop, err error) {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
|
||||
filter: filter,
|
||||
resolvConf: "/etc/resolv.conf",
|
||||
client: client,
|
||||
firewall: firewall,
|
||||
logger: logger,
|
||||
userTrigger: true,
|
||||
start: start,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
)
|
||||
|
||||
func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
|
||||
settings := l.GetSettings()
|
||||
|
||||
targetIP := settings.GetFirstPlaintextIPv4()
|
||||
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
|
||||
const dialTimeout = 3 * time.Second
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
|
||||
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
Timeout: dialTimeout,
|
||||
}
|
||||
nameserver.UseDNSInternally(settingsInternalDNS)
|
||||
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
+5
-5
@@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
"and go through your container network DNS outside the VPN tunnel!")
|
||||
} else {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
if !errors.Is(err, errUpdateBlockLists) {
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
settings = l.GetSettings()
|
||||
@@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
settings = l.GetSettings()
|
||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
}
|
||||
|
||||
l.userTrigger = false
|
||||
@@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
settings := l.GetSettings()
|
||||
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.stopServer()
|
||||
}
|
||||
l.stopped <- struct{}{}
|
||||
@@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
case err := <-runError: // unexpected error
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.useUnencryptedDNS(ctx, fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -39,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
|
||||
// use internal DNS server
|
||||
const defaultDNSPort = 53
|
||||
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
|
||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
||||
AddrPort: addrPort,
|
||||
})
|
||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||
IPs: []netip.Addr{settings.ServerAddress},
|
||||
@@ -50,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||
if err != nil {
|
||||
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
|
||||
}
|
||||
|
||||
err = check.WaitForDNS(ctx, check.Settings{})
|
||||
if err != nil {
|
||||
l.stopServer()
|
||||
|
||||
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
|
||||
@@ -29,6 +29,7 @@ type Config struct {
|
||||
outboundSubnets []netip.Prefix
|
||||
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||
portRedirections portRedirections
|
||||
outputAddrPort map[uint16]netip.Addr
|
||||
stateMutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -52,6 +53,7 @@ func NewConfig(ctx context.Context, logger Logger,
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||
outputAddrPort: make(map[uint16]netip.Addr),
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
|
||||
@@ -2,6 +2,7 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
@@ -19,3 +20,15 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
|
||||
ipv4Instructions, ipv6Instructions []string,
|
||||
) error {
|
||||
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
|
||||
return fmt.Errorf("running iptables instructions: %w", err)
|
||||
}
|
||||
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
|
||||
return fmt.Errorf("running ip6tables instructions: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+111
-19
@@ -9,9 +9,19 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type operation uint8
|
||||
|
||||
const (
|
||||
opNone operation = iota
|
||||
opAppend
|
||||
opDelete
|
||||
opInsert
|
||||
opReplace
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
operation operation
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
@@ -22,6 +32,7 @@ type iptablesInstruction struct {
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
lineNumber uint16 // for replace operation, the line number to replace
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
@@ -60,6 +71,58 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
}
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) String() string {
|
||||
var sb strings.Builder
|
||||
if i.table != "" && i.table != "filter" {
|
||||
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
|
||||
}
|
||||
switch i.operation {
|
||||
case opNone:
|
||||
panic("no operation specified")
|
||||
case opAppend:
|
||||
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
|
||||
case opDelete:
|
||||
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
|
||||
case opInsert:
|
||||
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
|
||||
case opReplace:
|
||||
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
|
||||
}
|
||||
if i.inputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
|
||||
}
|
||||
if i.outputInterface != "" {
|
||||
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
|
||||
}
|
||||
if i.protocol != "" {
|
||||
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
|
||||
}
|
||||
if i.source.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
|
||||
}
|
||||
if i.destination.IsValid() {
|
||||
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
|
||||
}
|
||||
if i.destinationPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
|
||||
}
|
||||
if len(i.ctstate) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
|
||||
}
|
||||
if len(i.toPorts) > 0 {
|
||||
var portStrings []string
|
||||
for _, port := range i.toPorts {
|
||||
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
|
||||
}
|
||||
if i.target != "" {
|
||||
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
@@ -77,34 +140,63 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "-R", "--replace":
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
consumed = expected
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
consumed = expected
|
||||
}
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.operation = opDelete
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.operation = opAppend
|
||||
instruction.chain = value
|
||||
case "-I", "--insert":
|
||||
instruction.operation = opInsert
|
||||
instruction.chain = value
|
||||
case "-R", "--replace":
|
||||
instruction.operation = opReplace
|
||||
instruction.chain = value
|
||||
const base, bits = 10, 16
|
||||
n, err := strconv.ParseUint(fields[2], base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
|
||||
}
|
||||
instruction.lineNumber = uint16(n)
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
@@ -117,18 +209,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
@@ -140,14 +232,14 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return nil
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
|
||||
@@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
s: "-I INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
operation: opInsert,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
operation: opAppend,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
operation: opDelete,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
|
||||
@@ -3,6 +3,7 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
|
||||
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
|
||||
// If the port was previously allowed for another IP address, that previous allowance will be removed.
|
||||
// Giving an invalid address will remove any existing restrictions for the port specified.
|
||||
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
existingIP := c.outputAddrPort[addrPort.Port()]
|
||||
|
||||
switch {
|
||||
case existingIP == addrPort.Addr():
|
||||
return nil
|
||||
case !addrPort.Addr().IsValid():
|
||||
// invalid address, remove any existing rules for the port
|
||||
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
|
||||
case !existingIP.IsValid():
|
||||
// no previous existing address for the port
|
||||
return c.insertOutputAddrPortRestriction(ctx, addrPort)
|
||||
default:
|
||||
// existing rule in the same IP family or different family
|
||||
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||
}
|
||||
if existingIP.Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.outputAddrPort, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||
commonInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
|
||||
}
|
||||
ipv4Instructions := commonInstructions
|
||||
ipv6Instructions := commonInstructions
|
||||
|
||||
familySpecificInstructions := []string{
|
||||
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||
}
|
||||
if addrPort.Addr().Is4() {
|
||||
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||
} else {
|
||||
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||
}
|
||||
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
|
||||
existingIP netip.Addr, addrPort netip.AddrPort,
|
||||
) (err error) {
|
||||
for _, protocol := range [...]string{"udp", "tcp"} {
|
||||
switch {
|
||||
case existingIP.Is4() && addrPort.Addr().Is4():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is6():
|
||||
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is4() && addrPort.Addr().Is6():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv4 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv6 rule: %w", err)
|
||||
}
|
||||
case existingIP.Is6() && addrPort.Addr().Is4():
|
||||
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), existingIP)
|
||||
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing existing IPv6 rule: %w", err)
|
||||
}
|
||||
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||
protocol, addrPort.Port(), addrPort.Addr())
|
||||
err = c.runIptablesInstruction(ctx, instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting new IPv4 rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIptablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
|
||||
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||
}
|
||||
parsed, err := parseIptablesInstruction(newInstruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||
}
|
||||
parsed.operation = opReplace
|
||||
parsed.lineNumber = lineNumber
|
||||
return c.runIP6tablesInstruction(ctx, parsed.String())
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !windows
|
||||
|
||||
package mod
|
||||
|
||||
import (
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package mod
|
||||
|
||||
func Probe(moduleName string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
+59
-17
@@ -1,33 +1,75 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink/rtnl"
|
||||
)
|
||||
|
||||
func (n *NetLink) AddrList(link Link, family int) (
|
||||
addresses []Addr, err error,
|
||||
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
|
||||
ipPrefixes []netip.Prefix, err error,
|
||||
) {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
|
||||
conn, err := rtnl.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ifc := &net.Interface{
|
||||
Index: int(linkIndex),
|
||||
}
|
||||
ipNets, err := conn.Addrs(ifc, int(family))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list addresses: %w", err)
|
||||
}
|
||||
|
||||
addresses = make([]Addr, len(netlinkAddresses))
|
||||
for i := range netlinkAddresses {
|
||||
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
|
||||
ipPrefixes = make([]netip.Prefix, len(ipNets))
|
||||
for i := range ipNets {
|
||||
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
return ipPrefixes, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
netlinkAddress := netlink.Addr{
|
||||
IPNet: netipPrefixToIPNet(addr.Network),
|
||||
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
|
||||
conn, err := rtnl.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ipNet := netipPrefixToIPNet(prefix)
|
||||
|
||||
// Remove any address identical to the one we want to add
|
||||
family := FamilyV4
|
||||
if prefix.Addr().Is6() {
|
||||
family = FamilyV6
|
||||
}
|
||||
ifc := &net.Interface{
|
||||
Index: int(linkIndex),
|
||||
}
|
||||
addresses, err := conn.Addrs(ifc, int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing addresses: %w", err)
|
||||
}
|
||||
for _, address := range addresses {
|
||||
if address.IP.Equal(ipNet.IP) &&
|
||||
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
|
||||
err = conn.AddrDel(ifc, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting address from interface: %w", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
|
||||
// Add the new address to the interface
|
||||
err = conn.AddrAdd(ifc, ipNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding address to interface: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) AddrList(link Link, family int) (
|
||||
addresses []Addr, err error,
|
||||
) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) AddrReplace(Link, Addr) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
|
||||
return netip.PrefixFrom(ip, bits)
|
||||
}
|
||||
|
||||
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
|
||||
if ip == nil || len(*ip) == 0 {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
var dstIP netip.Addr
|
||||
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
|
||||
dstIP = netip.AddrFrom4([4]byte(*ip))
|
||||
} else {
|
||||
dstIP = netip.AddrFrom16([16]byte(*ip))
|
||||
}
|
||||
return netip.PrefixFrom(dstIP, int(length))
|
||||
}
|
||||
|
||||
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
|
||||
if !prefix.IsValid() {
|
||||
return nil, 0
|
||||
}
|
||||
prefixIP := prefix.Addr().Unmap()
|
||||
ip = new(net.IP)
|
||||
*ip = netipAddrToNetIP(prefixIP)
|
||||
length = uint8(prefix.Bits()) //nolint:gosec
|
||||
return ip, length
|
||||
}
|
||||
|
||||
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
|
||||
switch {
|
||||
case !address.IsValid():
|
||||
|
||||
@@ -4,13 +4,7 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
FamilyAll = 0
|
||||
FamilyV4 = 2
|
||||
FamilyV6 = 10
|
||||
)
|
||||
|
||||
func FamilyToString(family int) string {
|
||||
func FamilyToString(family uint8) string {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
return "all"
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
FamilyAll uint8 = unix.AF_UNSPEC
|
||||
FamilyV4 uint8 = unix.AF_INET
|
||||
FamilyV6 uint8 = unix.AF_INET6
|
||||
)
|
||||
@@ -1,16 +1,30 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
func ptrTo[T any](v T) *T { return &v }
|
||||
|
||||
func makeNetipPrefix(n byte) netip.Prefix {
|
||||
const bits = 24
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
||||
}
|
||||
|
||||
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||
|
||||
func makeLinkName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||
name := make([]byte, 8)
|
||||
for i := range name {
|
||||
name[i] = alphabet[rng.IntN(len(alphabet))]
|
||||
}
|
||||
return "test" + string(name)
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Debug(_ string) {}
|
||||
|
||||
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
|
||||
return false, fmt.Errorf("finding link corresponding to route: %w", err)
|
||||
}
|
||||
|
||||
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
|
||||
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
|
||||
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
||||
switch {
|
||||
case !sourceIsIPv6 && !destinationIsIPv6,
|
||||
|
||||
+162
-76
@@ -1,105 +1,191 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
)
|
||||
|
||||
type DeviceType uint16
|
||||
|
||||
type Link struct {
|
||||
Index uint32
|
||||
Name string
|
||||
DeviceType DeviceType
|
||||
VirtualType string
|
||||
MTU uint32
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||
netlinkLinks, err := netlink.LinkList()
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
linkMessages, err := conn.Link.List()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||
}
|
||||
|
||||
links = make([]Link, len(netlinkLinks))
|
||||
for i := range netlinkLinks {
|
||||
links[i] = netlinkLinkToLink(netlinkLinks[i])
|
||||
links = make([]Link, len(linkMessages))
|
||||
for i, message := range linkMessages {
|
||||
virtualType := ""
|
||||
if message.Attributes.Info != nil {
|
||||
virtualType = message.Attributes.Info.Kind
|
||||
}
|
||||
links[i] = Link{
|
||||
Index: message.Index,
|
||||
Name: message.Attributes.Name,
|
||||
DeviceType: DeviceType(message.Type),
|
||||
VirtualType: virtualType,
|
||||
MTU: message.Attributes.MTU,
|
||||
}
|
||||
}
|
||||
|
||||
return links, nil
|
||||
}
|
||||
|
||||
var ErrLinkNotFound = errors.New("link not found")
|
||||
|
||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
netlinkLink, err := netlink.LinkByName(name)
|
||||
links, err := n.LinkList()
|
||||
if err != nil {
|
||||
return Link{}, err
|
||||
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
|
||||
return netlinkLinkToLink(netlinkLink), nil
|
||||
for _, link := range links {
|
||||
if link.Name == name {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
||||
netlinkLink, err := netlink.LinkByIndex(index)
|
||||
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||
links, err := n.LinkList()
|
||||
if err != nil {
|
||||
return Link{}, err
|
||||
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
|
||||
return netlinkLinkToLink(netlinkLink), nil
|
||||
for _, link = range links {
|
||||
if link.Index == index {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
err = netlink.LinkAdd(netlinkLink)
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
return netlinkLink.Attrs().Index, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
||||
return netlink.LinkDel(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
||||
netlinkLink := linkToNetlinkLink(&link)
|
||||
err = netlink.LinkSetUp(netlinkLink)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return netlinkLink.Attrs().Index, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
type netlinkLinkImpl struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
linkType string
|
||||
}
|
||||
|
||||
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
|
||||
return n.attrs
|
||||
}
|
||||
|
||||
func (n *netlinkLinkImpl) Type() string {
|
||||
return n.linkType
|
||||
}
|
||||
|
||||
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
|
||||
attributes := netlinkLink.Attrs()
|
||||
return Link{
|
||||
Type: netlinkLink.Type(),
|
||||
Name: attributes.Name,
|
||||
Index: attributes.Index,
|
||||
EncapType: attributes.EncapType,
|
||||
MTU: uint16(attributes.MTU), //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
|
||||
// so that the vishvananda/netlink package can compare the returned
|
||||
// value against an untyped nil.
|
||||
func linkToNetlinkLink(link *Link) netlink.Link {
|
||||
if link == nil {
|
||||
return nil
|
||||
}
|
||||
return &netlinkLinkImpl{
|
||||
linkType: link.Type,
|
||||
attrs: &netlink.LinkAttrs{
|
||||
Name: link.Name,
|
||||
Index: link.Index,
|
||||
EncapType: link.EncapType,
|
||||
MTU: int(link.MTU),
|
||||
tx := &rtnetlink.LinkMessage{
|
||||
Type: uint16(link.DeviceType),
|
||||
Attributes: &rtnetlink.LinkAttributes{
|
||||
MTU: link.MTU,
|
||||
Name: link.Name,
|
||||
},
|
||||
}
|
||||
if link.VirtualType != "" {
|
||||
tx.Attributes.Info = &rtnetlink.LinkInfo{
|
||||
Kind: link.VirtualType,
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Link.New(tx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("creating new link: %w", err)
|
||||
}
|
||||
|
||||
linkMessages, err := conn.Link.List()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
for _, linkMessage := range linkMessages {
|
||||
if linkMessage.Attributes.Name == link.Name {
|
||||
return linkMessage.Index, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Link.Delete(linkIndex)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
rx, err := conn.Link.Get(linkIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting link: %w", err)
|
||||
}
|
||||
tx := &rtnetlink.LinkMessage{
|
||||
Type: rx.Type,
|
||||
Index: linkIndex,
|
||||
Flags: iffUp,
|
||||
Change: iffUp,
|
||||
}
|
||||
return conn.Link.Set(tx)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
linkInfo, err := conn.Link.Get(linkIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting link: %w", err)
|
||||
}
|
||||
message := &rtnetlink.LinkMessage{
|
||||
Type: linkInfo.Type,
|
||||
Index: linkIndex,
|
||||
Flags: 0,
|
||||
Change: iffUp,
|
||||
}
|
||||
return conn.Link.Set(message)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
message := &rtnetlink.LinkMessage{
|
||||
Index: linkIndex,
|
||||
Attributes: &rtnetlink.LinkAttributes{
|
||||
MTU: mtu,
|
||||
},
|
||||
}
|
||||
|
||||
err = conn.Link.Set(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
|
||||
mtu, linkIndex, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
|
||||
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
|
||||
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
|
||||
|
||||
iffUp = unix.IFF_UP
|
||||
)
|
||||
@@ -0,0 +1,85 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_NetLink_LinkList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlink := &NetLink{}
|
||||
|
||||
initialLinks, err := netlink.LinkList()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, initialLinks)
|
||||
|
||||
loopbackFound := false
|
||||
for _, link := range initialLinks {
|
||||
if link.Name != "lo" {
|
||||
continue
|
||||
}
|
||||
loopbackFound = true
|
||||
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
|
||||
break
|
||||
}
|
||||
assert.True(t, loopbackFound, "loopback interface not found")
|
||||
|
||||
testLink := Link{
|
||||
Name: makeLinkName(),
|
||||
// note if [Link.VirtualType] is set, [Link.DeviceType]
|
||||
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
|
||||
DeviceType: DeviceTypeNone,
|
||||
VirtualType: "wireguard",
|
||||
MTU: 1420,
|
||||
}
|
||||
index, err := netlink.LinkAdd(testLink)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = netlink.LinkDel(index)
|
||||
})
|
||||
|
||||
links, err := netlink.LinkList()
|
||||
require.NoError(t, err)
|
||||
|
||||
testLink.Index = index
|
||||
for _, link := range links {
|
||||
if link.Name != testLink.Name {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testLink, link)
|
||||
return
|
||||
}
|
||||
t.Errorf("created link %q not found", testLink.Name)
|
||||
}
|
||||
|
||||
func Test_NetLink_LinkSetMTU(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlink := &NetLink{}
|
||||
|
||||
testLink := Link{
|
||||
Name: makeLinkName(),
|
||||
DeviceType: DeviceTypeNone,
|
||||
VirtualType: "wireguard",
|
||||
MTU: 1420,
|
||||
}
|
||||
index, err := netlink.LinkAdd(testLink)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = netlink.LinkDel(index)
|
||||
})
|
||||
testLink.Index = index
|
||||
|
||||
err = netlink.LinkSetMTU(index, 1500)
|
||||
require.NoError(t, err)
|
||||
|
||||
link, err := netlink.LinkByIndex(index)
|
||||
require.NoError(t, err)
|
||||
testLink.MTU = 1500
|
||||
assert.Equal(t, testLink, link)
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
const (
|
||||
// FamilyAll is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyAll uint8 = iota
|
||||
// FamilyV4 is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyV4
|
||||
// FamilyV6 is a placeholder only and should not
|
||||
// be used.
|
||||
FamilyV6
|
||||
|
||||
// DeviceTypeEthernet is a placeholder only and should not be used.
|
||||
DeviceTypeEthernet DeviceType = 0
|
||||
// DeviceTypeLoopback is a placeholder only and should not be used.
|
||||
DeviceTypeLoopback DeviceType = 0
|
||||
// DeviceTypeNone is a placeholder only and should not be used.
|
||||
DeviceTypeNone DeviceType = 0
|
||||
|
||||
// iffUp is a placeholder only and should not be used.
|
||||
iffUp = 0
|
||||
|
||||
// RouteTypeUnicast is a placeholder only and should not be used.
|
||||
RouteTypeUnicast = 0
|
||||
// ScopeUniverse is a placeholder only and should not be used.
|
||||
ScopeUniverse = 0
|
||||
// ProtoStatic is a placeholder only and should not be used.
|
||||
ProtoStatic = 0
|
||||
|
||||
// FlagInvert is a placeholder only and should not be used.
|
||||
FlagInvert = 0
|
||||
// ActionToTable is a placeholder only and should not be used.
|
||||
ActionToTable = 0
|
||||
|
||||
// rtTableCompat is a placeholder only and should not be used.
|
||||
rtTableCompat = 0
|
||||
)
|
||||
|
||||
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (bool, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
+102
-46
@@ -1,69 +1,125 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
)
|
||||
|
||||
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
|
||||
// We set the filter to netlink.RT_FILTER_TABLE so that
|
||||
// routes from all tables are listed, as long as the filter
|
||||
// table is set to 0.
|
||||
const filterMask = netlink.RT_FILTER_TABLE
|
||||
// The filter is not left to `nil` otherwise non-main tables
|
||||
// are ignored.
|
||||
filter := &netlink.Route{}
|
||||
type Route struct {
|
||||
LinkIndex uint32
|
||||
Dst netip.Prefix
|
||||
Src netip.Prefix
|
||||
Gw netip.Addr
|
||||
Priority uint32
|
||||
Family uint8
|
||||
Table uint32
|
||||
Type uint8
|
||||
Scope uint8
|
||||
Proto uint8
|
||||
}
|
||||
|
||||
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
|
||||
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
table := uint32(message.Table)
|
||||
if table == 0 || table == rtTableCompat {
|
||||
table = message.Attributes.Table
|
||||
}
|
||||
r.LinkIndex = message.Attributes.OutIface
|
||||
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
|
||||
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
|
||||
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
|
||||
r.Priority = message.Attributes.Priority
|
||||
r.Family = message.Family
|
||||
r.Table = table
|
||||
r.Type = message.Type
|
||||
r.Scope = message.Scope
|
||||
r.Proto = message.Protocol
|
||||
}
|
||||
|
||||
func (r Route) message() *rtnetlink.RouteMessage {
|
||||
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||
src, srcLength := prefixToIPAndLength(r.Src)
|
||||
var table uint8
|
||||
var extendedTable uint32
|
||||
if r.Table <= uint32(^uint8(0)) {
|
||||
table = uint8(r.Table)
|
||||
} else {
|
||||
table = rtTableCompat
|
||||
extendedTable = r.Table
|
||||
}
|
||||
message := &rtnetlink.RouteMessage{
|
||||
Family: r.Family,
|
||||
DstLength: dstLength,
|
||||
SrcLength: srcLength,
|
||||
Table: table,
|
||||
Type: r.Type,
|
||||
Scope: r.Scope,
|
||||
Protocol: r.Proto,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
OutIface: r.LinkIndex,
|
||||
Dst: *dst, // there should always be a dst for routes
|
||||
Gateway: netipAddrToNetIP(r.Gw),
|
||||
Priority: r.Priority,
|
||||
Table: extendedTable,
|
||||
},
|
||||
}
|
||||
if src != nil { // src is optional
|
||||
message.Attributes.Src = *src
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
routeMessages, err := conn.Route.List()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||
}
|
||||
|
||||
routes = make([]Route, len(netlinkRoutes))
|
||||
for i := range netlinkRoutes {
|
||||
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
|
||||
routes = make([]Route, 0, len(routeMessages))
|
||||
for _, routeMessage := range routeMessages {
|
||||
if family != FamilyAll && routeMessage.Family != family {
|
||||
continue
|
||||
}
|
||||
var route Route
|
||||
route.fromMessage(routeMessage)
|
||||
routes = append(routes, route)
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteAdd(route Route) error {
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteAdd(&netlinkRoute)
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Route.Add(route.message())
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteDel(route Route) error {
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteDel(&netlinkRoute)
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.Route.Delete(route.message())
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteReplace(route Route) error {
|
||||
netlinkRoute := routeToNetlinkRoute(route)
|
||||
return netlink.RouteReplace(&netlinkRoute)
|
||||
}
|
||||
|
||||
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
|
||||
return Route{
|
||||
LinkIndex: netlinkRoute.LinkIndex,
|
||||
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
|
||||
Src: netIPToNetipAddress(netlinkRoute.Src),
|
||||
Gw: netIPToNetipAddress(netlinkRoute.Gw),
|
||||
Priority: netlinkRoute.Priority,
|
||||
Family: netlinkRoute.Family,
|
||||
Table: netlinkRoute.Table,
|
||||
Type: netlinkRoute.Type,
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
|
||||
return netlink.Route{
|
||||
LinkIndex: route.LinkIndex,
|
||||
Dst: netipPrefixToIPNet(route.Dst),
|
||||
Src: netipAddrToNetIP(route.Src),
|
||||
Gw: netipAddrToNetIP(route.Gw),
|
||||
Priority: route.Priority,
|
||||
Family: route.Family,
|
||||
Table: route.Table,
|
||||
Type: route.Type,
|
||||
}
|
||||
return conn.Route.Replace(route.message())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package netlink
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
RouteTypeUnicast = unix.RTN_UNICAST
|
||||
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
|
||||
ProtoStatic = unix.RTPROT_STATIC
|
||||
|
||||
rtTableCompat = unix.RT_TABLE_COMPAT
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) RouteList(family int) (
|
||||
routes []Route, err error,
|
||||
) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteAdd(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteDel(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RouteReplace(route Route) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
+76
-71
@@ -1,91 +1,96 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
)
|
||||
|
||||
func NewRule() Rule {
|
||||
// defaults found from netlink.NewRule() for fields we use,
|
||||
// the rest of the defaults is set when converting from a `Rule`
|
||||
// to a `netlink.Rule`
|
||||
return Rule{
|
||||
Priority: -1,
|
||||
Mark: 0,
|
||||
}
|
||||
type Rule struct {
|
||||
Priority *uint32
|
||||
Family uint8
|
||||
Table uint32
|
||||
Mark *uint32
|
||||
Src netip.Prefix
|
||||
Dst netip.Prefix
|
||||
Flags uint32
|
||||
Action uint8
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
case FamilyV4:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
case FamilyV6:
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
|
||||
table := uint32(message.Table)
|
||||
if table == 0 || table == rtTableCompat {
|
||||
table = *message.Attributes.Table
|
||||
}
|
||||
netlinkRules, err := netlink.RuleList(family)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
r.Priority = message.Attributes.Priority
|
||||
r.Family = message.Family
|
||||
r.Table = table
|
||||
r.Mark = message.Attributes.FwMark
|
||||
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
|
||||
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
|
||||
r.Flags = message.Flags
|
||||
r.Action = message.Action
|
||||
}
|
||||
|
||||
func (r Rule) message() *rtnetlink.RuleMessage {
|
||||
src, srcLength := prefixToIPAndLength(r.Src)
|
||||
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||
|
||||
message := &rtnetlink.RuleMessage{
|
||||
Family: r.Family,
|
||||
SrcLength: srcLength,
|
||||
DstLength: dstLength,
|
||||
Flags: r.Flags,
|
||||
Action: r.Action,
|
||||
Attributes: &rtnetlink.RuleAttributes{
|
||||
Priority: r.Priority,
|
||||
FwMark: r.Mark,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
},
|
||||
}
|
||||
|
||||
rules = make([]Rule, len(netlinkRules))
|
||||
for i := range netlinkRules {
|
||||
rules[i] = netlinkRuleToRule(netlinkRules[i])
|
||||
if r.Table <= uint32(^uint8(0)) {
|
||||
message.Table = uint8(r.Table)
|
||||
} else {
|
||||
message.Table = rtTableCompat
|
||||
message.Attributes.Table = &r.Table
|
||||
}
|
||||
return rules, nil
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
n.debugLogger.Debug(ruleDbgMsg(true, rule))
|
||||
netlinkRule := ruleToNetlinkRule(rule)
|
||||
return netlink.RuleAdd(&netlinkRule)
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
n.debugLogger.Debug(ruleDbgMsg(false, rule))
|
||||
netlinkRule := ruleToNetlinkRule(rule)
|
||||
return netlink.RuleDel(&netlinkRule)
|
||||
}
|
||||
|
||||
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
|
||||
netlinkRule = *netlink.NewRule()
|
||||
netlinkRule.Priority = rule.Priority
|
||||
netlinkRule.Family = rule.Family
|
||||
netlinkRule.Table = rule.Table
|
||||
netlinkRule.Mark = rule.Mark
|
||||
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
|
||||
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
|
||||
netlinkRule.Invert = rule.Invert
|
||||
return netlinkRule
|
||||
}
|
||||
|
||||
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
|
||||
return Rule{
|
||||
Priority: netlinkRule.Priority,
|
||||
Family: netlinkRule.Family,
|
||||
Table: netlinkRule.Table,
|
||||
Mark: netlinkRule.Mark,
|
||||
Src: netIPNetToNetipPrefix(netlinkRule.Src),
|
||||
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
|
||||
Invert: netlinkRule.Invert,
|
||||
func (r Rule) String() string {
|
||||
from := "all"
|
||||
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
||||
from = r.Src.String()
|
||||
}
|
||||
|
||||
to := "all"
|
||||
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
||||
to = r.Dst.String()
|
||||
}
|
||||
|
||||
priority := ""
|
||||
if r.Priority != nil {
|
||||
priority = fmt.Sprintf(" %d", *r.Priority)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
|
||||
priority, from, to, r.Table)
|
||||
}
|
||||
|
||||
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
|
||||
func (r Rule) debugMessage(add bool) (debugMessage string) {
|
||||
debugMessage = "ip"
|
||||
|
||||
switch rule.Family {
|
||||
switch r.Family {
|
||||
case FamilyV4:
|
||||
debugMessage += " -f inet"
|
||||
case FamilyV6:
|
||||
debugMessage += " -f inet6"
|
||||
default:
|
||||
debugMessage += " -f " + fmt.Sprint(rule.Family)
|
||||
debugMessage += " -f " + fmt.Sprint(r.Family)
|
||||
}
|
||||
|
||||
debugMessage += " rule"
|
||||
@@ -96,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
|
||||
debugMessage += " del"
|
||||
}
|
||||
|
||||
if rule.Src.IsValid() {
|
||||
debugMessage += " from " + rule.Src.String()
|
||||
if r.Src.IsValid() {
|
||||
debugMessage += " from " + r.Src.String()
|
||||
}
|
||||
|
||||
if rule.Dst.IsValid() {
|
||||
debugMessage += " to " + rule.Dst.String()
|
||||
if r.Dst.IsValid() {
|
||||
debugMessage += " to " + r.Dst.String()
|
||||
}
|
||||
|
||||
if rule.Table != 0 {
|
||||
debugMessage += " lookup " + fmt.Sprint(rule.Table)
|
||||
if r.Table != 0 {
|
||||
debugMessage += " lookup " + fmt.Sprint(r.Table)
|
||||
}
|
||||
|
||||
if rule.Priority != -1 {
|
||||
debugMessage += " pref " + fmt.Sprint(rule.Priority)
|
||||
if r.Priority != nil {
|
||||
debugMessage += " pref " + fmt.Sprint(*r.Priority)
|
||||
}
|
||||
|
||||
return debugMessage
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
FlagInvert = unix.FIB_RULE_INVERT
|
||||
ActionToTable = unix.FR_ACT_TO_TBL
|
||||
)
|
||||
|
||||
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||
switch family {
|
||||
case FamilyAll:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
case FamilyV4:
|
||||
n.debugLogger.Debug("ip -4 rule list")
|
||||
case FamilyV6:
|
||||
n.debugLogger.Debug("ip -6 rule list")
|
||||
}
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ruleMessages, err := conn.Rule.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rules = make([]Rule, 0, len(ruleMessages))
|
||||
for _, message := range ruleMessages {
|
||||
if family != FamilyAll && family != message.Family {
|
||||
continue
|
||||
}
|
||||
var rule Rule
|
||||
rule.fromMessage(message)
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
n.debugLogger.Debug(rule.debugMessage(true))
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn.Rule.Add(rule.message())
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
n.debugLogger.Debug(rule.debugMessage(false))
|
||||
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn.Rule.Delete(rule.message())
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ruleDbgMsg(t *testing.T) {
|
||||
func Test_Rule_debugMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
||||
dbgMsg string
|
||||
}{
|
||||
"default values": {
|
||||
dbgMsg: "ip -f 0 rule del pref 0",
|
||||
dbgMsg: "ip -f 0 rule del",
|
||||
},
|
||||
"add rule": {
|
||||
add: true,
|
||||
@@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
||||
Src: makeNetipPrefix(1),
|
||||
Dst: makeNetipPrefix(2),
|
||||
Table: 100,
|
||||
Priority: 101,
|
||||
Priority: ptrTo(uint32(101)),
|
||||
},
|
||||
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||
},
|
||||
@@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
||||
Src: makeNetipPrefix(1),
|
||||
Dst: makeNetipPrefix(2),
|
||||
Table: 100,
|
||||
Priority: 101,
|
||||
Priority: ptrTo(uint32(101)),
|
||||
},
|
||||
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||
},
|
||||
@@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
|
||||
dbgMsg := testCase.rule.debugMessage(testCase.add)
|
||||
|
||||
assert.Equal(t, testCase.dbgMsg, dbgMsg)
|
||||
})
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
func NewRule() Rule {
|
||||
return Rule{}
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule Rule) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
Network netip.Prefix
|
||||
}
|
||||
|
||||
func (a Addr) String() string {
|
||||
return a.Network.String()
|
||||
}
|
||||
|
||||
type Link struct {
|
||||
Type string
|
||||
Name string
|
||||
Index int
|
||||
EncapType string
|
||||
MTU uint16
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
LinkIndex int
|
||||
Dst netip.Prefix
|
||||
Src netip.Addr
|
||||
Gw netip.Addr
|
||||
Priority int
|
||||
Family int
|
||||
Table int
|
||||
Type int
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
Priority int
|
||||
Family int
|
||||
Table int
|
||||
Mark uint32
|
||||
Src netip.Prefix
|
||||
Dst netip.Prefix
|
||||
Invert bool
|
||||
}
|
||||
|
||||
func (r Rule) String() string {
|
||||
from := "all"
|
||||
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
||||
from = r.Src.String()
|
||||
}
|
||||
|
||||
to := "all"
|
||||
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
||||
to = r.Dst.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
|
||||
r.Priority, from, to, r.Table)
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() bool {
|
||||
// Check for Wireguard family without loading the wireguard module.
|
||||
// Some kernels have the wireguard module built-in, and don't have a
|
||||
// modules directory, such as WSL2 kernels.
|
||||
ok := hasWireguardFamily()
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try loading the wireguard module, since some systems do not load
|
||||
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
||||
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
||||
err := mod.Probe("wireguard")
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
||||
return false
|
||||
}
|
||||
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
||||
|
||||
// Re-check if the Wireguard family is now available, after loading
|
||||
// the wireguard kernel module.
|
||||
return hasWireguardFamily()
|
||||
}
|
||||
|
||||
func hasWireguardFamily() bool {
|
||||
_, err := netlink.GenlFamilyGet("wireguard")
|
||||
return err == nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/mdlayher/genetlink"
|
||||
"github.com/qdm12/gluetun/internal/mod"
|
||||
)
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
||||
// Check for Wireguard family without loading the wireguard module.
|
||||
// Some kernels have the wireguard module built-in, and don't have a
|
||||
// modules directory, such as WSL2 kernels.
|
||||
ok, err = hasWireguardFamily()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||
} else if ok {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Try loading the wireguard module, since some systems do not load
|
||||
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
||||
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
||||
err = mod.Probe("wireguard")
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
||||
|
||||
// Re-check if the Wireguard family is now available, after loading
|
||||
// the wireguard kernel module.
|
||||
ok, err = hasWireguardFamily()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func hasWireguardFamily() (ok bool, err error) {
|
||||
conn, err := genetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("dialing netlink: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.GetFamily("wireguard")
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("getting wireguard family: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -4,6 +4,8 @@ package netlink
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
||||
@@ -12,7 +14,8 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
||||
netLink := &NetLink{
|
||||
debugLogger: &noopLogger{},
|
||||
}
|
||||
ok := netLink.IsWireguardSupported()
|
||||
ok, err := netLink.IsWireguardSupported()
|
||||
require.NoError(t, err)
|
||||
if ok { // cannot assert since this depends on kernel
|
||||
t.Log("wireguard is supported")
|
||||
} else {
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package netlink
|
||||
|
||||
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||
)
|
||||
@@ -33,7 +32,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
|
||||
args := []string{"--config", configPath}
|
||||
args = append(args, flags...)
|
||||
cmd := exec.CommandContext(ctx, bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
setCmdSysProcAttr(cmd)
|
||||
|
||||
return starter.Start(cmd)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
var _ net.PacketConn = &ipv4Wrapper{}
|
||||
|
||||
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
|
||||
// the net.PacketConn interface. It's only used for Darwin or iOS.
|
||||
type ipv4Wrapper struct {
|
||||
ipv4Conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
|
||||
return &ipv4Wrapper{ipv4Conn: ipv4}
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return i.ipv4Conn.WriteTo(p, nil, addr)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) Close() error {
|
||||
return i.ipv4Conn.Close()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) LocalAddr() net.Addr {
|
||||
return i.ipv4Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetWriteDeadline(t)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message,
|
||||
) (match bool, err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
) (err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking sent and received bodies: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrICMPEchoDataMismatch, len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrICMPEchoDataMismatch, sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package pmtud
|
||||
|
||||
// setDontFragment for platforms other than Linux and Windows
|
||||
// is not implemented, so we just return assuming the don't
|
||||
// fragment flag is set on IP packets.
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
|
||||
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package pmtud
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
minIPv4MTU uint32 = 68
|
||||
icmpv4Protocol int = 1
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is6() {
|
||||
panic("IP address is not v4")
|
||||
}
|
||||
conn, err := listenICMPv4(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop in case we read an echo reply for another ICMP request
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
// Side note: echo reply should be at most the number of bytes
|
||||
// sent, and can be lower, more precisely 576-ipHeader bytes,
|
||||
// in case the next hop we are reaching replies with a destination
|
||||
// unreachable and wants to ensure the response makes it way back
|
||||
// by keeping a low packet size, see:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.DstUnreach:
|
||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||
const communicationAdministrativelyProhibitedCode = 13
|
||||
switch inboundMessage.Code {
|
||||
case fragmentationRequiredAndDFFlagSetCode:
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrICMPDestinationUnreachable,
|
||||
ErrICMPCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
// Note: the go library does not handle this NextHopMTU section.
|
||||
nextHopMTU := packetBytes[6:8]
|
||||
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||
}
|
||||
|
||||
// The code below is really for sanity checks
|
||||
packetBytes = packetBytes[8:]
|
||||
header, err := ipv4.ParseHeader(packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
|
||||
}
|
||||
packetBytes = packetBytes[header.Len:] // truncated original datagram
|
||||
|
||||
const truncated = true
|
||||
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking echo reply: %w", err)
|
||||
}
|
||||
return mtu, nil
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
minIPv6MTU = 1280
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
|
||||
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is4() {
|
||||
panic("IP address is not v6")
|
||||
}
|
||||
conn, err := listenICMPv6(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop if we encounter another ICMP packet with an unknown id.
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
packetBytes = packetBytes[ipv6.HeaderLen:]
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.PacketTooBig:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
const truncatedBody = true
|
||||
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message: %w", err)
|
||||
}
|
||||
return uint32(typedBody.MTU), nil //nolint:gosec
|
||||
case *icmp.DstUnreach:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
|
||||
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func buildMessageToSend(ipVersion string, mtu uint32) (id uint16, message *icmp.Message) {
|
||||
var seed [32]byte
|
||||
_, _ = cryptorand.Read(seed[:])
|
||||
randomSource := rand.NewChaCha8(seed)
|
||||
|
||||
const uint16Bytes = 2
|
||||
idBytes := make([]byte, uint16Bytes)
|
||||
_, _ = randomSource.Read(idBytes)
|
||||
id = binary.BigEndian.Uint16(idBytes)
|
||||
|
||||
var ipHeaderLength uint32
|
||||
var icmpType icmp.Type
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
ipHeaderLength = ipv4.HeaderLen
|
||||
icmpType = ipv4.ICMPTypeEcho
|
||||
case "v6":
|
||||
ipHeaderLength = ipv6.HeaderLen
|
||||
icmpType = ipv6.ICMPTypeEchoRequest
|
||||
default:
|
||||
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
|
||||
}
|
||||
const pingHeaderLength = 0 +
|
||||
1 + // type
|
||||
1 + // code
|
||||
2 + // checksum
|
||||
2 + // identifier
|
||||
2 // sequence number
|
||||
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
|
||||
messageBodyData := make([]byte, pingBodyDataSize)
|
||||
_, _ = randomSource.Read(messageBodyData)
|
||||
|
||||
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
|
||||
message = &icmp.Message{
|
||||
Type: icmpType, // echo request
|
||||
Code: 0, // no code
|
||||
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
|
||||
Body: &icmp.Echo{
|
||||
ID: int(id),
|
||||
Seq: 0, // only one packet
|
||||
Data: messageBodyData,
|
||||
},
|
||||
}
|
||||
return id, message
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package pmtud
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(_ string) {}
|
||||
func (noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (noopLogger) Warnf(_ string, _ ...any) {}
|
||||
@@ -0,0 +1,271 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
||||
|
||||
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
||||
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||
// If the pingTimeout is zero, it defaults to 1 second.
|
||||
// If the logger is nil, a no-op logger is used.
|
||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
|
||||
mtu uint32, err error,
|
||||
) {
|
||||
if physicalLinkMTU == 0 {
|
||||
const ethernetStandardMTU = 1500
|
||||
physicalLinkMTU = ethernetStandardMTU
|
||||
}
|
||||
if pingTimeout == 0 {
|
||||
pingTimeout = time.Second
|
||||
}
|
||||
if logger == nil {
|
||||
logger = &noopLogger{}
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
logger.Debug("finding IPv4 next hop MTU")
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debug("falling back to sending different sized echo packets")
|
||||
minMTU := minIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = minIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
||||
}
|
||||
|
||||
type pmtudTestUnit struct {
|
||||
mtu uint32
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU uint32, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]pmtudTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// Create the MTU slice of length 11 such that:
|
||||
// - the first element is the minMTU
|
||||
// - the last element is the maxMTU
|
||||
// - elements in-between are separated as close to each other
|
||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||
// with a total search space of 1728 MTUs which is enough;
|
||||
// to find it in 2 searches requires 37 parallel queries which
|
||||
// could be blocked by firewalls.
|
||||
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||
const mtusLength = 11 // find the final MTU in 3 searches
|
||||
diff := maxMTU - minMTU
|
||||
switch {
|
||||
case minMTU > maxMTU:
|
||||
panic("minMTU > maxMTU")
|
||||
case diff <= mtusLength:
|
||||
mtus = make([]uint32, 0, diff)
|
||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||
mtus = append(mtus, mtu)
|
||||
}
|
||||
default:
|
||||
step := float64(diff) / float64(mtusLength-1)
|
||||
mtus = make([]uint32, 0, mtusLength)
|
||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||
}
|
||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||
}
|
||||
|
||||
return mtus
|
||||
}
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []pmtudTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
echoBody, ok := message.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//go:build integration
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_PathMTUDiscover(t *testing.T) {
|
||||
t.Parallel()
|
||||
const physicalLinkMTU = 1500
|
||||
const timeout = time.Second
|
||||
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
|
||||
physicalLinkMTU, timeout, nil)
|
||||
require.NoError(t, err)
|
||||
t.Log("MTU found:", mtu)
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_makeMTUsToTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
minMTU uint32
|
||||
maxMTU uint32
|
||||
mtus []uint32
|
||||
}{
|
||||
"0_0": {
|
||||
mtus: []uint32{0},
|
||||
},
|
||||
"0_1": {
|
||||
maxMTU: 1,
|
||||
mtus: []uint32{0, 1},
|
||||
},
|
||||
"0_8": {
|
||||
maxMTU: 8,
|
||||
mtus: []uint32{0, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
},
|
||||
"0_12": {
|
||||
maxMTU: 12,
|
||||
mtus: []uint32{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
|
||||
},
|
||||
"0_80": {
|
||||
maxMTU: 80,
|
||||
mtus: []uint32{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
|
||||
},
|
||||
"0_100": {
|
||||
maxMTU: 100,
|
||||
mtus: []uint32{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
|
||||
},
|
||||
"1280_1500": {
|
||||
minMTU: 1280,
|
||||
maxMTU: 1500,
|
||||
mtus: []uint32{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
assert.Equal(t, testCase.mtus, mtus)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ type Service interface {
|
||||
|
||||
type Routing interface {
|
||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
|
||||
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
|
||||
}
|
||||
|
||||
type PortAllower interface {
|
||||
|
||||
@@ -17,7 +17,7 @@ type PortAllower interface {
|
||||
|
||||
type Routing interface {
|
||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
|
||||
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var ErrRouteDefaultNotFound = errors.New("default route not found")
|
||||
@@ -15,7 +14,7 @@ type DefaultRoute struct {
|
||||
NetInterface string
|
||||
Gateway netip.Addr
|
||||
AssignedIP netip.Addr
|
||||
Family int
|
||||
Family uint8
|
||||
}
|
||||
|
||||
func (d DefaultRoute) String() string {
|
||||
@@ -30,7 +29,7 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Table != unix.RT_TABLE_MAIN {
|
||||
if route.Table != tableMain {
|
||||
// ignore non-main table
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
inboundTable = 200
|
||||
inboundPriority = 100
|
||||
inboundTable uint32 = 200
|
||||
inboundPriority uint32 = 100
|
||||
)
|
||||
|
||||
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
|
||||
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||
func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
|
||||
for _, defaultRoute := range defaultRoutes {
|
||||
assignedIP := defaultRoute.AssignedIP
|
||||
bits := 32
|
||||
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||
func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
|
||||
for _, defaultRoute := range defaultRoutes {
|
||||
assignedIP := defaultRoute.AssignedIP
|
||||
bits := 32
|
||||
|
||||
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
|
||||
|
||||
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
||||
|
||||
func ipMatchesFamily(ip netip.Addr, family int) bool {
|
||||
func ipMatchesFamily(ip netip.Addr, family uint8) bool {
|
||||
return (family == netlink.FamilyV4 && ip.Is4()) ||
|
||||
(family == netlink.FamilyV6 && ip.Is6())
|
||||
}
|
||||
|
||||
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
|
||||
func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
||||
return localNetworks, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
|
||||
localLinks := make(map[int]struct{})
|
||||
localLinks := make(map[uint32]struct{})
|
||||
|
||||
for _, link := range links {
|
||||
if link.EncapType != "ether" {
|
||||
if link.DeviceType != netlink.DeviceTypeEthernet {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -48,7 +47,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Table != unix.RT_TABLE_MAIN ||
|
||||
if route.Table != tableMain ||
|
||||
(route.Gw.IsValid() && !route.Gw.IsUnspecified()) ||
|
||||
(route.Dst.IsValid() && route.Dst.Addr().IsUnspecified()) {
|
||||
continue
|
||||
@@ -96,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
|
||||
|
||||
// Local has higher priority then outbound(99) and inbound(100) as the
|
||||
// local routes might be necessary to reach the outbound/inbound routes.
|
||||
const localPriority = 98
|
||||
const localPriority uint32 = 98
|
||||
|
||||
// Main table was setup correctly by Docker, just need to add rules to use it
|
||||
src := netip.Prefix{}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
}
|
||||
|
||||
// AddrList mocks base method.
|
||||
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) {
|
||||
func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
|
||||
ret0, _ := ret[0].([]netlink.Addr)
|
||||
ret0, _ := ret[0].([]netip.Prefix)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// AddrReplace mocks base method.
|
||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
||||
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
|
||||
}
|
||||
|
||||
// LinkAdd mocks base method.
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkByIndex mocks base method.
|
||||
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) {
|
||||
func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
|
||||
ret0, _ := ret[0].(netlink.Link)
|
||||
@@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkDel mocks base method.
|
||||
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkSetDown mocks base method.
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
||||
}
|
||||
|
||||
// LinkSetUp mocks base method.
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// RouteList mocks base method.
|
||||
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
|
||||
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||
ret0, _ := ret[0].([]netlink.Route)
|
||||
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// RuleList mocks base method.
|
||||
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) {
|
||||
func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleList", arg0)
|
||||
ret0, _ := ret[0].([]netlink.Rule)
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
outboundTable = 199
|
||||
outboundPriority = 99
|
||||
outboundTable uint32 = 199
|
||||
outboundPriority uint32 = 99
|
||||
)
|
||||
|
||||
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
|
||||
|
||||
@@ -9,25 +9,33 @@ import (
|
||||
)
|
||||
|
||||
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||
iface string, table int,
|
||||
iface string, table uint32,
|
||||
) error {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("adding route for " + destinationStr)
|
||||
r.logger.Debug("ip route replace " + destinationStr +
|
||||
" via " + gateway.String() +
|
||||
" dev " + iface +
|
||||
" table " + strconv.Itoa(table))
|
||||
" table " + strconv.Itoa(int(table)))
|
||||
|
||||
link, err := r.netLinker.LinkByName(iface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
||||
}
|
||||
|
||||
family := netlink.FamilyV4
|
||||
if destination.Addr().Is6() {
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
route := netlink.Route{
|
||||
Dst: destination,
|
||||
Gw: gateway,
|
||||
LinkIndex: link.Index,
|
||||
Family: family,
|
||||
Table: table,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
}
|
||||
if err := r.netLinker.RouteReplace(route); err != nil {
|
||||
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
|
||||
@@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||
}
|
||||
|
||||
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||
iface string, table int,
|
||||
iface string, table uint32,
|
||||
) (err error) {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("deleting route for " + destinationStr)
|
||||
r.logger.Debug("ip route delete " + destinationStr +
|
||||
" via " + gateway.String() +
|
||||
" dev " + iface +
|
||||
" table " + strconv.Itoa(table))
|
||||
" table " + strconv.Itoa(int(table)))
|
||||
|
||||
link, err := r.netLinker.LinkByName(iface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
||||
}
|
||||
|
||||
family := netlink.FamilyV4
|
||||
if destination.Addr().Is6() {
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
route := netlink.Route{
|
||||
Dst: destination,
|
||||
Gw: gateway,
|
||||
LinkIndex: link.Index,
|
||||
Family: family,
|
||||
Table: table,
|
||||
}
|
||||
if err := r.netLinker.RouteDel(route); err != nil {
|
||||
|
||||
+10
-10
@@ -15,20 +15,20 @@ type NetLinker interface {
|
||||
}
|
||||
|
||||
type Addresser interface {
|
||||
AddrList(link netlink.Link, family int) (
|
||||
addresses []netlink.Addr, err error)
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
AddrList(linkIndex uint32, family uint8) (
|
||||
addresses []netip.Prefix, err error)
|
||||
AddrReplace(linkIndex uint32, prefix netip.Prefix) error
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
RouteDel(route netlink.Route) error
|
||||
RouteReplace(route netlink.Route) error
|
||||
}
|
||||
|
||||
type Ruler interface {
|
||||
RuleList(family int) (rules []netlink.Rule, err error)
|
||||
RuleList(family uint8) (rules []netlink.Rule, err error)
|
||||
RuleAdd(rule netlink.Rule) error
|
||||
RuleDel(rule netlink.Rule) error
|
||||
}
|
||||
@@ -36,11 +36,11 @@ type Ruler interface {
|
||||
type Linker interface {
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkByIndex(index int) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkByIndex(index uint32) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkDel(index uint32) (err error)
|
||||
LinkSetUp(index uint32) (err error)
|
||||
LinkSetDown(index uint32) (err error)
|
||||
}
|
||||
|
||||
type Routing struct {
|
||||
|
||||
+39
-13
@@ -7,12 +7,19 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
|
||||
rule := netlink.NewRule()
|
||||
rule.Src = src
|
||||
rule.Dst = dst
|
||||
rule.Priority = priority
|
||||
rule.Table = table
|
||||
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
|
||||
family := netlink.FamilyV4
|
||||
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
rule := netlink.Rule{
|
||||
Priority: &priority,
|
||||
Family: family,
|
||||
Table: table,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
|
||||
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||
if err != nil {
|
||||
@@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
|
||||
rule := netlink.NewRule()
|
||||
rule.Src = src
|
||||
rule.Dst = dst
|
||||
rule.Priority = priority
|
||||
rule.Table = table
|
||||
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
|
||||
family := netlink.FamilyV4
|
||||
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
rule := netlink.Rule{
|
||||
Priority: &priority,
|
||||
Family: family,
|
||||
Table: table,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
|
||||
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||
if err != nil {
|
||||
@@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// rulesAreEqual checks whether two rules are equal
|
||||
// only according to src, dst, priority and table.
|
||||
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||
a.Priority == b.Priority &&
|
||||
ptrsEqual(a.Priority, b.Priority) &&
|
||||
a.Table == b.Table
|
||||
}
|
||||
|
||||
@@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
||||
return a.Bits() == b.Bits() &&
|
||||
a.Addr().Compare(b.Addr()) == 0
|
||||
}
|
||||
|
||||
func ptrsEqual(a, b *uint32) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
@@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix {
|
||||
}
|
||||
|
||||
func makeIPRule(src, dst netip.Prefix,
|
||||
table, priority int,
|
||||
table uint32, priority uint32,
|
||||
) netlink.Rule {
|
||||
rule := netlink.NewRule()
|
||||
rule.Src = src
|
||||
rule.Dst = dst
|
||||
rule.Table = table
|
||||
rule.Priority = priority
|
||||
return rule
|
||||
family := netlink.FamilyV4
|
||||
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
return netlink.Rule{
|
||||
Priority: &priority,
|
||||
Family: family,
|
||||
Table: table,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Routing_addIPRule(t *testing.T) {
|
||||
@@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
src netip.Prefix
|
||||
dst netip.Prefix
|
||||
table int
|
||||
priority int
|
||||
table uint32
|
||||
priority uint32
|
||||
ruleList ruleListCall
|
||||
ruleAdd ruleAddCall
|
||||
err error
|
||||
@@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
src netip.Prefix
|
||||
dst netip.Prefix
|
||||
table int
|
||||
priority int
|
||||
table uint32
|
||||
priority uint32
|
||||
ruleList ruleListCall
|
||||
ruleDel ruleDelCall
|
||||
err error
|
||||
@@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func ptrTo[T any](v T) *T { return &v }
|
||||
|
||||
func Test_rulesAreEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
||||
a: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
b: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
},
|
||||
@@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
||||
a: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
b: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
},
|
||||
@@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
||||
a: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 999,
|
||||
Priority: ptrTo(uint32(999)),
|
||||
Table: 101,
|
||||
},
|
||||
b: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
},
|
||||
@@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
||||
a: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Table: 999,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 102,
|
||||
},
|
||||
b: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
},
|
||||
@@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
||||
a: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
b: netlink.Rule{
|
||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
Priority: 100,
|
||||
Priority: ptrTo(uint32(100)),
|
||||
Table: 101,
|
||||
},
|
||||
equal: true,
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package routing
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
const (
|
||||
tableMain = unix.RT_TABLE_MAIN
|
||||
tableLocal = unix.RT_TABLE_LOCAL
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build !linux
|
||||
|
||||
package routing
|
||||
|
||||
const (
|
||||
tableMain = 0
|
||||
tableLocal = 0
|
||||
)
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -34,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
||||
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
|
||||
return route.Gw, nil
|
||||
case route.Dst.IsSingleIP() &&
|
||||
route.Dst.Addr().Compare(route.Src) == 0 &&
|
||||
route.Table == unix.RT_TABLE_LOCAL: // Wireguard
|
||||
route.Src = route.Src.Unmap()
|
||||
if route.Src.Is6() {
|
||||
route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
|
||||
route.Table == tableLocal: // Wireguard
|
||||
if route.Src.Addr().Is6() {
|
||||
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
|
||||
}
|
||||
bytes := route.Src.As4()
|
||||
bytes := route.Src.Addr().As4()
|
||||
// force last byte to 1 to get the VPN gateway IP
|
||||
// This is not necessarily bullet proof but it seems to work.
|
||||
bytes[3] = 1
|
||||
|
||||
+6346
-8186
File diff suppressed because it is too large
Load Diff
@@ -57,15 +57,15 @@ type Storage interface {
|
||||
}
|
||||
|
||||
type NetLinker interface {
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() bool
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
}
|
||||
|
||||
@@ -77,10 +77,11 @@ type Ruler interface {
|
||||
type Linker interface {
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkDel(linkIndex uint32) error
|
||||
LinkSetUp(linkIndex uint32) error
|
||||
LinkSetDown(linkIndex uint32) error
|
||||
LinkSetMTU(linkIndex, mtu uint32) error
|
||||
}
|
||||
|
||||
type DNSLoop interface {
|
||||
|
||||
@@ -47,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
vpnType: settings.Type,
|
||||
serverIP: connection.IP,
|
||||
serverName: connection.ServerName,
|
||||
canPortForward: connection.PortForward,
|
||||
|
||||
@@ -2,16 +2,24 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud"
|
||||
"github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Healthcheck
|
||||
serverIP netip.Addr
|
||||
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||
// It can be "wireguard" or "openvpn".
|
||||
vpnType string
|
||||
// Port forwarding
|
||||
vpnIntf string
|
||||
serverName string // used for PIA
|
||||
@@ -31,6 +39,13 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
||||
l.netLinker, l.routing, mtuLogger)
|
||||
if err != nil {
|
||||
mtuLogger.Error(err.Error())
|
||||
}
|
||||
|
||||
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
||||
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
|
||||
icmpTargetIPs = []netip.Addr{data.serverIP}
|
||||
@@ -120,3 +135,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
||||
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||
|
||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||
) error {
|
||||
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
||||
|
||||
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN gateway IP address: %w", err)
|
||||
}
|
||||
|
||||
link, err := netlinker.LinkByName(vpnInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN interface by name: %w", err)
|
||||
}
|
||||
|
||||
originalMTU := link.MTU
|
||||
|
||||
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||
const physicalLinkMTU uint32 = 1500
|
||||
vpnLinkMTU := physicalLinkMTU
|
||||
switch vpnType {
|
||||
case "wireguard":
|
||||
vpnLinkMTU -= 60 // Wireguard overhead
|
||||
case "openvpn":
|
||||
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||
default:
|
||||
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||
}
|
||||
|
||||
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
||||
|
||||
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
const pingTimeout = time.Second
|
||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
||||
vpnLinkMTU = originalMTU
|
||||
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||
vpnInterface, originalMTU, err)
|
||||
default:
|
||||
return fmt.Errorf("path MTU discovering: %w", err)
|
||||
}
|
||||
|
||||
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,26 +3,20 @@ package wireguard
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
||||
func (w *Wireguard) addAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix,
|
||||
) (err error) {
|
||||
for _, ipNet := range addresses {
|
||||
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
|
||||
for _, address := range addresses {
|
||||
if !*w.settings.IPv6 && address.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
|
||||
address := netlink.Addr{
|
||||
Network: ipNet,
|
||||
}
|
||||
|
||||
err = w.netlink.AddrReplace(link, address)
|
||||
err = w.netlink.AddrReplace(linkIndex, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
||||
err, address, link.Name)
|
||||
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||
err, address, linkIndex)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
linkIndex uint32
|
||||
addrs []netip.Prefix
|
||||
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
|
||||
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: netlink.Link{Type: "wireguard"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(nil)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(nil).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"first add error": {
|
||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(errDummy)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||
},
|
||||
"second add error": {
|
||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(nil)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(errDummy).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||
},
|
||||
"ignore IPv6": {
|
||||
addrs: []netip.Prefix{ipNetTwo},
|
||||
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
|
||||
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
||||
return &Wireguard{
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(false),
|
||||
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
wg := testCase.wgBuilder(ctrl, testCase.link)
|
||||
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
||||
|
||||
err := wg.addAddresses(testCase.link, testCase.addrs)
|
||||
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -1,3 +1,53 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func ptrTo[T any](x T) *T { return &x }
|
||||
|
||||
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||
|
||||
func makeLinkName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||
b := make([]byte, 8)
|
||||
for i := range b {
|
||||
b[i] = alphabet[rng.IntN(len(alphabet))]
|
||||
}
|
||||
return "test" + string(b)
|
||||
}
|
||||
|
||||
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||
ptrsEqual(a.Priority, b.Priority) &&
|
||||
a.Table == b.Table &&
|
||||
a.Family == b.Family &&
|
||||
a.Flags == b.Flags &&
|
||||
a.Action == b.Action &&
|
||||
ptrsEqual(a.Mark, b.Mark)
|
||||
}
|
||||
|
||||
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
||||
if !a.IsValid() && !b.IsValid() {
|
||||
return true
|
||||
}
|
||||
if !a.IsValid() || !b.IsValid() {
|
||||
return false
|
||||
}
|
||||
return a.Bits() == b.Bits() &&
|
||||
a.Addr().Compare(b.Addr()) == 0
|
||||
}
|
||||
|
||||
func ptrsEqual(a, b *uint32) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build netlink && linux
|
||||
//go:build linux
|
||||
|
||||
package wireguard
|
||||
|
||||
@@ -10,13 +10,16 @@ import (
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type noopDebugLogger struct{}
|
||||
|
||||
func (n noopDebugLogger) Debugf(format string, args ...any) {}
|
||||
func (n noopDebugLogger) Patch(options ...log.Option) {}
|
||||
func (n noopDebugLogger) Debug(_ string) {}
|
||||
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Info(_ string) {}
|
||||
func (n noopDebugLogger) Error(_ string) {}
|
||||
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
|
||||
link := netlink.Link{
|
||||
Type: "bridge",
|
||||
Name: "test_8081",
|
||||
}
|
||||
|
||||
// Remove any previously created test interface from a crashed/panic
|
||||
// test or test suite run.
|
||||
err := netlinker.LinkDel(link)
|
||||
if err != nil && err.Error() != "invalid argument" {
|
||||
require.NoError(t, err)
|
||||
DeviceType: netlink.DeviceTypeNone,
|
||||
VirtualType: "bridge",
|
||||
Name: makeLinkName(),
|
||||
}
|
||||
|
||||
linkIndex, err := netlinker.LinkAdd(link)
|
||||
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
link.Index = linkIndex
|
||||
|
||||
defer func() {
|
||||
err = netlinker.LinkDel(link)
|
||||
err = netlinker.LinkDel(linkIndex)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
}
|
||||
|
||||
const addIterations = 2 // initial + replace
|
||||
|
||||
for i := 0; i < addIterations; i++ {
|
||||
err = wg.addAddresses(link, addresses)
|
||||
for range addIterations {
|
||||
err = wg.addAddresses(link.Index, addresses)
|
||||
require.NoError(t, err)
|
||||
|
||||
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
|
||||
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
||||
for i, netlinkAddress := range netlinkAddresses {
|
||||
require.NotNil(t, netlinkAddress.Network)
|
||||
assert.Equal(t, addresses[i], netlinkAddress.Network)
|
||||
require.Equal(t, len(addresses), len(ipPrefixes))
|
||||
for i, ipPrefix := range ipPrefixes {
|
||||
assert.Equal(t, addresses[i], ipPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
logger: &noopDebugLogger{},
|
||||
}
|
||||
|
||||
rulePriority := 10000
|
||||
const firewallMark = 999
|
||||
const family = unix.AF_INET // ipv4
|
||||
// Unique combination for this test
|
||||
const rulePriority uint32 = 10000
|
||||
const firewallMark uint32 = 12345
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
t.Cleanup(func() {
|
||||
err := cleanup()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
})
|
||||
|
||||
rules, err := netlinker.RuleList(netlink.FamilyV4)
|
||||
require.NoError(t, err)
|
||||
expectedRule := netlink.Rule{
|
||||
Priority: ptrTo(rulePriority),
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Mark: ptrTo(firewallMark),
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
var rule netlink.Rule
|
||||
var ruleFound bool
|
||||
for _, rule = range rules {
|
||||
if rule.Mark == firewallMark {
|
||||
if rulesAreEqual(rule, expectedRule) {
|
||||
ruleFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, ruleFound)
|
||||
expectedRule := netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
}
|
||||
assert.Equal(t, expectedRule, rule)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority,
|
||||
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
|
||||
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
|
||||
}
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/netlink"
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||
|
||||
type NetLinker interface {
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() bool
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
}
|
||||
|
||||
@@ -23,10 +27,10 @@ type Ruler interface {
|
||||
}
|
||||
|
||||
type Linker interface {
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) error
|
||||
LinkDel(link netlink.Link) error
|
||||
LinkSetUp(linkIndex uint32) error
|
||||
LinkSetDown(linkIndex uint32) error
|
||||
LinkDel(linkIndex uint32) error
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
}
|
||||
|
||||
// AddrReplace mocks base method.
|
||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
||||
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
|
||||
}
|
||||
|
||||
// IsWireguardSupported mocks base method.
|
||||
func (m *MockNetLinker) IsWireguardSupported() bool {
|
||||
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
||||
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkAdd mocks base method.
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkDel mocks base method.
|
||||
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkSetDown mocks base method.
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
||||
}
|
||||
|
||||
// LinkSetUp mocks base method.
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// RouteList mocks base method.
|
||||
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
|
||||
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||
ret0, _ := ret[0].([]netlink.Route)
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
||||
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32,
|
||||
) (err error) {
|
||||
for _, dst := range destinations {
|
||||
err = w.addRoute(link, dst, firewallMark)
|
||||
err = w.addRoute(linkIndex, dst, firewallMark)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32,
|
||||
) (err error) {
|
||||
family := netlink.FamilyV4
|
||||
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
route := netlink.Route{
|
||||
LinkIndex: link.Index,
|
||||
LinkIndex: linkIndex,
|
||||
Dst: dst,
|
||||
Family: family,
|
||||
Table: int(firewallMark),
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"adding route for link %s, destination %s and table %d: %w",
|
||||
link.Name, dst, firewallMark, err)
|
||||
"adding route for link with index %d, destination %s and table %d: %w",
|
||||
linkIndex, dst, firewallMark, err)
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
dst netip.Prefix
|
||||
expectedRoute netlink.Route
|
||||
routeAddErr error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: netlink.Link{
|
||||
Index: linkIndex,
|
||||
},
|
||||
dst: ipPrefix,
|
||||
expectedRoute: netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipPrefix,
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
},
|
||||
},
|
||||
"route add error": {
|
||||
link: netlink.Link{
|
||||
Name: "a_bridge",
|
||||
Index: linkIndex,
|
||||
},
|
||||
dst: ipPrefix,
|
||||
expectedRoute: netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipPrefix,
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
},
|
||||
routeAddErr: errDummy,
|
||||
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
||||
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
||||
},
|
||||
}
|
||||
|
||||
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
|
||||
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -7,15 +7,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
|
||||
family int,
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
family uint8,
|
||||
) (cleanup func() error, err error) {
|
||||
rule := netlink.NewRule()
|
||||
rule.Invert = true
|
||||
rule.Priority = rulePriority
|
||||
rule.Mark = firewallMark
|
||||
rule.Table = int(firewallMark)
|
||||
rule.Family = family
|
||||
rule := netlink.Rule{
|
||||
Priority: &rulePriority,
|
||||
Family: family,
|
||||
Table: firewallMark,
|
||||
Mark: &firewallMark,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
if strings.HasSuffix(err.Error(), "file exists") {
|
||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
|
||||
@@ -8,15 +8,14 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority = 987
|
||||
const firewallMark = 456
|
||||
const family = unix.AF_INET
|
||||
const rulePriority uint32 = 987
|
||||
const firewallMark uint32 = 456
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
|
||||
}{
|
||||
"success": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
},
|
||||
"rule add error": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
ruleAddErr: errDummy,
|
||||
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
|
||||
},
|
||||
"rule delete error": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
ruleDelErr: errDummy,
|
||||
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
|
||||
|
||||
+42
-41
@@ -7,15 +7,14 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
@@ -34,7 +33,11 @@ var (
|
||||
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
kernelSupported := w.netlink.IsWireguardSupported()
|
||||
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
||||
return
|
||||
}
|
||||
|
||||
setupFunction := setupUserSpace
|
||||
switch w.settings.Implementation {
|
||||
@@ -67,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
|
||||
link, waitAndCleanup, err := setupFunction(ctx,
|
||||
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
|
||||
if err != nil {
|
||||
waitError <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addAddresses(link, w.settings.Addresses)
|
||||
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
return
|
||||
@@ -87,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
return
|
||||
}
|
||||
|
||||
linkIndex, err := w.netlink.LinkSetUp(link)
|
||||
err = w.netlink.LinkSetUp(linkIndex)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
return
|
||||
}
|
||||
link.Index = linkIndex
|
||||
closers.add("shutting down link", stepFour, func() error {
|
||||
return w.netlink.LinkSetDown(link)
|
||||
return w.netlink.LinkSetDown(linkIndex)
|
||||
})
|
||||
|
||||
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
return
|
||||
@@ -106,7 +108,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
if *w.settings.IPv6 {
|
||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, unix.AF_INET6)
|
||||
w.settings.FirewallMark, netlink.FamilyV6)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||
return
|
||||
@@ -115,7 +117,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
}
|
||||
|
||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, unix.AF_INET)
|
||||
w.settings.FirewallMark, netlink.FamilyV4)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
||||
return
|
||||
@@ -133,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
type waitAndCleanupFunc func() error
|
||||
|
||||
func setupKernelSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger) (
|
||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
) {
|
||||
link = netlink.Link{
|
||||
Type: "wireguard",
|
||||
Name: interfaceName,
|
||||
MTU: mtu,
|
||||
}
|
||||
links, err := netLinker.LinkList()
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("listing links: %w", err)
|
||||
return 0, nil, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
|
||||
// Cleanup any previous Wireguard interface with the same name
|
||||
// See https://github.com/qdm12/gluetun/issues/1669
|
||||
for _, link := range links {
|
||||
if link.Type == "wireguard" && link.Name == interfaceName {
|
||||
err = netLinker.LinkDel(link)
|
||||
if link.VirtualType == "wireguard" && link.Name == interfaceName {
|
||||
err = netLinker.LinkDel(link.Index)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
||||
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
||||
interfaceName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
linkIndex, err := netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
link := netlink.Link{
|
||||
VirtualType: "wireguard",
|
||||
Name: interfaceName,
|
||||
MTU: mtu,
|
||||
}
|
||||
linkIndex, err = netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
}
|
||||
link.Index = linkIndex
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
return netLinker.LinkDel(link)
|
||||
return netLinker.LinkDel(linkIndex)
|
||||
})
|
||||
|
||||
waitAndCleanup = func() error {
|
||||
@@ -174,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return link, waitAndCleanup, nil
|
||||
return linkIndex, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func setupUserSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger) (
|
||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
) {
|
||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
} else if tunName != interfaceName {
|
||||
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err = netLinker.LinkByName(interfaceName)
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
return netLinker.LinkDel(link)
|
||||
return netLinker.LinkDel(link.Index)
|
||||
})
|
||||
|
||||
bind := conn.NewDefaultBind()
|
||||
@@ -217,16 +218,16 @@ func setupUserSpace(ctx context.Context,
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := ipc.UAPIOpen(interfaceName)
|
||||
uapiFile, err := uapiOpen(interfaceName)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
|
||||
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
|
||||
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
@@ -251,7 +252,7 @@ func setupUserSpace(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
return link, waitAndCleanup, nil
|
||||
return link.Index, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
|
||||
@@ -38,10 +38,10 @@ type Settings struct {
|
||||
FirewallMark uint32
|
||||
// Maximum Transmission Unit (MTU) setting for the network interface.
|
||||
// It defaults to device.DefaultMTU from wireguard-go which is 1420
|
||||
MTU uint16
|
||||
MTU uint32
|
||||
// RulePriority is the priority for the rule created with the
|
||||
// FirewallMark.
|
||||
RulePriority int
|
||||
RulePriority uint32
|
||||
// IPv6 can bet set to true if IPv6 should be handled.
|
||||
// It defaults to false if left unset.
|
||||
IPv6 *bool
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(name)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, uapiFile)
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build !linux
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
Reference in New Issue
Block a user