mirror of
https://github.com/qdm12/gluetun.git
synced 2026-07-05 18:19:51 +02:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db947c17a8 | |||
| b0a75673bd | |||
| 5f0c499808 | |||
| bdd69a1fb7 | |||
| 1af75bb30c | |||
| 9c1cd7e8b1 | |||
| facc6df3be | |||
| e292a4c9be | |||
| 9e4dd61c19 | |||
| fe3d4a94d4 | |||
| de38d759a4 | |||
| fba60af772 | |||
| 9b9b723887 |
@@ -59,10 +59,13 @@ jobs:
|
|||||||
- name: Run tests in test container
|
- name: Run tests in test container
|
||||||
run: |
|
run: |
|
||||||
touch coverage.txt
|
touch coverage.txt
|
||||||
docker run --rm --device /dev/net/tun \
|
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
|
||||||
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
||||||
test-container
|
test-container
|
||||||
|
|
||||||
|
- name: Verify dev cross platform compatibility
|
||||||
|
run: docker build --target xcompile .
|
||||||
|
|
||||||
- name: Build final image
|
- name: Build final image
|
||||||
run: docker build -t final-image .
|
run: docker build -t final-image .
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ RUN git init && \
|
|||||||
git diff --exit-code && \
|
git diff --exit-code && \
|
||||||
rm -rf .git/
|
rm -rf .git/
|
||||||
|
|
||||||
|
FROM --platform=${BUILDPLATFORM} base AS xcompile
|
||||||
|
RUN GOOS=darwin go build -o /dev/null ./...
|
||||||
|
RUN GOOS=windows go build -o /dev/null ./...
|
||||||
|
|
||||||
FROM --platform=${BUILDPLATFORM} base AS build
|
FROM --platform=${BUILDPLATFORM} base AS build
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
ARG VERSION=unknown
|
ARG VERSION=unknown
|
||||||
|
|||||||
+14
-12
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -393,7 +394,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
dnsLogger := logger.New(log.SetComponent("dns"))
|
dnsLogger := logger.New(log.SetComponent("dns"))
|
||||||
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
|
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
|
||||||
dnsLogger)
|
dnsLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating DNS loop: %w", err)
|
return fmt.Errorf("creating DNS loop: %w", err)
|
||||||
@@ -553,26 +554,26 @@ type netLinker interface {
|
|||||||
Router
|
Router
|
||||||
Ruler
|
Ruler
|
||||||
Linker
|
Linker
|
||||||
IsWireguardSupported() bool
|
IsWireguardSupported() (ok bool, err error)
|
||||||
IsIPv6Supported() (ok bool, err error)
|
IsIPv6Supported() (ok bool, err error)
|
||||||
PatchLoggerLevel(level log.Level)
|
PatchLoggerLevel(level log.Level)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Addresser interface {
|
type Addresser interface {
|
||||||
AddrList(link netlink.Link, family int) (
|
AddrList(linkIndex uint32, family uint8) (
|
||||||
addresses []netlink.Addr, err error)
|
addresses []netip.Prefix, err error)
|
||||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(family int) (routes []netlink.Route, err error)
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
RouteAdd(route netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
RouteDel(route netlink.Route) error
|
RouteDel(route netlink.Route) error
|
||||||
RouteReplace(route netlink.Route) error
|
RouteReplace(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleList(family int) (rules []netlink.Rule, err error)
|
RuleList(family uint8) (rules []netlink.Rule, err error)
|
||||||
RuleAdd(rule netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
@@ -580,11 +581,12 @@ type Ruler interface {
|
|||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkByIndex(index int) (link netlink.Link, err error)
|
LinkByIndex(index uint32) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(linkIndex uint32) (err error)
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(linkIndex uint32) (err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(linkIndex uint32) (err error)
|
||||||
|
LinkSetMTU(linkIndex, mtu uint32) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type clier interface {
|
type clier interface {
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ require (
|
|||||||
github.com/breml/rootcerts v0.3.3
|
github.com/breml/rootcerts v0.3.3
|
||||||
github.com/fatih/color v1.18.0
|
github.com/fatih/color v1.18.0
|
||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
|
github.com/jsimonetti/rtnetlink v1.4.2
|
||||||
github.com/klauspost/compress v1.18.1
|
github.com/klauspost/compress v1.18.1
|
||||||
github.com/klauspost/pgzip v1.2.6
|
github.com/klauspost/pgzip v1.2.6
|
||||||
|
github.com/mdlayher/genetlink v1.3.2
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4
|
github.com/pelletier/go-toml/v2 v2.2.4
|
||||||
github.com/qdm12/dns/v2 v2.0.0-rc10
|
github.com/qdm12/dns/v2 v2.0.0-rc10
|
||||||
github.com/qdm12/gosettings v0.4.4
|
github.com/qdm12/gosettings v0.4.4
|
||||||
@@ -19,12 +21,11 @@ require (
|
|||||||
github.com/qdm12/ss-server v0.6.0
|
github.com/qdm12/ss-server v0.6.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/ulikunitz/xz v0.5.15
|
github.com/ulikunitz/xz v0.5.15
|
||||||
github.com/vishvananda/netlink v1.3.1
|
|
||||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
||||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||||
golang.org/x/net v0.47.0
|
golang.org/x/net v0.49.0
|
||||||
golang.org/x/sys v0.38.0
|
golang.org/x/sys v0.40.0
|
||||||
golang.org/x/text v0.31.0
|
golang.org/x/text v0.33.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
gopkg.in/ini.v1 v1.67.0
|
gopkg.in/ini.v1 v1.67.0
|
||||||
@@ -38,13 +39,12 @@ require (
|
|||||||
github.com/cloudflare/circl v1.6.1 // indirect
|
github.com/cloudflare/circl v1.6.1 // indirect
|
||||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/josharian/native v1.1.0 // indirect
|
github.com/josharian/native v1.1.0 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
|
||||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||||
github.com/mdlayher/socket v0.4.1 // indirect
|
github.com/mdlayher/socket v0.5.1 // indirect
|
||||||
github.com/miekg/dns v1.1.62 // indirect
|
github.com/miekg/dns v1.1.62 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
@@ -55,12 +55,11 @@ require (
|
|||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
|
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
|
||||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
golang.org/x/crypto v0.47.0 // indirect
|
||||||
golang.org/x/crypto v0.45.0 // indirect
|
golang.org/x/mod v0.31.0 // indirect
|
||||||
golang.org/x/mod v0.29.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
golang.org/x/sync v0.18.0 // indirect
|
|
||||||
golang.org/x/time v0.3.0 // indirect
|
golang.org/x/time v0.3.0 // indirect
|
||||||
golang.org/x/tools v0.38.0 // indirect
|
golang.org/x/tools v0.40.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/protobuf v1.35.1 // indirect
|
google.golang.org/protobuf v1.35.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
|
|||||||
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
|
||||||
|
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
|
||||||
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
|
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
|
||||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||||
@@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
|||||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||||
|
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
|
||||||
|
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
|
||||||
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
|
||||||
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
|
||||||
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
|
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
|
||||||
@@ -47,8 +51,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
|||||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||||
@@ -93,10 +97,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
|||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
|
||||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
|
||||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
|
||||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
|
||||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
|
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
|
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
@@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||||||
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
@@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
|
|||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -140,12 +140,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
@@ -155,8 +153,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -164,8 +162,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
|||||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.Name == providers.Mullvad && vpnType == vpn.OpenVPN {
|
||||||
|
warner.Warn("https://mullvad.net/en/blog/removing-openvpn-15th-january-2026")
|
||||||
|
}
|
||||||
|
|
||||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server selection: %w", err)
|
return fmt.Errorf("server selection: %w", err)
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ type Wireguard struct {
|
|||||||
// It has been lowered to 1320 following quite a bit of
|
// It has been lowered to 1320 following quite a bit of
|
||||||
// investigation in the issue:
|
// investigation in the issue:
|
||||||
// https://github.com/qdm12/gluetun/issues/2533.
|
// https://github.com/qdm12/gluetun/issues/2533.
|
||||||
MTU uint16 `json:"mtu"`
|
// Note this should now be replaced with the PMTUD feature.
|
||||||
|
MTU uint32 `json:"mtu"`
|
||||||
// Implementation is the Wireguard implementation to use.
|
// Implementation is the Wireguard implementation to use.
|
||||||
// It can be "auto", "userspace" or "kernelspace".
|
// It can be "auto", "userspace" or "kernelspace".
|
||||||
// It defaults to "auto" and cannot be the empty string
|
// It defaults to "auto" and cannot be the empty string
|
||||||
@@ -272,7 +273,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if mtuPtr != nil {
|
} else if mtuPtr != nil {
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(s string)
|
||||||
|
Info(s string)
|
||||||
|
Warn(s string)
|
||||||
|
Error(s string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Firewall interface {
|
||||||
|
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
|
||||||
|
}
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
type Logger interface {
|
|
||||||
Debug(s string)
|
|
||||||
Info(s string)
|
|
||||||
Warn(s string)
|
|
||||||
Error(s string)
|
|
||||||
}
|
|
||||||
@@ -24,6 +24,7 @@ type Loop struct {
|
|||||||
localResolvers []netip.Addr
|
localResolvers []netip.Addr
|
||||||
resolvConf string
|
resolvConf string
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
firewall Firewall
|
||||||
logger Logger
|
logger Logger
|
||||||
userTrigger bool
|
userTrigger bool
|
||||||
start <-chan struct{}
|
start <-chan struct{}
|
||||||
@@ -39,7 +40,7 @@ type Loop struct {
|
|||||||
const defaultBackoffTime = 10 * time.Second
|
const defaultBackoffTime = 10 * time.Second
|
||||||
|
|
||||||
func NewLoop(settings settings.DNS,
|
func NewLoop(settings settings.DNS,
|
||||||
client *http.Client, logger Logger,
|
client *http.Client, firewall Firewall, logger Logger,
|
||||||
) (loop *Loop, err error) {
|
) (loop *Loop, err error) {
|
||||||
start := make(chan struct{})
|
start := make(chan struct{})
|
||||||
running := make(chan models.LoopStatus)
|
running := make(chan models.LoopStatus)
|
||||||
@@ -64,6 +65,7 @@ func NewLoop(settings settings.DNS,
|
|||||||
filter: filter,
|
filter: filter,
|
||||||
resolvConf: "/etc/resolv.conf",
|
resolvConf: "/etc/resolv.conf",
|
||||||
client: client,
|
client: client,
|
||||||
|
firewall: firewall,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
userTrigger: true,
|
userTrigger: true,
|
||||||
start: start,
|
start: start,
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *Loop) useUnencryptedDNS(fallback bool) {
|
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
|
|
||||||
targetIP := settings.GetFirstPlaintextIPv4()
|
targetIP := settings.GetFirstPlaintextIPv4()
|
||||||
@@ -20,8 +21,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
|||||||
|
|
||||||
const dialTimeout = 3 * time.Second
|
const dialTimeout = 3 * time.Second
|
||||||
const defaultDNSPort = 53
|
const defaultDNSPort = 53
|
||||||
|
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
|
||||||
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
settingsInternalDNS := nameserver.SettingsInternalDNS{
|
||||||
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
|
AddrPort: addrPort,
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
}
|
}
|
||||||
nameserver.UseDNSInternally(settingsInternalDNS)
|
nameserver.UseDNSInternally(settingsInternalDNS)
|
||||||
@@ -34,4 +36,9 @@ func (l *Loop) useUnencryptedDNS(fallback bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-5
@@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
"and go through your container network DNS outside the VPN tunnel!")
|
"and go through your container network DNS outside the VPN tunnel!")
|
||||||
} else {
|
} else {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
|
|
||||||
if !errors.Is(err, errUpdateBlockLists) {
|
if !errors.Is(err, errUpdateBlockLists) {
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
settings = l.GetSettings()
|
settings = l.GetSettings()
|
||||||
@@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
settings = l.GetSettings()
|
settings = l.GetSettings()
|
||||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.userTrigger = false
|
l.userTrigger = false
|
||||||
@@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
|||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
if !*settings.KeepNameserver && *settings.ServerEnabled {
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
l.stopServer()
|
l.stopServer()
|
||||||
}
|
}
|
||||||
l.stopped <- struct{}{}
|
l.stopped <- struct{}{}
|
||||||
@@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
|||||||
case err := <-runError: // unexpected error
|
case err := <-runError: // unexpected error
|
||||||
l.statusManager.SetStatus(constants.Crashed)
|
l.statusManager.SetStatus(constants.Crashed)
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(ctx, fallback)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,9 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
|||||||
|
|
||||||
// use internal DNS server
|
// use internal DNS server
|
||||||
const defaultDNSPort = 53
|
const defaultDNSPort = 53
|
||||||
|
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
|
||||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
AddrPort: addrPort,
|
||||||
})
|
})
|
||||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||||
IPs: []netip.Addr{settings.ServerAddress},
|
IPs: []netip.Addr{settings.ServerAddress},
|
||||||
@@ -50,6 +51,11 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
|||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
err = check.WaitForDNS(ctx, check.Settings{})
|
err = check.WaitForDNS(ctx, check.Settings{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.stopServer()
|
l.stopServer()
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
"invalid_instruction": {
|
"invalid_instruction": {
|
||||||
instruction: "invalid",
|
instruction: "invalid",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||||
"fields count 1 is not even: \"invalid\"",
|
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"list_error": {
|
"list_error": {
|
||||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ type Config struct {
|
|||||||
outboundSubnets []netip.Prefix
|
outboundSubnets []netip.Prefix
|
||||||
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||||
portRedirections portRedirections
|
portRedirections portRedirections
|
||||||
|
outputAddrPort map[uint16]netip.Addr
|
||||||
stateMutex sync.Mutex
|
stateMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ func NewConfig(ctx context.Context, logger Logger,
|
|||||||
runner: runner,
|
runner: runner,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||||
|
outputAddrPort: make(map[uint16]netip.Addr),
|
||||||
ipTables: iptables,
|
ipTables: iptables,
|
||||||
ip6Tables: ip6tables,
|
ip6Tables: ip6tables,
|
||||||
customRulesPath: "/iptables/post-rules.txt",
|
customRulesPath: "/iptables/post-rules.txt",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package firewall
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||||
@@ -19,3 +20,15 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
|
|||||||
}
|
}
|
||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
|
||||||
|
ipv4Instructions, ipv6Instructions []string,
|
||||||
|
) error {
|
||||||
|
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
|
||||||
|
return fmt.Errorf("running iptables instructions: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
|
||||||
|
return fmt.Errorf("running ip6tables instructions: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
+111
-19
@@ -9,9 +9,19 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type operation uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
opNone operation = iota
|
||||||
|
opAppend
|
||||||
|
opDelete
|
||||||
|
opInsert
|
||||||
|
opReplace
|
||||||
|
)
|
||||||
|
|
||||||
type iptablesInstruction struct {
|
type iptablesInstruction struct {
|
||||||
table string // defaults to "filter", and can be "nat" for example.
|
table string // defaults to "filter", and can be "nat" for example.
|
||||||
append bool
|
operation operation
|
||||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||||
target string // for example ACCEPT. Can be empty.
|
target string // for example ACCEPT. Can be empty.
|
||||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||||
@@ -22,6 +32,7 @@ type iptablesInstruction struct {
|
|||||||
destinationPort uint16 // if zero, there is no destination port
|
destinationPort uint16 // if zero, there is no destination port
|
||||||
toPorts []uint16 // if empty, there is no redirection
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
ctstate []string // if empty, there is no ctstate
|
ctstate []string // if empty, there is no ctstate
|
||||||
|
lineNumber uint16 // for replace operation, the line number to replace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesInstruction) setDefaults() {
|
func (i *iptablesInstruction) setDefaults() {
|
||||||
@@ -60,6 +71,58 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iptablesInstruction) String() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if i.table != "" && i.table != "filter" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
|
||||||
|
}
|
||||||
|
switch i.operation {
|
||||||
|
case opNone:
|
||||||
|
panic("no operation specified")
|
||||||
|
case opAppend:
|
||||||
|
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
|
||||||
|
case opDelete:
|
||||||
|
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
|
||||||
|
case opInsert:
|
||||||
|
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
|
||||||
|
case opReplace:
|
||||||
|
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
|
||||||
|
}
|
||||||
|
if i.inputInterface != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
|
||||||
|
}
|
||||||
|
if i.outputInterface != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
|
||||||
|
}
|
||||||
|
if i.protocol != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
|
||||||
|
}
|
||||||
|
if i.source.IsValid() {
|
||||||
|
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
|
||||||
|
}
|
||||||
|
if i.destination.IsValid() {
|
||||||
|
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
|
||||||
|
}
|
||||||
|
if i.destinationPort != 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
|
||||||
|
}
|
||||||
|
if len(i.ctstate) > 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
|
||||||
|
}
|
||||||
|
if len(i.toPorts) > 0 {
|
||||||
|
var portStrings []string
|
||||||
|
for _, port := range i.toPorts {
|
||||||
|
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
|
||||||
|
}
|
||||||
|
if i.target != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||||
@@ -77,34 +140,63 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
|||||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||||
}
|
}
|
||||||
fields := strings.Fields(s)
|
fields := strings.Fields(s)
|
||||||
if len(fields)%2 != 0 {
|
|
||||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
|
||||||
ErrIptablesCommandMalformed, len(fields), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(fields); i += 2 {
|
i := 0
|
||||||
key := fields[i]
|
for i < len(fields) {
|
||||||
value := fields[i+1]
|
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||||
err = parseInstructionFlag(key, value, &instruction)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||||
}
|
}
|
||||||
|
i += consumed
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction.setDefaults()
|
instruction.setDefaults()
|
||||||
return instruction, nil
|
return instruction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||||
switch key {
|
flag := fields[0]
|
||||||
|
|
||||||
|
// All flags use one value after the flag, except the following:
|
||||||
|
switch flag {
|
||||||
|
case "-R", "--replace":
|
||||||
|
const expected = 3
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||||
|
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
default:
|
||||||
|
const expected = 2
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||||
|
ErrIptablesCommandMalformed, flag)
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
}
|
||||||
|
value := fields[1]
|
||||||
|
|
||||||
|
switch flag {
|
||||||
case "-t", "--table":
|
case "-t", "--table":
|
||||||
instruction.table = value
|
instruction.table = value
|
||||||
case "-D", "--delete":
|
case "-D", "--delete":
|
||||||
instruction.append = false
|
instruction.operation = opDelete
|
||||||
instruction.chain = value
|
instruction.chain = value
|
||||||
case "-A", "--append":
|
case "-A", "--append":
|
||||||
instruction.append = true
|
instruction.operation = opAppend
|
||||||
instruction.chain = value
|
instruction.chain = value
|
||||||
|
case "-I", "--insert":
|
||||||
|
instruction.operation = opInsert
|
||||||
|
instruction.chain = value
|
||||||
|
case "-R", "--replace":
|
||||||
|
instruction.operation = opReplace
|
||||||
|
instruction.chain = value
|
||||||
|
const base, bits = 10, 16
|
||||||
|
n, err := strconv.ParseUint(fields[2], base, bits)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
|
||||||
|
}
|
||||||
|
instruction.lineNumber = uint16(n)
|
||||||
case "-j", "--jump":
|
case "-j", "--jump":
|
||||||
instruction.target = value
|
instruction.target = value
|
||||||
case "-p", "--protocol":
|
case "-p", "--protocol":
|
||||||
@@ -117,18 +209,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
case "-s", "--source":
|
case "-s", "--source":
|
||||||
instruction.source, err = parseIPPrefix(value)
|
instruction.source, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "-d", "--destination":
|
case "-d", "--destination":
|
||||||
instruction.destination, err = parseIPPrefix(value)
|
instruction.destination, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "--dport":
|
case "--dport":
|
||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
instruction.destinationPort = uint16(destinationPort)
|
instruction.destinationPort = uint16(destinationPort)
|
||||||
case "--ctstate":
|
case "--ctstate":
|
||||||
@@ -140,14 +232,14 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing port redirection: %w", err)
|
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||||
}
|
}
|
||||||
instruction.toPorts[i] = uint16(port)
|
instruction.toPorts[i] = uint16(port)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
|
||||||
}
|
}
|
||||||
return nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||||
|
|||||||
@@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
"uneven_fields": {
|
"uneven_fields": {
|
||||||
s: "-A",
|
s: "-A",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"unknown_key": {
|
"unknown_key": {
|
||||||
s: "-x something",
|
s: "-x something",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
|
||||||
},
|
},
|
||||||
"one_pair": {
|
"one_pair": {
|
||||||
s: "-A INPUT",
|
s: "-I INPUT",
|
||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "filter",
|
table: "filter",
|
||||||
chain: "INPUT",
|
chain: "INPUT",
|
||||||
append: true,
|
operation: opInsert,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"instruction_A": {
|
"instruction_A": {
|
||||||
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "filter",
|
table: "filter",
|
||||||
chain: "INPUT",
|
chain: "INPUT",
|
||||||
append: true,
|
operation: opAppend,
|
||||||
inputInterface: "tun0",
|
inputInterface: "tun0",
|
||||||
protocol: "tcp",
|
protocol: "tcp",
|
||||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||||
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
instruction: iptablesInstruction{
|
instruction: iptablesInstruction{
|
||||||
table: "nat",
|
table: "nat",
|
||||||
chain: "PREROUTING",
|
chain: "PREROUTING",
|
||||||
append: false,
|
operation: opDelete,
|
||||||
inputInterface: "tun0",
|
inputInterface: "tun0",
|
||||||
protocol: "tcp",
|
protocol: "tcp",
|
||||||
destinationPort: 43716,
|
destinationPort: 43716,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package firewall
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,3 +82,133 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
|
||||||
|
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
|
||||||
|
// If the port was previously allowed for another IP address, that previous allowance will be removed.
|
||||||
|
// Giving an invalid address will remove any existing restrictions for the port specified.
|
||||||
|
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
existingIP := c.outputAddrPort[addrPort.Port()]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case existingIP == addrPort.Addr():
|
||||||
|
return nil
|
||||||
|
case !addrPort.Addr().IsValid():
|
||||||
|
// invalid address, remove any existing rules for the port
|
||||||
|
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
|
||||||
|
case !existingIP.IsValid():
|
||||||
|
// no previous existing address for the port
|
||||||
|
return c.insertOutputAddrPortRestriction(ctx, addrPort)
|
||||||
|
default:
|
||||||
|
// existing rule in the same IP family or different family
|
||||||
|
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
|
||||||
|
commonInstructions := []string{
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
|
||||||
|
}
|
||||||
|
ipv4Instructions := commonInstructions
|
||||||
|
ipv6Instructions := commonInstructions
|
||||||
|
|
||||||
|
familySpecificInstructions := []string{
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||||
|
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
|
||||||
|
}
|
||||||
|
if existingIP.Is4() {
|
||||||
|
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||||
|
} else {
|
||||||
|
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(c.outputAddrPort, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
|
||||||
|
commonInstructions := []string{
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
|
||||||
|
}
|
||||||
|
ipv4Instructions := commonInstructions
|
||||||
|
ipv6Instructions := commonInstructions
|
||||||
|
|
||||||
|
familySpecificInstructions := []string{
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||||
|
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
|
||||||
|
}
|
||||||
|
if addrPort.Addr().Is4() {
|
||||||
|
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
|
||||||
|
} else {
|
||||||
|
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
|
||||||
|
}
|
||||||
|
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
|
||||||
|
existingIP netip.Addr, addrPort netip.AddrPort,
|
||||||
|
) (err error) {
|
||||||
|
for _, protocol := range [...]string{"udp", "tcp"} {
|
||||||
|
switch {
|
||||||
|
case existingIP.Is4() && addrPort.Addr().Is4():
|
||||||
|
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is6() && addrPort.Addr().Is6():
|
||||||
|
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is4() && addrPort.Addr().Is6():
|
||||||
|
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
err = c.runIptablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("removing existing IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting new IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
case existingIP.Is6() && addrPort.Addr().Is4():
|
||||||
|
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), existingIP)
|
||||||
|
err = c.runIP6tablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("removing existing IPv6 rule: %w", err)
|
||||||
|
}
|
||||||
|
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
|
||||||
|
protocol, addrPort.Port(), addrPort.Addr())
|
||||||
|
err = c.runIptablesInstruction(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting new IPv4 rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errRuleNotFound = errors.New("rule not found")
|
||||||
|
|
||||||
|
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||||
|
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||||
|
}
|
||||||
|
parsed, err := parseIptablesInstruction(newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||||
|
}
|
||||||
|
parsed.operation = opReplace
|
||||||
|
parsed.lineNumber = lineNumber
|
||||||
|
return c.runIptablesInstruction(ctx, parsed.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
|
||||||
|
targetRule, err := parseIptablesInstruction(oldInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing iptables command to replace: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
|
||||||
|
}
|
||||||
|
parsed, err := parseIptablesInstruction(newInstruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing replacement iptables command: %w", err)
|
||||||
|
}
|
||||||
|
parsed.operation = opReplace
|
||||||
|
parsed.lineNumber = lineNumber
|
||||||
|
return c.runIP6tablesInstruction(ctx, parsed.String())
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
package mod
|
package mod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package mod
|
||||||
|
|
||||||
|
func Probe(moduleName string) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
+59
-17
@@ -1,33 +1,75 @@
|
|||||||
//go:build linux || darwin
|
|
||||||
|
|
||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/vishvananda/netlink"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/jsimonetti/rtnetlink/rtnl"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (n *NetLink) AddrList(link Link, family int) (
|
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
|
||||||
addresses []Addr, err error,
|
ipPrefixes []netip.Prefix, err error,
|
||||||
) {
|
) {
|
||||||
netlinkLink := linkToNetlinkLink(&link)
|
conn, err := rtnl.Dial(nil)
|
||||||
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ifc := &net.Interface{
|
||||||
|
Index: int(linkIndex),
|
||||||
|
}
|
||||||
|
ipNets, err := conn.Addrs(ifc, int(family))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list addresses: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
addresses = make([]Addr, len(netlinkAddresses))
|
ipPrefixes = make([]netip.Prefix, len(ipNets))
|
||||||
for i := range netlinkAddresses {
|
for i := range ipNets {
|
||||||
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
|
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
return addresses, nil
|
return ipPrefixes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
|
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
|
||||||
netlinkLink := linkToNetlinkLink(&link)
|
conn, err := rtnl.Dial(nil)
|
||||||
netlinkAddress := netlink.Addr{
|
if err != nil {
|
||||||
IPNet: netipPrefixToIPNet(addr.Network),
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ipNet := netipPrefixToIPNet(prefix)
|
||||||
|
|
||||||
|
// Remove any address identical to the one we want to add
|
||||||
|
family := FamilyV4
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
family = FamilyV6
|
||||||
|
}
|
||||||
|
ifc := &net.Interface{
|
||||||
|
Index: int(linkIndex),
|
||||||
|
}
|
||||||
|
addresses, err := conn.Addrs(ifc, int(family))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listing addresses: %w", err)
|
||||||
|
}
|
||||||
|
for _, address := range addresses {
|
||||||
|
if address.IP.Equal(ipNet.IP) &&
|
||||||
|
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
|
||||||
|
err = conn.AddrDel(ifc, address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deleting address from interface: %w", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
|
// Add the new address to the interface
|
||||||
|
err = conn.AddrAdd(ifc, ipNet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("adding address to interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
//go:build !linux && !darwin
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
func (n *NetLink) AddrList(link Link, family int) (
|
|
||||||
addresses []Addr, err error,
|
|
||||||
) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) AddrReplace(Link, Addr) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
@@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
|
|||||||
return netip.PrefixFrom(ip, bits)
|
return netip.PrefixFrom(ip, bits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
|
||||||
|
if ip == nil || len(*ip) == 0 {
|
||||||
|
return netip.Prefix{}
|
||||||
|
}
|
||||||
|
var dstIP netip.Addr
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
|
||||||
|
dstIP = netip.AddrFrom4([4]byte(*ip))
|
||||||
|
} else {
|
||||||
|
dstIP = netip.AddrFrom16([16]byte(*ip))
|
||||||
|
}
|
||||||
|
return netip.PrefixFrom(dstIP, int(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
|
||||||
|
if !prefix.IsValid() {
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
prefixIP := prefix.Addr().Unmap()
|
||||||
|
ip = new(net.IP)
|
||||||
|
*ip = netipAddrToNetIP(prefixIP)
|
||||||
|
length = uint8(prefix.Bits()) //nolint:gosec
|
||||||
|
return ip, length
|
||||||
|
}
|
||||||
|
|
||||||
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
|
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
|
||||||
switch {
|
switch {
|
||||||
case !address.IsValid():
|
case !address.IsValid():
|
||||||
|
|||||||
@@ -4,13 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
func FamilyToString(family uint8) string {
|
||||||
FamilyAll = 0
|
|
||||||
FamilyV4 = 2
|
|
||||||
FamilyV6 = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
func FamilyToString(family int) string {
|
|
||||||
switch family {
|
switch family {
|
||||||
case FamilyAll:
|
case FamilyAll:
|
||||||
return "all"
|
return "all"
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
const (
|
||||||
|
FamilyAll uint8 = unix.AF_UNSPEC
|
||||||
|
FamilyV4 uint8 = unix.AF_INET
|
||||||
|
FamilyV6 uint8 = unix.AF_INET6
|
||||||
|
)
|
||||||
@@ -1,16 +1,30 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ptrTo[T any](v T) *T { return &v }
|
||||||
|
|
||||||
func makeNetipPrefix(n byte) netip.Prefix {
|
func makeNetipPrefix(n byte) netip.Prefix {
|
||||||
const bits = 24
|
const bits = 24
|
||||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||||
|
|
||||||
|
func makeLinkName() string {
|
||||||
|
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
name := make([]byte, 8)
|
||||||
|
for i := range name {
|
||||||
|
name[i] = alphabet[rng.IntN(len(alphabet))]
|
||||||
|
}
|
||||||
|
return "test" + string(name)
|
||||||
|
}
|
||||||
|
|
||||||
type noopLogger struct{}
|
type noopLogger struct{}
|
||||||
|
|
||||||
func (l *noopLogger) Debug(_ string) {}
|
func (l *noopLogger) Debug(_ string) {}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
|
|||||||
return false, fmt.Errorf("finding link corresponding to route: %w", err)
|
return false, fmt.Errorf("finding link corresponding to route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
|
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
|
||||||
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
||||||
switch {
|
switch {
|
||||||
case !sourceIsIPv6 && !destinationIsIPv6,
|
case !sourceIsIPv6 && !destinationIsIPv6,
|
||||||
|
|||||||
+163
-77
@@ -1,105 +1,191 @@
|
|||||||
//go:build linux || darwin
|
|
||||||
|
|
||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
"github.com/jsimonetti/rtnetlink"
|
||||||
netlinkLinks, err := netlink.LinkList()
|
)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
type DeviceType uint16
|
||||||
|
|
||||||
|
type Link struct {
|
||||||
|
Index uint32
|
||||||
|
Name string
|
||||||
|
DeviceType DeviceType
|
||||||
|
VirtualType string
|
||||||
|
MTU uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
links = make([]Link, len(netlinkLinks))
|
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||||
for i := range netlinkLinks {
|
conn, err := rtnetlink.Dial(nil)
|
||||||
links[i] = netlinkLinkToLink(netlinkLinks[i])
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
linkMessages, err := conn.Link.List()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
links = make([]Link, len(linkMessages))
|
||||||
|
for i, message := range linkMessages {
|
||||||
|
virtualType := ""
|
||||||
|
if message.Attributes.Info != nil {
|
||||||
|
virtualType = message.Attributes.Info.Kind
|
||||||
|
}
|
||||||
|
links[i] = Link{
|
||||||
|
Index: message.Index,
|
||||||
|
Name: message.Attributes.Name,
|
||||||
|
DeviceType: DeviceType(message.Type),
|
||||||
|
VirtualType: virtualType,
|
||||||
|
MTU: message.Attributes.MTU,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return links, nil
|
return links, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrLinkNotFound = errors.New("link not found")
|
||||||
|
|
||||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||||
netlinkLink, err := netlink.LinkByName(name)
|
links, err := n.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Link{}, err
|
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return netlinkLinkToLink(netlinkLink), nil
|
for _, link := range links {
|
||||||
|
if link.Name == name {
|
||||||
|
return link, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
|
||||||
netlinkLink, err := netlink.LinkByIndex(index)
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||||
|
links, err := n.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Link{}, err
|
return Link{}, fmt.Errorf("listing links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return netlinkLinkToLink(netlinkLink), nil
|
for _, link = range links {
|
||||||
|
if link.Index == index {
|
||||||
|
return link, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
|
||||||
netlinkLink := linkToNetlinkLink(&link)
|
}
|
||||||
err = netlink.LinkAdd(netlinkLink)
|
|
||||||
|
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, fmt.Errorf("dialing netlink: %w", err)
|
||||||
}
|
|
||||||
return netlinkLink.Attrs().Index, nil
|
|
||||||
}
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
tx := &rtnetlink.LinkMessage{
|
||||||
return netlink.LinkDel(linkToNetlinkLink(&link))
|
Type: uint16(link.DeviceType),
|
||||||
}
|
Attributes: &rtnetlink.LinkAttributes{
|
||||||
|
MTU: link.MTU,
|
||||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
|
||||||
netlinkLink := linkToNetlinkLink(&link)
|
|
||||||
err = netlink.LinkSetUp(netlinkLink)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return netlinkLink.Attrs().Index, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
|
||||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
|
||||||
}
|
|
||||||
|
|
||||||
type netlinkLinkImpl struct {
|
|
||||||
attrs *netlink.LinkAttrs
|
|
||||||
linkType string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
|
|
||||||
return n.attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *netlinkLinkImpl) Type() string {
|
|
||||||
return n.linkType
|
|
||||||
}
|
|
||||||
|
|
||||||
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
|
|
||||||
attributes := netlinkLink.Attrs()
|
|
||||||
return Link{
|
|
||||||
Type: netlinkLink.Type(),
|
|
||||||
Name: attributes.Name,
|
|
||||||
Index: attributes.Index,
|
|
||||||
EncapType: attributes.EncapType,
|
|
||||||
MTU: uint16(attributes.MTU), //nolint:gosec
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
|
|
||||||
// so that the vishvananda/netlink package can compare the returned
|
|
||||||
// value against an untyped nil.
|
|
||||||
func linkToNetlinkLink(link *Link) netlink.Link {
|
|
||||||
if link == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &netlinkLinkImpl{
|
|
||||||
linkType: link.Type,
|
|
||||||
attrs: &netlink.LinkAttrs{
|
|
||||||
Name: link.Name,
|
Name: link.Name,
|
||||||
Index: link.Index,
|
|
||||||
EncapType: link.EncapType,
|
|
||||||
MTU: int(link.MTU),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if link.VirtualType != "" {
|
||||||
|
tx.Attributes.Info = &rtnetlink.LinkInfo{
|
||||||
|
Kind: link.VirtualType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Link.New(tx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("creating new link: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
linkMessages, err := conn.Link.List()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("listing links: %w", err)
|
||||||
|
}
|
||||||
|
for _, linkMessage := range linkMessages {
|
||||||
|
if linkMessage.Attributes.Name == link.Name {
|
||||||
|
return linkMessage.Index, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return conn.Link.Delete(linkIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
rx, err := conn.Link.Get(linkIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting link: %w", err)
|
||||||
|
}
|
||||||
|
tx := &rtnetlink.LinkMessage{
|
||||||
|
Type: rx.Type,
|
||||||
|
Index: linkIndex,
|
||||||
|
Flags: iffUp,
|
||||||
|
Change: iffUp,
|
||||||
|
}
|
||||||
|
return conn.Link.Set(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
linkInfo, err := conn.Link.Get(linkIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting link: %w", err)
|
||||||
|
}
|
||||||
|
message := &rtnetlink.LinkMessage{
|
||||||
|
Type: linkInfo.Type,
|
||||||
|
Index: linkIndex,
|
||||||
|
Flags: 0,
|
||||||
|
Change: iffUp,
|
||||||
|
}
|
||||||
|
return conn.Link.Set(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
message := &rtnetlink.LinkMessage{
|
||||||
|
Index: linkIndex,
|
||||||
|
Attributes: &rtnetlink.LinkAttributes{
|
||||||
|
MTU: mtu,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Link.Set(message)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
|
||||||
|
mtu, linkIndex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
const (
|
||||||
|
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
|
||||||
|
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
|
||||||
|
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
|
||||||
|
|
||||||
|
iffUp = unix.IFF_UP
|
||||||
|
)
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package netlink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_NetLink_LinkList(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
netlink := &NetLink{}
|
||||||
|
|
||||||
|
initialLinks, err := netlink.LinkList()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, initialLinks)
|
||||||
|
|
||||||
|
loopbackFound := false
|
||||||
|
for _, link := range initialLinks {
|
||||||
|
if link.Name != "lo" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
loopbackFound = true
|
||||||
|
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
assert.True(t, loopbackFound, "loopback interface not found")
|
||||||
|
|
||||||
|
testLink := Link{
|
||||||
|
Name: makeLinkName(),
|
||||||
|
// note if [Link.VirtualType] is set, [Link.DeviceType]
|
||||||
|
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
|
||||||
|
DeviceType: DeviceTypeNone,
|
||||||
|
VirtualType: "wireguard",
|
||||||
|
MTU: 1420,
|
||||||
|
}
|
||||||
|
index, err := netlink.LinkAdd(testLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = netlink.LinkDel(index)
|
||||||
|
})
|
||||||
|
|
||||||
|
links, err := netlink.LinkList()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testLink.Index = index
|
||||||
|
for _, link := range links {
|
||||||
|
if link.Name != testLink.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.Equal(t, testLink, link)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Errorf("created link %q not found", testLink.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetLink_LinkSetMTU(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
netlink := &NetLink{}
|
||||||
|
|
||||||
|
testLink := Link{
|
||||||
|
Name: makeLinkName(),
|
||||||
|
DeviceType: DeviceTypeNone,
|
||||||
|
VirtualType: "wireguard",
|
||||||
|
MTU: 1420,
|
||||||
|
}
|
||||||
|
index, err := netlink.LinkAdd(testLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = netlink.LinkDel(index)
|
||||||
|
})
|
||||||
|
testLink.Index = index
|
||||||
|
|
||||||
|
err = netlink.LinkSetMTU(index, 1500)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
link, err := netlink.LinkByIndex(index)
|
||||||
|
require.NoError(t, err)
|
||||||
|
testLink.MTU = 1500
|
||||||
|
assert.Equal(t, testLink, link)
|
||||||
|
}
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
//go:build !linux && !darwin
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package netlink
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FamilyAll is a placeholder only and should not
|
||||||
|
// be used.
|
||||||
|
FamilyAll uint8 = iota
|
||||||
|
// FamilyV4 is a placeholder only and should not
|
||||||
|
// be used.
|
||||||
|
FamilyV4
|
||||||
|
// FamilyV6 is a placeholder only and should not
|
||||||
|
// be used.
|
||||||
|
FamilyV6
|
||||||
|
|
||||||
|
// DeviceTypeEthernet is a placeholder only and should not be used.
|
||||||
|
DeviceTypeEthernet DeviceType = 0
|
||||||
|
// DeviceTypeLoopback is a placeholder only and should not be used.
|
||||||
|
DeviceTypeLoopback DeviceType = 0
|
||||||
|
// DeviceTypeNone is a placeholder only and should not be used.
|
||||||
|
DeviceTypeNone DeviceType = 0
|
||||||
|
|
||||||
|
// iffUp is a placeholder only and should not be used.
|
||||||
|
iffUp = 0
|
||||||
|
|
||||||
|
// RouteTypeUnicast is a placeholder only and should not be used.
|
||||||
|
RouteTypeUnicast = 0
|
||||||
|
// ScopeUniverse is a placeholder only and should not be used.
|
||||||
|
ScopeUniverse = 0
|
||||||
|
// ProtoStatic is a placeholder only and should not be used.
|
||||||
|
ProtoStatic = 0
|
||||||
|
|
||||||
|
// FlagInvert is a placeholder only and should not be used.
|
||||||
|
FlagInvert = 0
|
||||||
|
// ActionToTable is a placeholder only and should not be used.
|
||||||
|
ActionToTable = 0
|
||||||
|
|
||||||
|
// rtTableCompat is a placeholder only and should not be used.
|
||||||
|
rtTableCompat = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RuleDel(rule Rule) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) IsWireguardSupported() (bool, error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
+104
-48
@@ -1,69 +1,125 @@
|
|||||||
//go:build linux || darwin
|
|
||||||
|
|
||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/vishvananda/netlink"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/jsimonetti/rtnetlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
|
type Route struct {
|
||||||
// We set the filter to netlink.RT_FILTER_TABLE so that
|
LinkIndex uint32
|
||||||
// routes from all tables are listed, as long as the filter
|
Dst netip.Prefix
|
||||||
// table is set to 0.
|
Src netip.Prefix
|
||||||
const filterMask = netlink.RT_FILTER_TABLE
|
Gw netip.Addr
|
||||||
// The filter is not left to `nil` otherwise non-main tables
|
Priority uint32
|
||||||
// are ignored.
|
Family uint8
|
||||||
filter := &netlink.Route{}
|
Table uint32
|
||||||
|
Type uint8
|
||||||
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
|
Scope uint8
|
||||||
if err != nil {
|
Proto uint8
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
routes = make([]Route, len(netlinkRoutes))
|
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||||
for i := range netlinkRoutes {
|
table := uint32(message.Table)
|
||||||
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
|
if table == 0 || table == rtTableCompat {
|
||||||
|
table = message.Attributes.Table
|
||||||
|
}
|
||||||
|
r.LinkIndex = message.Attributes.OutIface
|
||||||
|
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
|
||||||
|
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
|
||||||
|
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
|
||||||
|
r.Priority = message.Attributes.Priority
|
||||||
|
r.Family = message.Family
|
||||||
|
r.Table = table
|
||||||
|
r.Type = message.Type
|
||||||
|
r.Scope = message.Scope
|
||||||
|
r.Proto = message.Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Route) message() *rtnetlink.RouteMessage {
|
||||||
|
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||||
|
src, srcLength := prefixToIPAndLength(r.Src)
|
||||||
|
var table uint8
|
||||||
|
var extendedTable uint32
|
||||||
|
if r.Table <= uint32(^uint8(0)) {
|
||||||
|
table = uint8(r.Table)
|
||||||
|
} else {
|
||||||
|
table = rtTableCompat
|
||||||
|
extendedTable = r.Table
|
||||||
|
}
|
||||||
|
message := &rtnetlink.RouteMessage{
|
||||||
|
Family: r.Family,
|
||||||
|
DstLength: dstLength,
|
||||||
|
SrcLength: srcLength,
|
||||||
|
Table: table,
|
||||||
|
Type: r.Type,
|
||||||
|
Scope: r.Scope,
|
||||||
|
Protocol: r.Proto,
|
||||||
|
Attributes: rtnetlink.RouteAttributes{
|
||||||
|
OutIface: r.LinkIndex,
|
||||||
|
Dst: *dst, // there should always be a dst for routes
|
||||||
|
Gateway: netipAddrToNetIP(r.Gw),
|
||||||
|
Priority: r.Priority,
|
||||||
|
Table: extendedTable,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if src != nil { // src is optional
|
||||||
|
message.Attributes.Src = *src
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
routeMessages, err := conn.Route.List()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listing interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = make([]Route, 0, len(routeMessages))
|
||||||
|
for _, routeMessage := range routeMessages {
|
||||||
|
if family != FamilyAll && routeMessage.Family != family {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var route Route
|
||||||
|
route.fromMessage(routeMessage)
|
||||||
|
routes = append(routes, route)
|
||||||
}
|
}
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteAdd(route Route) error {
|
func (n *NetLink) RouteAdd(route Route) error {
|
||||||
netlinkRoute := routeToNetlinkRoute(route)
|
conn, err := rtnetlink.Dial(nil)
|
||||||
return netlink.RouteAdd(&netlinkRoute)
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return conn.Route.Add(route.message())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteDel(route Route) error {
|
func (n *NetLink) RouteDel(route Route) error {
|
||||||
netlinkRoute := routeToNetlinkRoute(route)
|
conn, err := rtnetlink.Dial(nil)
|
||||||
return netlink.RouteDel(&netlinkRoute)
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return conn.Route.Delete(route.message())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteReplace(route Route) error {
|
func (n *NetLink) RouteReplace(route Route) error {
|
||||||
netlinkRoute := routeToNetlinkRoute(route)
|
conn, err := rtnetlink.Dial(nil)
|
||||||
return netlink.RouteReplace(&netlinkRoute)
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
}
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
|
return conn.Route.Replace(route.message())
|
||||||
return Route{
|
|
||||||
LinkIndex: netlinkRoute.LinkIndex,
|
|
||||||
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
|
|
||||||
Src: netIPToNetipAddress(netlinkRoute.Src),
|
|
||||||
Gw: netIPToNetipAddress(netlinkRoute.Gw),
|
|
||||||
Priority: netlinkRoute.Priority,
|
|
||||||
Family: netlinkRoute.Family,
|
|
||||||
Table: netlinkRoute.Table,
|
|
||||||
Type: netlinkRoute.Type,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
|
|
||||||
return netlink.Route{
|
|
||||||
LinkIndex: route.LinkIndex,
|
|
||||||
Dst: netipPrefixToIPNet(route.Dst),
|
|
||||||
Src: netipAddrToNetIP(route.Src),
|
|
||||||
Gw: netipAddrToNetIP(route.Gw),
|
|
||||||
Priority: route.Priority,
|
|
||||||
Family: route.Family,
|
|
||||||
Table: route.Table,
|
|
||||||
Type: route.Type,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
const (
|
||||||
|
RouteTypeUnicast = unix.RTN_UNICAST
|
||||||
|
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
|
||||||
|
ProtoStatic = unix.RTPROT_STATIC
|
||||||
|
|
||||||
|
rtTableCompat = unix.RT_TABLE_COMPAT
|
||||||
|
)
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
//go:build !linux && !darwin
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
func (n *NetLink) RouteList(family int) (
|
|
||||||
routes []Route, err error,
|
|
||||||
) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RouteAdd(route Route) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RouteDel(route Route) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RouteReplace(route Route) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
+72
-67
@@ -1,91 +1,96 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/jsimonetti/rtnetlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRule() Rule {
|
type Rule struct {
|
||||||
// defaults found from netlink.NewRule() for fields we use,
|
Priority *uint32
|
||||||
// the rest of the defaults is set when converting from a `Rule`
|
Family uint8
|
||||||
// to a `netlink.Rule`
|
Table uint32
|
||||||
return Rule{
|
Mark *uint32
|
||||||
Priority: -1,
|
Src netip.Prefix
|
||||||
Mark: 0,
|
Dst netip.Prefix
|
||||||
}
|
Flags uint32
|
||||||
|
Action uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
|
||||||
switch family {
|
table := uint32(message.Table)
|
||||||
case FamilyAll:
|
if table == 0 || table == rtTableCompat {
|
||||||
n.debugLogger.Debug("ip -4 rule list")
|
table = *message.Attributes.Table
|
||||||
n.debugLogger.Debug("ip -6 rule list")
|
|
||||||
case FamilyV4:
|
|
||||||
n.debugLogger.Debug("ip -4 rule list")
|
|
||||||
case FamilyV6:
|
|
||||||
n.debugLogger.Debug("ip -6 rule list")
|
|
||||||
}
|
}
|
||||||
netlinkRules, err := netlink.RuleList(family)
|
r.Priority = message.Attributes.Priority
|
||||||
if err != nil {
|
r.Family = message.Family
|
||||||
return nil, err
|
r.Table = table
|
||||||
|
r.Mark = message.Attributes.FwMark
|
||||||
|
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
|
||||||
|
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
|
||||||
|
r.Flags = message.Flags
|
||||||
|
r.Action = message.Action
|
||||||
}
|
}
|
||||||
|
|
||||||
rules = make([]Rule, len(netlinkRules))
|
func (r Rule) message() *rtnetlink.RuleMessage {
|
||||||
for i := range netlinkRules {
|
src, srcLength := prefixToIPAndLength(r.Src)
|
||||||
rules[i] = netlinkRuleToRule(netlinkRules[i])
|
dst, dstLength := prefixToIPAndLength(r.Dst)
|
||||||
}
|
|
||||||
return rules, nil
|
message := &rtnetlink.RuleMessage{
|
||||||
|
Family: r.Family,
|
||||||
|
SrcLength: srcLength,
|
||||||
|
DstLength: dstLength,
|
||||||
|
Flags: r.Flags,
|
||||||
|
Action: r.Action,
|
||||||
|
Attributes: &rtnetlink.RuleAttributes{
|
||||||
|
Priority: r.Priority,
|
||||||
|
FwMark: r.Mark,
|
||||||
|
Src: src,
|
||||||
|
Dst: dst,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
if r.Table <= uint32(^uint8(0)) {
|
||||||
n.debugLogger.Debug(ruleDbgMsg(true, rule))
|
message.Table = uint8(r.Table)
|
||||||
netlinkRule := ruleToNetlinkRule(rule)
|
} else {
|
||||||
return netlink.RuleAdd(&netlinkRule)
|
message.Table = rtTableCompat
|
||||||
|
message.Attributes.Table = &r.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleDel(rule Rule) error {
|
return message
|
||||||
n.debugLogger.Debug(ruleDbgMsg(false, rule))
|
|
||||||
netlinkRule := ruleToNetlinkRule(rule)
|
|
||||||
return netlink.RuleDel(&netlinkRule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
|
func (r Rule) String() string {
|
||||||
netlinkRule = *netlink.NewRule()
|
from := "all"
|
||||||
netlinkRule.Priority = rule.Priority
|
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
||||||
netlinkRule.Family = rule.Family
|
from = r.Src.String()
|
||||||
netlinkRule.Table = rule.Table
|
|
||||||
netlinkRule.Mark = rule.Mark
|
|
||||||
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
|
|
||||||
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
|
|
||||||
netlinkRule.Invert = rule.Invert
|
|
||||||
return netlinkRule
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
|
to := "all"
|
||||||
return Rule{
|
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
||||||
Priority: netlinkRule.Priority,
|
to = r.Dst.String()
|
||||||
Family: netlinkRule.Family,
|
|
||||||
Table: netlinkRule.Table,
|
|
||||||
Mark: netlinkRule.Mark,
|
|
||||||
Src: netIPNetToNetipPrefix(netlinkRule.Src),
|
|
||||||
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
|
|
||||||
Invert: netlinkRule.Invert,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
|
priority := ""
|
||||||
|
if r.Priority != nil {
|
||||||
|
priority = fmt.Sprintf(" %d", *r.Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
|
||||||
|
priority, from, to, r.Table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) debugMessage(add bool) (debugMessage string) {
|
||||||
debugMessage = "ip"
|
debugMessage = "ip"
|
||||||
|
|
||||||
switch rule.Family {
|
switch r.Family {
|
||||||
case FamilyV4:
|
case FamilyV4:
|
||||||
debugMessage += " -f inet"
|
debugMessage += " -f inet"
|
||||||
case FamilyV6:
|
case FamilyV6:
|
||||||
debugMessage += " -f inet6"
|
debugMessage += " -f inet6"
|
||||||
default:
|
default:
|
||||||
debugMessage += " -f " + fmt.Sprint(rule.Family)
|
debugMessage += " -f " + fmt.Sprint(r.Family)
|
||||||
}
|
}
|
||||||
|
|
||||||
debugMessage += " rule"
|
debugMessage += " rule"
|
||||||
@@ -96,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
|
|||||||
debugMessage += " del"
|
debugMessage += " del"
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.Src.IsValid() {
|
if r.Src.IsValid() {
|
||||||
debugMessage += " from " + rule.Src.String()
|
debugMessage += " from " + r.Src.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.Dst.IsValid() {
|
if r.Dst.IsValid() {
|
||||||
debugMessage += " to " + rule.Dst.String()
|
debugMessage += " to " + r.Dst.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.Table != 0 {
|
if r.Table != 0 {
|
||||||
debugMessage += " lookup " + fmt.Sprint(rule.Table)
|
debugMessage += " lookup " + fmt.Sprint(r.Table)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.Priority != -1 {
|
if r.Priority != nil {
|
||||||
debugMessage += " pref " + fmt.Sprint(rule.Priority)
|
debugMessage += " pref " + fmt.Sprint(*r.Priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
return debugMessage
|
return debugMessage
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jsimonetti/rtnetlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
FlagInvert = unix.FIB_RULE_INVERT
|
||||||
|
ActionToTable = unix.FR_ACT_TO_TBL
|
||||||
|
)
|
||||||
|
|
||||||
|
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
|
||||||
|
switch family {
|
||||||
|
case FamilyAll:
|
||||||
|
n.debugLogger.Debug("ip -4 rule list")
|
||||||
|
n.debugLogger.Debug("ip -6 rule list")
|
||||||
|
case FamilyV4:
|
||||||
|
n.debugLogger.Debug("ip -4 rule list")
|
||||||
|
case FamilyV6:
|
||||||
|
n.debugLogger.Debug("ip -6 rule list")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ruleMessages, err := conn.Rule.List()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = make([]Rule, 0, len(ruleMessages))
|
||||||
|
for _, message := range ruleMessages {
|
||||||
|
if family != FamilyAll && family != message.Family {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var rule Rule
|
||||||
|
rule.fromMessage(message)
|
||||||
|
rules = append(rules, rule)
|
||||||
|
}
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||||
|
n.debugLogger.Debug(rule.debugMessage(true))
|
||||||
|
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
return conn.Rule.Add(rule.message())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RuleDel(rule Rule) error {
|
||||||
|
n.debugLogger.Debug(rule.debugMessage(false))
|
||||||
|
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
return conn.Rule.Delete(rule.message())
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_ruleDbgMsg(t *testing.T) {
|
func Test_Rule_debugMessage(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
@@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
dbgMsg string
|
dbgMsg string
|
||||||
}{
|
}{
|
||||||
"default values": {
|
"default values": {
|
||||||
dbgMsg: "ip -f 0 rule del pref 0",
|
dbgMsg: "ip -f 0 rule del",
|
||||||
},
|
},
|
||||||
"add rule": {
|
"add rule": {
|
||||||
add: true,
|
add: true,
|
||||||
@@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
Src: makeNetipPrefix(1),
|
Src: makeNetipPrefix(1),
|
||||||
Dst: makeNetipPrefix(2),
|
Dst: makeNetipPrefix(2),
|
||||||
Table: 100,
|
Table: 100,
|
||||||
Priority: 101,
|
Priority: ptrTo(uint32(101)),
|
||||||
},
|
},
|
||||||
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||||
},
|
},
|
||||||
@@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
Src: makeNetipPrefix(1),
|
Src: makeNetipPrefix(1),
|
||||||
Dst: makeNetipPrefix(2),
|
Dst: makeNetipPrefix(2),
|
||||||
Table: 100,
|
Table: 100,
|
||||||
Priority: 101,
|
Priority: ptrTo(uint32(101)),
|
||||||
},
|
},
|
||||||
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
|
||||||
},
|
},
|
||||||
@@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
|
dbgMsg := testCase.rule.debugMessage(testCase.add)
|
||||||
|
|
||||||
assert.Equal(t, testCase.dbgMsg, dbgMsg)
|
assert.Equal(t, testCase.dbgMsg, dbgMsg)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
func NewRule() Rule {
|
|
||||||
return Rule{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RuleAdd(rule Rule) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *NetLink) RuleDel(rule Rule) error {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
package netlink
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Addr struct {
|
|
||||||
Network netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a Addr) String() string {
|
|
||||||
return a.Network.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Link struct {
|
|
||||||
Type string
|
|
||||||
Name string
|
|
||||||
Index int
|
|
||||||
EncapType string
|
|
||||||
MTU uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type Route struct {
|
|
||||||
LinkIndex int
|
|
||||||
Dst netip.Prefix
|
|
||||||
Src netip.Addr
|
|
||||||
Gw netip.Addr
|
|
||||||
Priority int
|
|
||||||
Family int
|
|
||||||
Table int
|
|
||||||
Type int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Rule struct {
|
|
||||||
Priority int
|
|
||||||
Family int
|
|
||||||
Table int
|
|
||||||
Mark uint32
|
|
||||||
Src netip.Prefix
|
|
||||||
Dst netip.Prefix
|
|
||||||
Invert bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Rule) String() string {
|
|
||||||
from := "all"
|
|
||||||
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
|
|
||||||
from = r.Src.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
to := "all"
|
|
||||||
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
|
|
||||||
to = r.Dst.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
|
|
||||||
r.Priority, from, to, r.Table)
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/qdm12/gluetun/internal/mod"
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (n *NetLink) IsWireguardSupported() bool {
|
|
||||||
// Check for Wireguard family without loading the wireguard module.
|
|
||||||
// Some kernels have the wireguard module built-in, and don't have a
|
|
||||||
// modules directory, such as WSL2 kernels.
|
|
||||||
ok := hasWireguardFamily()
|
|
||||||
if ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try loading the wireguard module, since some systems do not load
|
|
||||||
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
|
||||||
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
|
||||||
err := mod.Probe("wireguard")
|
|
||||||
if err != nil {
|
|
||||||
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
|
||||||
|
|
||||||
// Re-check if the Wireguard family is now available, after loading
|
|
||||||
// the wireguard kernel module.
|
|
||||||
return hasWireguardFamily()
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasWireguardFamily() bool {
|
|
||||||
_, err := netlink.GenlFamilyGet("wireguard")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/mdlayher/genetlink"
|
||||||
|
"github.com/qdm12/gluetun/internal/mod"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
||||||
|
// Check for Wireguard family without loading the wireguard module.
|
||||||
|
// Some kernels have the wireguard module built-in, and don't have a
|
||||||
|
// modules directory, such as WSL2 kernels.
|
||||||
|
ok, err = hasWireguardFamily()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||||
|
} else if ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try loading the wireguard module, since some systems do not load
|
||||||
|
// it after a boot. If this fails, wireguard is assumed to not be supported.
|
||||||
|
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
|
||||||
|
err = mod.Probe("wireguard")
|
||||||
|
if err != nil {
|
||||||
|
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
|
||||||
|
|
||||||
|
// Re-check if the Wireguard family is now available, after loading
|
||||||
|
// the wireguard kernel module.
|
||||||
|
ok, err = hasWireguardFamily()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("checking wireguard family: %w", err)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasWireguardFamily() (ok bool, err error) {
|
||||||
|
conn, err := genetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("dialing netlink: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.GetFamily("wireguard")
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("getting wireguard family: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -4,6 +4,8 @@ package netlink
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
||||||
@@ -12,7 +14,8 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
|
|||||||
netLink := &NetLink{
|
netLink := &NetLink{
|
||||||
debugLogger: &noopLogger{},
|
debugLogger: &noopLogger{},
|
||||||
}
|
}
|
||||||
ok := netLink.IsWireguardSupported()
|
ok, err := netLink.IsWireguardSupported()
|
||||||
|
require.NoError(t, err)
|
||||||
if ok { // cannot assert since this depends on kernel
|
if ok { // cannot assert since this depends on kernel
|
||||||
t.Log("wireguard is supported")
|
t.Log("wireguard is supported")
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package netlink
|
|
||||||
|
|
||||||
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||||
)
|
)
|
||||||
@@ -33,7 +32,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
|
|||||||
args := []string{"--config", configPath}
|
args := []string{"--config", configPath}
|
||||||
args = append(args, flags...)
|
args = append(args, flags...)
|
||||||
cmd := exec.CommandContext(ctx, bin, args...)
|
cmd := exec.CommandContext(ctx, bin, args...)
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
setCmdSysProcAttr(cmd)
|
||||||
|
|
||||||
return starter.Start(cmd)
|
return starter.Start(cmd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package openvpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package openvpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdSysProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ net.PacketConn = &ipv4Wrapper{}
|
||||||
|
|
||||||
|
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
|
||||||
|
// the net.PacketConn interface. It's only used for Darwin or iOS.
|
||||||
|
type ipv4Wrapper struct {
|
||||||
|
ipv4Conn *ipv4.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
|
||||||
|
return &ipv4Wrapper{ipv4Conn: ipv4}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
|
||||||
|
return n, addr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
return i.ipv4Conn.WriteTo(p, nil, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) Close() error {
|
||||||
|
return i.ipv4Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) LocalAddr() net.Addr {
|
||||||
|
return i.ipv4Conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
|
||||||
|
return i.ipv4Conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
|
||||||
|
return i.ipv4Conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
|
||||||
|
return i.ipv4Conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||||
|
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||||
|
switch {
|
||||||
|
case mtu < minMTU:
|
||||||
|
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
||||||
|
case mtu > physicalLinkMTU:
|
||||||
|
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||||
|
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||||
|
outboundMessage *icmp.Message,
|
||||||
|
) (match bool, err error) {
|
||||||
|
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("parsing invoking packet: %w", err)
|
||||||
|
}
|
||||||
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||||
|
}
|
||||||
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
|
return inboundBody.ID == outboundBody.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
||||||
|
|
||||||
|
func checkEchoReply(icmpProtocol int, received []byte,
|
||||||
|
outboundMessage *icmp.Message, truncatedBody bool,
|
||||||
|
) (err error) {
|
||||||
|
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing invoking packet: %w", err)
|
||||||
|
}
|
||||||
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||||
|
}
|
||||||
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
|
if inboundBody.ID != outboundBody.ID {
|
||||||
|
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||||
|
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||||
|
}
|
||||||
|
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("checking sent and received bodies: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||||
|
|
||||||
|
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||||
|
if len(received) > len(sent) {
|
||||||
|
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||||
|
ErrICMPEchoDataMismatch, len(sent), len(received))
|
||||||
|
}
|
||||||
|
if receivedTruncated {
|
||||||
|
sent = sent[:len(received)]
|
||||||
|
}
|
||||||
|
if !bytes.Equal(received, sent) {
|
||||||
|
return fmt.Errorf("%w: sent %x and received %x",
|
||||||
|
ErrICMPEchoDataMismatch, sent, received)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !linux && !windows
|
||||||
|
|
||||||
|
package pmtud
|
||||||
|
|
||||||
|
// setDontFragment for platforms other than Linux and Windows
|
||||||
|
// is not implemented, so we just return assuming the don't
|
||||||
|
// fragment flag is set on IP packets.
|
||||||
|
func setDontFragment(fd uintptr) (err error) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setDontFragment(fd uintptr) (err error) {
|
||||||
|
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
|
||||||
|
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setDontFragment(fd uintptr) (err error) {
|
||||||
|
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||||
|
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||||
|
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
||||||
|
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||||
|
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||||
|
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||||
|
)
|
||||||
|
|
||||||
|
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||||
|
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||||
|
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||||
|
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||||
|
case timedCtx.Err() != nil:
|
||||||
|
err = timedCtx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(msg string)
|
||||||
|
Debugf(msg string, args ...any)
|
||||||
|
Warnf(msg string, args ...any)
|
||||||
|
}
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||||
|
minIPv4MTU uint32 = 68
|
||||||
|
icmpv4Protocol int = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||||
|
var listenConfig net.ListenConfig
|
||||||
|
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||||
|
var setDFErr error
|
||||||
|
err := rawConn.Control(func(fd uintptr) {
|
||||||
|
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = setDFErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const listenAddress = ""
|
||||||
|
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||||
|
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
|
||||||
|
}
|
||||||
|
|
||||||
|
return packetConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||||
|
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||||
|
) (mtu uint32, err error) {
|
||||||
|
if ip.Is6() {
|
||||||
|
panic("IP address is not v4")
|
||||||
|
}
|
||||||
|
conn, err := listenICMPv4(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// First try to send a packet which is too big to get the maximum MTU
|
||||||
|
// directly.
|
||||||
|
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
|
||||||
|
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||||
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := make([]byte, physicalLinkMTU)
|
||||||
|
|
||||||
|
for { // for loop in case we read an echo reply for another ICMP request
|
||||||
|
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||||
|
// must be large enough to read the entire reply packet. See:
|
||||||
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
|
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
|
}
|
||||||
|
packetBytes := buffer[:bytesRead]
|
||||||
|
// Side note: echo reply should be at most the number of bytes
|
||||||
|
// sent, and can be lower, more precisely 576-ipHeader bytes,
|
||||||
|
// in case the next hop we are reaching replies with a destination
|
||||||
|
// unreachable and wants to ensure the response makes it way back
|
||||||
|
// by keeping a low packet size, see:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||||
|
|
||||||
|
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch typedBody := inboundMessage.Body.(type) {
|
||||||
|
case *icmp.DstUnreach:
|
||||||
|
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||||
|
const communicationAdministrativelyProhibitedCode = 13
|
||||||
|
switch inboundMessage.Code {
|
||||||
|
case fragmentationRequiredAndDFFlagSetCode:
|
||||||
|
case communicationAdministrativelyProhibitedCode:
|
||||||
|
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||||
|
ErrICMPDestinationUnreachable,
|
||||||
|
ErrICMPCommunicationAdministrativelyProhibited,
|
||||||
|
inboundMessage.Code)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: code %d",
|
||||||
|
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||||
|
// Note: the go library does not handle this NextHopMTU section.
|
||||||
|
nextHopMTU := packetBytes[6:8]
|
||||||
|
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||||
|
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The code below is really for sanity checks
|
||||||
|
packetBytes = packetBytes[8:]
|
||||||
|
header, err := ipv4.ParseHeader(packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
|
||||||
|
}
|
||||||
|
packetBytes = packetBytes[header.Len:] // truncated original datagram
|
||||||
|
|
||||||
|
const truncated = true
|
||||||
|
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("checking echo reply: %w", err)
|
||||||
|
}
|
||||||
|
return mtu, nil
|
||||||
|
case *icmp.Echo:
|
||||||
|
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||||
|
if inboundID == outboundID {
|
||||||
|
return physicalLinkMTU, nil
|
||||||
|
}
|
||||||
|
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||||
|
inboundID, outboundID)
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
minIPv6MTU = 1280
|
||||||
|
icmpv6Protocol = 58
|
||||||
|
)
|
||||||
|
|
||||||
|
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||||
|
var listenConfig net.ListenConfig
|
||||||
|
const listenAddress = ""
|
||||||
|
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return packetConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||||
|
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
|
||||||
|
) (mtu uint32, err error) {
|
||||||
|
if ip.Is4() {
|
||||||
|
panic("IP address is not v6")
|
||||||
|
}
|
||||||
|
conn, err := listenICMPv6(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// First try to send a packet which is too big to get the maximum MTU
|
||||||
|
// directly.
|
||||||
|
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
|
||||||
|
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||||
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := make([]byte, physicalLinkMTU)
|
||||||
|
|
||||||
|
for { // for loop if we encounter another ICMP packet with an unknown id.
|
||||||
|
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||||
|
// must be large enough to read the entire reply packet. See:
|
||||||
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
|
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
|
}
|
||||||
|
packetBytes := buffer[:bytesRead]
|
||||||
|
|
||||||
|
packetBytes = packetBytes[ipv6.HeaderLen:]
|
||||||
|
|
||||||
|
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch typedBody := inboundMessage.Body.(type) {
|
||||||
|
case *icmp.PacketTooBig:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||||
|
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||||
|
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity checks
|
||||||
|
const truncatedBody = true
|
||||||
|
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("checking invoking message: %w", err)
|
||||||
|
}
|
||||||
|
return uint32(typedBody.MTU), nil //nolint:gosec
|
||||||
|
case *icmp.DstUnreach:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
|
||||||
|
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||||
|
} else if idMatch {
|
||||||
|
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
||||||
|
}
|
||||||
|
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||||
|
continue
|
||||||
|
case *icmp.Echo:
|
||||||
|
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||||
|
if inboundID == outboundID {
|
||||||
|
return physicalLinkMTU, nil
|
||||||
|
}
|
||||||
|
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||||
|
inboundID, outboundID)
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildMessageToSend(ipVersion string, mtu uint32) (id uint16, message *icmp.Message) {
|
||||||
|
var seed [32]byte
|
||||||
|
_, _ = cryptorand.Read(seed[:])
|
||||||
|
randomSource := rand.NewChaCha8(seed)
|
||||||
|
|
||||||
|
const uint16Bytes = 2
|
||||||
|
idBytes := make([]byte, uint16Bytes)
|
||||||
|
_, _ = randomSource.Read(idBytes)
|
||||||
|
id = binary.BigEndian.Uint16(idBytes)
|
||||||
|
|
||||||
|
var ipHeaderLength uint32
|
||||||
|
var icmpType icmp.Type
|
||||||
|
switch ipVersion {
|
||||||
|
case "v4":
|
||||||
|
ipHeaderLength = ipv4.HeaderLen
|
||||||
|
icmpType = ipv4.ICMPTypeEcho
|
||||||
|
case "v6":
|
||||||
|
ipHeaderLength = ipv6.HeaderLen
|
||||||
|
icmpType = ipv6.ICMPTypeEchoRequest
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
|
||||||
|
}
|
||||||
|
const pingHeaderLength = 0 +
|
||||||
|
1 + // type
|
||||||
|
1 + // code
|
||||||
|
2 + // checksum
|
||||||
|
2 + // identifier
|
||||||
|
2 // sequence number
|
||||||
|
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
|
||||||
|
messageBodyData := make([]byte, pingBodyDataSize)
|
||||||
|
_, _ = randomSource.Read(messageBodyData)
|
||||||
|
|
||||||
|
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
|
||||||
|
message = &icmp.Message{
|
||||||
|
Type: icmpType, // echo request
|
||||||
|
Code: 0, // no code
|
||||||
|
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: int(id),
|
||||||
|
Seq: 0, // only one packet
|
||||||
|
Data: messageBodyData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return id, message
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
type noopLogger struct{}
|
||||||
|
|
||||||
|
func (noopLogger) Debug(_ string) {}
|
||||||
|
func (noopLogger) Debugf(_ string, _ ...any) {}
|
||||||
|
func (noopLogger) Warnf(_ string, _ ...any) {}
|
||||||
@@ -0,0 +1,271 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
||||||
|
|
||||||
|
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
||||||
|
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||||
|
// If the pingTimeout is zero, it defaults to 1 second.
|
||||||
|
// If the logger is nil, a no-op logger is used.
|
||||||
|
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||||
|
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||||
|
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
|
||||||
|
mtu uint32, err error,
|
||||||
|
) {
|
||||||
|
if physicalLinkMTU == 0 {
|
||||||
|
const ethernetStandardMTU = 1500
|
||||||
|
physicalLinkMTU = ethernetStandardMTU
|
||||||
|
}
|
||||||
|
if pingTimeout == 0 {
|
||||||
|
pingTimeout = time.Second
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
logger = &noopLogger{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip.Is4() {
|
||||||
|
logger.Debug("finding IPv4 next hop MTU")
|
||||||
|
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return mtu, nil
|
||||||
|
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||||
|
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return mtu, nil
|
||||||
|
case errors.Is(err, net.ErrClosed): // blackhole
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back method: send echo requests with different packet
|
||||||
|
// sizes and check which ones succeed to find the maximum MTU.
|
||||||
|
logger.Debug("falling back to sending different sized echo packets")
|
||||||
|
minMTU := minIPv4MTU
|
||||||
|
if ip.Is6() {
|
||||||
|
minMTU = minIPv6MTU
|
||||||
|
}
|
||||||
|
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
type pmtudTestUnit struct {
|
||||||
|
mtu uint32
|
||||||
|
echoID uint16
|
||||||
|
sentBytes int
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||||
|
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||||
|
logger Logger,
|
||||||
|
) (maxMTU uint32, err error) {
|
||||||
|
var ipVersion string
|
||||||
|
var conn net.PacketConn
|
||||||
|
if ip.Is4() {
|
||||||
|
ipVersion = "v4"
|
||||||
|
conn, err = listenICMPv4(ctx)
|
||||||
|
} else {
|
||||||
|
ipVersion = "v6"
|
||||||
|
conn, err = listenICMPv6(ctx)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
||||||
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||||
|
return minMTU, nil
|
||||||
|
}
|
||||||
|
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
||||||
|
|
||||||
|
tests := make([]pmtudTestUnit, len(mtusToTest))
|
||||||
|
for i := range mtusToTest {
|
||||||
|
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
||||||
|
}
|
||||||
|
|
||||||
|
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-timedCtx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||||
|
tests[i].echoID = id
|
||||||
|
|
||||||
|
encodedMessage, err := message.Marshal(nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
tests[i].sentBytes = len(encodedMessage)
|
||||||
|
|
||||||
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = collectReplies(conn, ipVersion, tests, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil: // max possible MTU is working
|
||||||
|
return tests[len(tests)-1].mtu, nil
|
||||||
|
case err != nil && errors.Is(err, net.ErrClosed):
|
||||||
|
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||||
|
// so find the highest MTU which worked.
|
||||||
|
// Note we start from index len(tests) - 2 since the max MTU
|
||||||
|
// cannot be working if we had a timeout.
|
||||||
|
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||||
|
if tests[i].ok {
|
||||||
|
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||||
|
pingTimeout, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All MTUs failed.
|
||||||
|
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||||
|
case err != nil:
|
||||||
|
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the MTU slice of length 11 such that:
|
||||||
|
// - the first element is the minMTU
|
||||||
|
// - the last element is the maxMTU
|
||||||
|
// - elements in-between are separated as close to each other
|
||||||
|
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||||
|
// with a total search space of 1728 MTUs which is enough;
|
||||||
|
// to find it in 2 searches requires 37 parallel queries which
|
||||||
|
// could be blocked by firewalls.
|
||||||
|
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||||
|
const mtusLength = 11 // find the final MTU in 3 searches
|
||||||
|
diff := maxMTU - minMTU
|
||||||
|
switch {
|
||||||
|
case minMTU > maxMTU:
|
||||||
|
panic("minMTU > maxMTU")
|
||||||
|
case diff <= mtusLength:
|
||||||
|
mtus = make([]uint32, 0, diff)
|
||||||
|
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||||
|
mtus = append(mtus, mtu)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
step := float64(diff) / float64(mtusLength-1)
|
||||||
|
mtus = make([]uint32, 0, mtusLength)
|
||||||
|
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||||
|
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||||
|
}
|
||||||
|
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtus
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||||
|
tests []pmtudTestUnit, logger Logger,
|
||||||
|
) (err error) {
|
||||||
|
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||||
|
for i, test := range tests {
|
||||||
|
echoIDToTestIndex[test.echoID] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||||
|
// create huge buffers which we don't really want to support anyway.
|
||||||
|
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||||
|
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||||
|
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||||
|
// match eventual Jumbo frames. More information at:
|
||||||
|
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||||
|
const maxPossibleMTU = 9196
|
||||||
|
buffer := make([]byte, maxPossibleMTU)
|
||||||
|
|
||||||
|
idsFound := 0
|
||||||
|
for idsFound < len(tests) {
|
||||||
|
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||||
|
// must be large enough to read the entire reply packet. See:
|
||||||
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
|
}
|
||||||
|
packetBytes := buffer[:bytesRead]
|
||||||
|
|
||||||
|
ipPacketLength := len(packetBytes)
|
||||||
|
|
||||||
|
var icmpProtocol int
|
||||||
|
switch ipVersion {
|
||||||
|
case "v4":
|
||||||
|
icmpProtocol = icmpv4Protocol
|
||||||
|
case "v6":
|
||||||
|
icmpProtocol = icmpv6Protocol
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the ICMP message
|
||||||
|
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||||
|
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
echoBody, ok := message.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uint16(echoBody.ID) //nolint:gosec
|
||||||
|
testIndex, testing := echoIDToTestIndex[id]
|
||||||
|
if !testing { // not an id we expected so ignore it
|
||||||
|
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||||
|
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idsFound++
|
||||||
|
sentBytes := tests[testIndex].sentBytes
|
||||||
|
|
||||||
|
// echo reply should be at most the number of bytes sent,
|
||||||
|
// and can be lower, more precisely 556 bytes, in case
|
||||||
|
// the host we are reaching wants to stay out of trouble
|
||||||
|
// and ensure its echo reply goes through without
|
||||||
|
// fragmentation, see the following page:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||||
|
const conservativeReplyLength = 556
|
||||||
|
truncated := ipPacketLength < sentBytes &&
|
||||||
|
ipPacketLength == conservativeReplyLength
|
||||||
|
// Check the packet size is the same if the reply is not truncated
|
||||||
|
if !truncated && sentBytes != ipPacketLength {
|
||||||
|
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||||
|
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
||||||
|
}
|
||||||
|
// Truncated reply or matching reply size
|
||||||
|
tests[testIndex].ok = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_PathMTUDiscover(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
const physicalLinkMTU = 1500
|
||||||
|
const timeout = time.Second
|
||||||
|
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
|
||||||
|
physicalLinkMTU, timeout, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Log("MTU found:", mtu)
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_makeMTUsToTest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
minMTU uint32
|
||||||
|
maxMTU uint32
|
||||||
|
mtus []uint32
|
||||||
|
}{
|
||||||
|
"0_0": {
|
||||||
|
mtus: []uint32{0},
|
||||||
|
},
|
||||||
|
"0_1": {
|
||||||
|
maxMTU: 1,
|
||||||
|
mtus: []uint32{0, 1},
|
||||||
|
},
|
||||||
|
"0_8": {
|
||||||
|
maxMTU: 8,
|
||||||
|
mtus: []uint32{0, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
|
},
|
||||||
|
"0_12": {
|
||||||
|
maxMTU: 12,
|
||||||
|
mtus: []uint32{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
|
||||||
|
},
|
||||||
|
"0_80": {
|
||||||
|
maxMTU: 80,
|
||||||
|
mtus: []uint32{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
|
||||||
|
},
|
||||||
|
"0_100": {
|
||||||
|
maxMTU: 100,
|
||||||
|
mtus: []uint32{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
|
||||||
|
},
|
||||||
|
"1280_1500": {
|
||||||
|
minMTU: 1280,
|
||||||
|
maxMTU: 1500,
|
||||||
|
mtus: []uint32{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||||
|
assert.Equal(t, testCase.mtus, mtus)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@ type Service interface {
|
|||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||||
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
|
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PortAllower interface {
|
type PortAllower interface {
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type PortAllower interface {
|
|||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||||
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
|
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRouteDefaultNotFound = errors.New("default route not found")
|
var ErrRouteDefaultNotFound = errors.New("default route not found")
|
||||||
@@ -15,7 +14,7 @@ type DefaultRoute struct {
|
|||||||
NetInterface string
|
NetInterface string
|
||||||
Gateway netip.Addr
|
Gateway netip.Addr
|
||||||
AssignedIP netip.Addr
|
AssignedIP netip.Addr
|
||||||
Family int
|
Family uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d DefaultRoute) String() string {
|
func (d DefaultRoute) String() string {
|
||||||
@@ -30,7 +29,7 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Table != unix.RT_TABLE_MAIN {
|
if route.Table != tableMain {
|
||||||
// ignore non-main table
|
// ignore non-main table
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
inboundTable = 200
|
inboundTable uint32 = 200
|
||||||
inboundPriority = 100
|
inboundPriority uint32 = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
|
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
|
||||||
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
|
||||||
for _, defaultRoute := range defaultRoutes {
|
for _, defaultRoute := range defaultRoutes {
|
||||||
assignedIP := defaultRoute.AssignedIP
|
assignedIP := defaultRoute.AssignedIP
|
||||||
bits := 32
|
bits := 32
|
||||||
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
|
||||||
for _, defaultRoute := range defaultRoutes {
|
for _, defaultRoute := range defaultRoutes {
|
||||||
assignedIP := defaultRoute.AssignedIP
|
assignedIP := defaultRoute.AssignedIP
|
||||||
bits := 32
|
bits := 32
|
||||||
|
|||||||
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
|
|||||||
|
|
||||||
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
||||||
|
|
||||||
func ipMatchesFamily(ip netip.Addr, family int) bool {
|
func ipMatchesFamily(ip netip.Addr, family uint8) bool {
|
||||||
return (family == netlink.FamilyV4 && ip.Is4()) ||
|
return (family == netlink.FamilyV4 && ip.Is4()) ||
|
||||||
(family == netlink.FamilyV6 && ip.Is6())
|
(family == netlink.FamilyV6 && ip.Is6())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
|
func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
|
||||||
iface, err := net.InterfaceByName(interfaceName)
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -27,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
return localNetworks, fmt.Errorf("listing links: %w", err)
|
return localNetworks, fmt.Errorf("listing links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
localLinks := make(map[int]struct{})
|
localLinks := make(map[uint32]struct{})
|
||||||
|
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
if link.EncapType != "ether" {
|
if link.DeviceType != netlink.DeviceTypeEthernet {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,7 +47,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Table != unix.RT_TABLE_MAIN ||
|
if route.Table != tableMain ||
|
||||||
(route.Gw.IsValid() && !route.Gw.IsUnspecified()) ||
|
(route.Gw.IsValid() && !route.Gw.IsUnspecified()) ||
|
||||||
(route.Dst.IsValid() && route.Dst.Addr().IsUnspecified()) {
|
(route.Dst.IsValid() && route.Dst.Addr().IsUnspecified()) {
|
||||||
continue
|
continue
|
||||||
@@ -96,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
|
|||||||
|
|
||||||
// Local has higher priority then outbound(99) and inbound(100) as the
|
// Local has higher priority then outbound(99) and inbound(100) as the
|
||||||
// local routes might be necessary to reach the outbound/inbound routes.
|
// local routes might be necessary to reach the outbound/inbound routes.
|
||||||
const localPriority = 98
|
const localPriority uint32 = 98
|
||||||
|
|
||||||
// Main table was setup correctly by Docker, just need to add rules to use it
|
// Main table was setup correctly by Docker, just need to add rules to use it
|
||||||
src := netip.Prefix{}
|
src := netip.Prefix{}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
netip "net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddrList mocks base method.
|
// AddrList mocks base method.
|
||||||
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) {
|
func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
|
||||||
ret0, _ := ret[0].([]netlink.Addr)
|
ret0, _ := ret[0].([]netip.Prefix)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
@@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddrReplace mocks base method.
|
// AddrReplace mocks base method.
|
||||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd mocks base method.
|
// LinkAdd mocks base method.
|
||||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||||
ret0, _ := ret[0].(int)
|
ret0, _ := ret[0].(uint32)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
@@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkByIndex mocks base method.
|
// LinkByIndex mocks base method.
|
||||||
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) {
|
func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
|
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
|
||||||
ret0, _ := ret[0].(netlink.Link)
|
ret0, _ := ret[0].(netlink.Link)
|
||||||
@@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkDel mocks base method.
|
// LinkDel mocks base method.
|
||||||
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetDown mocks base method.
|
// LinkSetDown mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp mocks base method.
|
// LinkSetUp mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||||
ret0, _ := ret[0].(int)
|
ret0, _ := ret[0].(error)
|
||||||
ret1, _ := ret[1].(error)
|
return ret0
|
||||||
return ret0, ret1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||||
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteList mocks base method.
|
// RouteList mocks base method.
|
||||||
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
|
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||||
ret0, _ := ret[0].([]netlink.Route)
|
ret0, _ := ret[0].([]netlink.Route)
|
||||||
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RuleList mocks base method.
|
// RuleList mocks base method.
|
||||||
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) {
|
func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RuleList", arg0)
|
ret := m.ctrl.Call(m, "RuleList", arg0)
|
||||||
ret0, _ := ret[0].([]netlink.Rule)
|
ret0, _ := ret[0].([]netlink.Rule)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
outboundTable = 199
|
outboundTable uint32 = 199
|
||||||
outboundPriority = 99
|
outboundPriority uint32 = 99
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
|
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
|
||||||
|
|||||||
@@ -9,25 +9,33 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||||
iface string, table int,
|
iface string, table uint32,
|
||||||
) error {
|
) error {
|
||||||
destinationStr := destination.String()
|
destinationStr := destination.String()
|
||||||
r.logger.Info("adding route for " + destinationStr)
|
r.logger.Info("adding route for " + destinationStr)
|
||||||
r.logger.Debug("ip route replace " + destinationStr +
|
r.logger.Debug("ip route replace " + destinationStr +
|
||||||
" via " + gateway.String() +
|
" via " + gateway.String() +
|
||||||
" dev " + iface +
|
" dev " + iface +
|
||||||
" table " + strconv.Itoa(table))
|
" table " + strconv.Itoa(int(table)))
|
||||||
|
|
||||||
link, err := r.netLinker.LinkByName(iface)
|
link, err := r.netLinker.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
family := netlink.FamilyV4
|
||||||
|
if destination.Addr().Is6() {
|
||||||
|
family = netlink.FamilyV6
|
||||||
|
}
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
Dst: destination,
|
Dst: destination,
|
||||||
Gw: gateway,
|
Gw: gateway,
|
||||||
LinkIndex: link.Index,
|
LinkIndex: link.Index,
|
||||||
|
Family: family,
|
||||||
Table: table,
|
Table: table,
|
||||||
|
Type: netlink.RouteTypeUnicast,
|
||||||
|
Scope: netlink.ScopeUniverse,
|
||||||
|
Proto: netlink.ProtoStatic,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteReplace(route); err != nil {
|
if err := r.netLinker.RouteReplace(route); err != nil {
|
||||||
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
|
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
|
||||||
@@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
|
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||||
iface string, table int,
|
iface string, table uint32,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
destinationStr := destination.String()
|
destinationStr := destination.String()
|
||||||
r.logger.Info("deleting route for " + destinationStr)
|
r.logger.Info("deleting route for " + destinationStr)
|
||||||
r.logger.Debug("ip route delete " + destinationStr +
|
r.logger.Debug("ip route delete " + destinationStr +
|
||||||
" via " + gateway.String() +
|
" via " + gateway.String() +
|
||||||
" dev " + iface +
|
" dev " + iface +
|
||||||
" table " + strconv.Itoa(table))
|
" table " + strconv.Itoa(int(table)))
|
||||||
|
|
||||||
link, err := r.netLinker.LinkByName(iface)
|
link, err := r.netLinker.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
return fmt.Errorf("finding link for interface %s: %w", iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
family := netlink.FamilyV4
|
||||||
|
if destination.Addr().Is6() {
|
||||||
|
family = netlink.FamilyV6
|
||||||
|
}
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
Dst: destination,
|
Dst: destination,
|
||||||
Gw: gateway,
|
Gw: gateway,
|
||||||
LinkIndex: link.Index,
|
LinkIndex: link.Index,
|
||||||
|
Family: family,
|
||||||
Table: table,
|
Table: table,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteDel(route); err != nil {
|
if err := r.netLinker.RouteDel(route); err != nil {
|
||||||
|
|||||||
+10
-10
@@ -15,20 +15,20 @@ type NetLinker interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Addresser interface {
|
type Addresser interface {
|
||||||
AddrList(link netlink.Link, family int) (
|
AddrList(linkIndex uint32, family uint8) (
|
||||||
addresses []netlink.Addr, err error)
|
addresses []netip.Prefix, err error)
|
||||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
AddrReplace(linkIndex uint32, prefix netip.Prefix) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(family int) (routes []netlink.Route, err error)
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
RouteAdd(route netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
RouteDel(route netlink.Route) error
|
RouteDel(route netlink.Route) error
|
||||||
RouteReplace(route netlink.Route) error
|
RouteReplace(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleList(family int) (rules []netlink.Rule, err error)
|
RuleList(family uint8) (rules []netlink.Rule, err error)
|
||||||
RuleAdd(rule netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
@@ -36,11 +36,11 @@ type Ruler interface {
|
|||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkByIndex(index int) (link netlink.Link, err error)
|
LinkByIndex(index uint32) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(index uint32) (err error)
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(index uint32) (err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(index uint32) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Routing struct {
|
type Routing struct {
|
||||||
|
|||||||
+39
-13
@@ -7,12 +7,19 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
|
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
|
||||||
rule := netlink.NewRule()
|
family := netlink.FamilyV4
|
||||||
rule.Src = src
|
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||||
rule.Dst = dst
|
family = netlink.FamilyV6
|
||||||
rule.Priority = priority
|
}
|
||||||
rule.Table = table
|
rule := netlink.Rule{
|
||||||
|
Priority: &priority,
|
||||||
|
Family: family,
|
||||||
|
Table: table,
|
||||||
|
Src: src,
|
||||||
|
Dst: dst,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
|
}
|
||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
|
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
|
||||||
rule := netlink.NewRule()
|
family := netlink.FamilyV4
|
||||||
rule.Src = src
|
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||||
rule.Dst = dst
|
family = netlink.FamilyV6
|
||||||
rule.Priority = priority
|
}
|
||||||
rule.Table = table
|
rule := netlink.Rule{
|
||||||
|
Priority: &priority,
|
||||||
|
Family: family,
|
||||||
|
Table: table,
|
||||||
|
Src: src,
|
||||||
|
Dst: dst,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
|
}
|
||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rulesAreEqual checks whether two rules are equal
|
||||||
|
// only according to src, dst, priority and table.
|
||||||
func rulesAreEqual(a, b netlink.Rule) bool {
|
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||||
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||||
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||||
a.Priority == b.Priority &&
|
ptrsEqual(a.Priority, b.Priority) &&
|
||||||
a.Table == b.Table
|
a.Table == b.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
|||||||
return a.Bits() == b.Bits() &&
|
return a.Bits() == b.Bits() &&
|
||||||
a.Addr().Compare(b.Addr()) == 0
|
a.Addr().Compare(b.Addr()) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrsEqual(a, b *uint32) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *a == *b
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func makeIPRule(src, dst netip.Prefix,
|
func makeIPRule(src, dst netip.Prefix,
|
||||||
table, priority int,
|
table uint32, priority uint32,
|
||||||
) netlink.Rule {
|
) netlink.Rule {
|
||||||
rule := netlink.NewRule()
|
family := netlink.FamilyV4
|
||||||
rule.Src = src
|
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
|
||||||
rule.Dst = dst
|
family = netlink.FamilyV6
|
||||||
rule.Table = table
|
}
|
||||||
rule.Priority = priority
|
return netlink.Rule{
|
||||||
return rule
|
Priority: &priority,
|
||||||
|
Family: family,
|
||||||
|
Table: table,
|
||||||
|
Src: src,
|
||||||
|
Dst: dst,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Routing_addIPRule(t *testing.T) {
|
func Test_Routing_addIPRule(t *testing.T) {
|
||||||
@@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
src netip.Prefix
|
src netip.Prefix
|
||||||
dst netip.Prefix
|
dst netip.Prefix
|
||||||
table int
|
table uint32
|
||||||
priority int
|
priority uint32
|
||||||
ruleList ruleListCall
|
ruleList ruleListCall
|
||||||
ruleAdd ruleAddCall
|
ruleAdd ruleAddCall
|
||||||
err error
|
err error
|
||||||
@@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
src netip.Prefix
|
src netip.Prefix
|
||||||
dst netip.Prefix
|
dst netip.Prefix
|
||||||
table int
|
table uint32
|
||||||
priority int
|
priority uint32
|
||||||
ruleList ruleListCall
|
ruleList ruleListCall
|
||||||
ruleDel ruleDelCall
|
ruleDel ruleDelCall
|
||||||
err error
|
err error
|
||||||
@@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrTo[T any](v T) *T { return &v }
|
||||||
|
|
||||||
func Test_rulesAreEqual(t *testing.T) {
|
func Test_rulesAreEqual(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
a: netlink.Rule{
|
a: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
b: netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
a: netlink.Rule{
|
a: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
b: netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
a: netlink.Rule{
|
a: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 999,
|
Priority: ptrTo(uint32(999)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
b: netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
a: netlink.Rule{
|
a: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 999,
|
Table: 102,
|
||||||
},
|
},
|
||||||
b: netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
a: netlink.Rule{
|
a: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
b: netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: ptrTo(uint32(100)),
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
equal: true,
|
equal: true,
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableMain = unix.RT_TABLE_MAIN
|
||||||
|
tableLocal = unix.RT_TABLE_LOCAL
|
||||||
|
)
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableMain = 0
|
||||||
|
tableLocal = 0
|
||||||
|
)
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -34,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
|||||||
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
|
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
|
||||||
return route.Gw, nil
|
return route.Gw, nil
|
||||||
case route.Dst.IsSingleIP() &&
|
case route.Dst.IsSingleIP() &&
|
||||||
route.Dst.Addr().Compare(route.Src) == 0 &&
|
route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
|
||||||
route.Table == unix.RT_TABLE_LOCAL: // Wireguard
|
route.Table == tableLocal: // Wireguard
|
||||||
route.Src = route.Src.Unmap()
|
if route.Src.Addr().Is6() {
|
||||||
if route.Src.Is6() {
|
|
||||||
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
|
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
|
||||||
}
|
}
|
||||||
bytes := route.Src.As4()
|
bytes := route.Src.Addr().As4()
|
||||||
// force last byte to 1 to get the VPN gateway IP
|
// force last byte to 1 to get the VPN gateway IP
|
||||||
// This is not necessarily bullet proof but it seems to work.
|
// This is not necessarily bullet proof but it seems to work.
|
||||||
bytes[3] = 1
|
bytes[3] = 1
|
||||||
|
|||||||
+6346
-8186
File diff suppressed because it is too large
Load Diff
@@ -57,15 +57,15 @@ type Storage interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NetLinker interface {
|
type NetLinker interface {
|
||||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||||
Router
|
Router
|
||||||
Ruler
|
Ruler
|
||||||
Linker
|
Linker
|
||||||
IsWireguardSupported() bool
|
IsWireguardSupported() (ok bool, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(family int) (routes []netlink.Route, err error)
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
RouteAdd(route netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,10 +77,11 @@ type Ruler interface {
|
|||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(linkIndex uint32) error
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(linkIndex uint32) error
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(linkIndex uint32) error
|
||||||
|
LinkSetMTU(linkIndex, mtu uint32) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSLoop interface {
|
type DNSLoop interface {
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tunnelUpData := tunnelUpData{
|
tunnelUpData := tunnelUpData{
|
||||||
|
vpnType: settings.Type,
|
||||||
serverIP: connection.IP,
|
serverIP: connection.IP,
|
||||||
serverName: connection.ServerName,
|
serverName: connection.ServerName,
|
||||||
canPortForward: connection.PortForward,
|
canPortForward: connection.PortForward,
|
||||||
|
|||||||
@@ -2,16 +2,24 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/check"
|
"github.com/qdm12/dns/v2/pkg/check"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud"
|
||||||
"github.com/qdm12/gluetun/internal/version"
|
"github.com/qdm12/gluetun/internal/version"
|
||||||
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tunnelUpData struct {
|
type tunnelUpData struct {
|
||||||
// Healthcheck
|
// Healthcheck
|
||||||
serverIP netip.Addr
|
serverIP netip.Addr
|
||||||
|
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||||
|
// It can be "wireguard" or "openvpn".
|
||||||
|
vpnType string
|
||||||
// Port forwarding
|
// Port forwarding
|
||||||
vpnIntf string
|
vpnIntf string
|
||||||
serverName string // used for PIA
|
serverName string // used for PIA
|
||||||
@@ -31,6 +39,13 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||||
|
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
||||||
|
l.netLinker, l.routing, mtuLogger)
|
||||||
|
if err != nil {
|
||||||
|
mtuLogger.Error(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
||||||
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
|
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
|
||||||
icmpTargetIPs = []netip.Addr{data.serverIP}
|
icmpTargetIPs = []netip.Addr{data.serverIP}
|
||||||
@@ -120,3 +135,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
|||||||
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
||||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||||
|
|
||||||
|
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||||
|
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||||
|
) error {
|
||||||
|
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
||||||
|
|
||||||
|
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting VPN gateway IP address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netlinker.LinkByName(vpnInterface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting VPN interface by name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalMTU := link.MTU
|
||||||
|
|
||||||
|
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||||
|
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||||
|
const physicalLinkMTU uint32 = 1500
|
||||||
|
vpnLinkMTU := physicalLinkMTU
|
||||||
|
switch vpnType {
|
||||||
|
case "wireguard":
|
||||||
|
vpnLinkMTU -= 60 // Wireguard overhead
|
||||||
|
case "openvpn":
|
||||||
|
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||||
|
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||||
|
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
||||||
|
|
||||||
|
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const pingTimeout = time.Second
|
||||||
|
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||||
|
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
||||||
|
vpnLinkMTU = originalMTU
|
||||||
|
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||||
|
vpnInterface, originalMTU, err)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("path MTU discovering: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,26 +3,20 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
func (w *Wireguard) addAddresses(linkIndex uint32,
|
||||||
addresses []netip.Prefix,
|
addresses []netip.Prefix,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for _, ipNet := range addresses {
|
for _, address := range addresses {
|
||||||
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
|
if !*w.settings.IPv6 && address.Addr().Is6() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
address := netlink.Addr{
|
err = w.netlink.AddrReplace(linkIndex, address)
|
||||||
Network: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = w.netlink.AddrReplace(link, address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||||
err, address, link.Name)
|
err, address, linkIndex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
link netlink.Link
|
linkIndex uint32
|
||||||
addrs []netip.Prefix
|
addrs []netip.Prefix
|
||||||
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
|
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
link: netlink.Link{Type: "wireguard"},
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
AddrReplace(linkIndex, ipNetTwo).
|
||||||
Return(nil).After(firstCall)
|
Return(nil).After(firstCall)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"first add error": {
|
"first add error": {
|
||||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
Return(errDummy)
|
Return(errDummy)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||||
},
|
},
|
||||||
"second add error": {
|
"second add error": {
|
||||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
AddrReplace(linkIndex, ipNetTwo).
|
||||||
Return(errDummy).After(firstCall)
|
Return(errDummy).After(firstCall)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
|
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||||
},
|
},
|
||||||
"ignore IPv6": {
|
"ignore IPv6": {
|
||||||
addrs: []netip.Prefix{ipNetTwo},
|
addrs: []netip.Prefix{ipNetTwo},
|
||||||
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
|
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
IPv6: ptrTo(false),
|
IPv6: ptrTo(false),
|
||||||
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
wg := testCase.wgBuilder(ctrl, testCase.link)
|
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
||||||
|
|
||||||
err := wg.addAddresses(testCase.link, testCase.addrs)
|
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|||||||
@@ -1,3 +1,53 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
func ptrTo[T any](x T) *T { return &x }
|
func ptrTo[T any](x T) *T { return &x }
|
||||||
|
|
||||||
|
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||||
|
|
||||||
|
func makeLinkName() string {
|
||||||
|
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
b := make([]byte, 8)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = alphabet[rng.IntN(len(alphabet))]
|
||||||
|
}
|
||||||
|
return "test" + string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||||
|
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||||
|
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||||
|
ptrsEqual(a.Priority, b.Priority) &&
|
||||||
|
a.Table == b.Table &&
|
||||||
|
a.Family == b.Family &&
|
||||||
|
a.Flags == b.Flags &&
|
||||||
|
a.Action == b.Action &&
|
||||||
|
ptrsEqual(a.Mark, b.Mark)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
||||||
|
if !a.IsValid() && !b.IsValid() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !a.IsValid() || !b.IsValid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.Bits() == b.Bits() &&
|
||||||
|
a.Addr().Compare(b.Addr()) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrsEqual(a, b *uint32) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *a == *b
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build netlink && linux
|
//go:build linux
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
@@ -10,13 +10,16 @@ import (
|
|||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type noopDebugLogger struct{}
|
type noopDebugLogger struct{}
|
||||||
|
|
||||||
func (n noopDebugLogger) Debugf(format string, args ...any) {}
|
func (n noopDebugLogger) Debug(_ string) {}
|
||||||
func (n noopDebugLogger) Patch(options ...log.Option) {}
|
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
|
||||||
|
func (n noopDebugLogger) Info(_ string) {}
|
||||||
|
func (n noopDebugLogger) Error(_ string) {}
|
||||||
|
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||||
|
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||||
|
|
||||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
netlinker := netlink.New(&noopDebugLogger{})
|
netlinker := netlink.New(&noopDebugLogger{})
|
||||||
|
|
||||||
link := netlink.Link{
|
link := netlink.Link{
|
||||||
Type: "bridge",
|
DeviceType: netlink.DeviceTypeNone,
|
||||||
Name: "test_8081",
|
VirtualType: "bridge",
|
||||||
}
|
Name: makeLinkName(),
|
||||||
|
|
||||||
// Remove any previously created test interface from a crashed/panic
|
|
||||||
// test or test suite run.
|
|
||||||
err := netlinker.LinkDel(link)
|
|
||||||
if err != nil && err.Error() != "invalid argument" {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
linkIndex, err := netlinker.LinkAdd(link)
|
linkIndex, err := netlinker.LinkAdd(link)
|
||||||
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
link.Index = linkIndex
|
link.Index = linkIndex
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = netlinker.LinkDel(link)
|
err = netlinker.LinkDel(linkIndex)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const addIterations = 2 // initial + replace
|
const addIterations = 2 // initial + replace
|
||||||
|
for range addIterations {
|
||||||
for i := 0; i < addIterations; i++ {
|
err = wg.addAddresses(link.Index, addresses)
|
||||||
err = wg.addAddresses(link, addresses)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
|
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
require.Equal(t, len(addresses), len(ipPrefixes))
|
||||||
for i, netlinkAddress := range netlinkAddresses {
|
for i, ipPrefix := range ipPrefixes {
|
||||||
require.NotNil(t, netlinkAddress.Network)
|
assert.Equal(t, addresses[i], ipPrefix)
|
||||||
assert.Equal(t, addresses[i], netlinkAddress.Network)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
|||||||
netlinker := netlink.New(&noopDebugLogger{})
|
netlinker := netlink.New(&noopDebugLogger{})
|
||||||
wg := &Wireguard{
|
wg := &Wireguard{
|
||||||
netlink: netlinker,
|
netlink: netlinker,
|
||||||
|
logger: &noopDebugLogger{},
|
||||||
}
|
}
|
||||||
|
|
||||||
rulePriority := 10000
|
// Unique combination for this test
|
||||||
const firewallMark = 999
|
const rulePriority uint32 = 10000
|
||||||
const family = unix.AF_INET // ipv4
|
const firewallMark uint32 = 12345
|
||||||
|
const family = netlink.FamilyV4
|
||||||
|
|
||||||
cleanup, err := wg.addRule(rulePriority,
|
cleanup, err := wg.addRule(rulePriority,
|
||||||
firewallMark, family)
|
firewallMark, family)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
t.Cleanup(func() {
|
||||||
err := cleanup()
|
err := cleanup()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}()
|
})
|
||||||
|
|
||||||
rules, err := netlinker.RuleList(netlink.FamilyV4)
|
rules, err := netlinker.RuleList(netlink.FamilyV4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
expectedRule := netlink.Rule{
|
||||||
|
Priority: ptrTo(rulePriority),
|
||||||
|
Family: netlink.FamilyV4,
|
||||||
|
Table: firewallMark,
|
||||||
|
Mark: ptrTo(firewallMark),
|
||||||
|
Flags: netlink.FlagInvert,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
|
}
|
||||||
var rule netlink.Rule
|
var rule netlink.Rule
|
||||||
var ruleFound bool
|
var ruleFound bool
|
||||||
for _, rule = range rules {
|
for _, rule = range rules {
|
||||||
if rule.Mark == firewallMark {
|
if rulesAreEqual(rule, expectedRule) {
|
||||||
ruleFound = true
|
ruleFound = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.True(t, ruleFound)
|
require.True(t, ruleFound)
|
||||||
expectedRule := netlink.Rule{
|
|
||||||
Invert: true,
|
|
||||||
Priority: rulePriority,
|
|
||||||
Mark: firewallMark,
|
|
||||||
Table: firewallMark,
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedRule, rule)
|
|
||||||
|
|
||||||
// Existing rule cannot be added
|
// Existing rule cannot be added
|
||||||
nilCleanup, err := wg.addRule(rulePriority,
|
nilCleanup, err := wg.addRule(rulePriority,
|
||||||
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
|||||||
_ = nilCleanup() // in case it succeeds
|
_ = nilCleanup() // in case it succeeds
|
||||||
}
|
}
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
|
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import "github.com/qdm12/gluetun/internal/netlink"
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||||
|
|
||||||
type NetLinker interface {
|
type NetLinker interface {
|
||||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||||
Router
|
Router
|
||||||
Ruler
|
Ruler
|
||||||
Linker
|
Linker
|
||||||
IsWireguardSupported() bool
|
IsWireguardSupported() (ok bool, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(family int) (routes []netlink.Route, err error)
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
RouteAdd(route netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,10 +27,10 @@ type Ruler interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(linkIndex uint32) error
|
||||||
LinkSetDown(link netlink.Link) error
|
LinkSetDown(linkIndex uint32) error
|
||||||
LinkDel(link netlink.Link) error
|
LinkDel(linkIndex uint32) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
netip "net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddrReplace mocks base method.
|
// AddrReplace mocks base method.
|
||||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsWireguardSupported mocks base method.
|
// IsWireguardSupported mocks base method.
|
||||||
func (m *MockNetLinker) IsWireguardSupported() bool {
|
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
||||||
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd mocks base method.
|
// LinkAdd mocks base method.
|
||||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||||
ret0, _ := ret[0].(int)
|
ret0, _ := ret[0].(uint32)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkDel mocks base method.
|
// LinkDel mocks base method.
|
||||||
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetDown mocks base method.
|
// LinkSetDown mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp mocks base method.
|
// LinkSetUp mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||||
ret0, _ := ret[0].(int)
|
ret0, _ := ret[0].(error)
|
||||||
ret1, _ := ret[1].(error)
|
return ret0
|
||||||
return ret0, ret1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||||
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteList mocks base method.
|
// RouteList mocks base method.
|
||||||
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
|
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||||
ret0, _ := ret[0].([]netlink.Route)
|
ret0, _ := ret[0].([]netlink.Route)
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||||
firewallMark uint32,
|
firewallMark uint32,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for _, dst := range destinations {
|
for _, dst := range destinations {
|
||||||
err = w.addRoute(link, dst, firewallMark)
|
err = w.addRoute(linkIndex, dst, firewallMark)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||||
firewallMark uint32,
|
firewallMark uint32,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
family := netlink.FamilyV4
|
family := netlink.FamilyV4
|
||||||
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
|||||||
family = netlink.FamilyV6
|
family = netlink.FamilyV6
|
||||||
}
|
}
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
LinkIndex: link.Index,
|
LinkIndex: linkIndex,
|
||||||
Dst: dst,
|
Dst: dst,
|
||||||
Family: family,
|
Family: family,
|
||||||
Table: int(firewallMark),
|
Table: firewallMark,
|
||||||
|
Type: netlink.RouteTypeUnicast,
|
||||||
|
Scope: netlink.ScopeUniverse,
|
||||||
|
Proto: netlink.ProtoStatic,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.netlink.RouteAdd(route)
|
err = w.netlink.RouteAdd(route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"adding route for link %s, destination %s and table %d: %w",
|
"adding route for link with index %d, destination %s and table %d: %w",
|
||||||
link.Name, dst, firewallMark, err)
|
linkIndex, dst, firewallMark, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
|||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
link netlink.Link
|
|
||||||
dst netip.Prefix
|
dst netip.Prefix
|
||||||
expectedRoute netlink.Route
|
expectedRoute netlink.Route
|
||||||
routeAddErr error
|
routeAddErr error
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
link: netlink.Link{
|
|
||||||
Index: linkIndex,
|
|
||||||
},
|
|
||||||
dst: ipPrefix,
|
dst: ipPrefix,
|
||||||
expectedRoute: netlink.Route{
|
expectedRoute: netlink.Route{
|
||||||
LinkIndex: linkIndex,
|
LinkIndex: linkIndex,
|
||||||
Dst: ipPrefix,
|
Dst: ipPrefix,
|
||||||
Family: netlink.FamilyV4,
|
Family: netlink.FamilyV4,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
|
Type: netlink.RouteTypeUnicast,
|
||||||
|
Scope: netlink.ScopeUniverse,
|
||||||
|
Proto: netlink.ProtoStatic,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"route add error": {
|
"route add error": {
|
||||||
link: netlink.Link{
|
|
||||||
Name: "a_bridge",
|
|
||||||
Index: linkIndex,
|
|
||||||
},
|
|
||||||
dst: ipPrefix,
|
dst: ipPrefix,
|
||||||
expectedRoute: netlink.Route{
|
expectedRoute: netlink.Route{
|
||||||
LinkIndex: linkIndex,
|
LinkIndex: linkIndex,
|
||||||
Dst: ipPrefix,
|
Dst: ipPrefix,
|
||||||
Family: netlink.FamilyV4,
|
Family: netlink.FamilyV4,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
|
Type: netlink.RouteTypeUnicast,
|
||||||
|
Scope: netlink.ScopeUniverse,
|
||||||
|
Proto: netlink.ProtoStatic,
|
||||||
},
|
},
|
||||||
routeAddErr: errDummy,
|
routeAddErr: errDummy,
|
||||||
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
|||||||
RouteAdd(testCase.expectedRoute).
|
RouteAdd(testCase.expectedRoute).
|
||||||
Return(testCase.routeAddErr)
|
Return(testCase.routeAddErr)
|
||||||
|
|
||||||
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
|
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|||||||
@@ -7,15 +7,17 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
|
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||||
family int,
|
family uint8,
|
||||||
) (cleanup func() error, err error) {
|
) (cleanup func() error, err error) {
|
||||||
rule := netlink.NewRule()
|
rule := netlink.Rule{
|
||||||
rule.Invert = true
|
Priority: &rulePriority,
|
||||||
rule.Priority = rulePriority
|
Family: family,
|
||||||
rule.Mark = firewallMark
|
Table: firewallMark,
|
||||||
rule.Table = int(firewallMark)
|
Mark: &firewallMark,
|
||||||
rule.Family = family
|
Flags: netlink.FlagInvert,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
|
}
|
||||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||||
if strings.HasSuffix(err.Error(), "file exists") {
|
if strings.HasSuffix(err.Error(), "file exists") {
|
||||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||||
|
|||||||
@@ -8,15 +8,14 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Wireguard_addRule(t *testing.T) {
|
func Test_Wireguard_addRule(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
const rulePriority = 987
|
const rulePriority uint32 = 987
|
||||||
const firewallMark = 456
|
const firewallMark uint32 = 456
|
||||||
const family = unix.AF_INET
|
const family = netlink.FamilyV4
|
||||||
|
|
||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
expectedRule: netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Priority: ptrTo(rulePriority),
|
||||||
Priority: rulePriority,
|
Mark: ptrTo(firewallMark),
|
||||||
Mark: firewallMark,
|
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Family: family,
|
Family: family,
|
||||||
|
Flags: netlink.FlagInvert,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"rule add error": {
|
"rule add error": {
|
||||||
expectedRule: netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Priority: ptrTo(rulePriority),
|
||||||
Priority: rulePriority,
|
Mark: ptrTo(firewallMark),
|
||||||
Mark: firewallMark,
|
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Family: family,
|
Family: family,
|
||||||
|
Flags: netlink.FlagInvert,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
},
|
},
|
||||||
ruleAddErr: errDummy,
|
ruleAddErr: errDummy,
|
||||||
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
|
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
|
||||||
},
|
},
|
||||||
"rule delete error": {
|
"rule delete error": {
|
||||||
expectedRule: netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Priority: ptrTo(rulePriority),
|
||||||
Priority: rulePriority,
|
Mark: ptrTo(firewallMark),
|
||||||
Mark: firewallMark,
|
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Family: family,
|
Family: family,
|
||||||
|
Flags: netlink.FlagInvert,
|
||||||
|
Action: netlink.ActionToTable,
|
||||||
},
|
},
|
||||||
ruleDelErr: errDummy,
|
ruleDelErr: errDummy,
|
||||||
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
|
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
|
||||||
|
|||||||
+42
-41
@@ -7,15 +7,14 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
||||||
ErrCreateTun = errors.New("cannot create TUN device")
|
ErrCreateTun = errors.New("cannot create TUN device")
|
||||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
ErrAddLink = errors.New("cannot add Wireguard link")
|
||||||
ErrFindLink = errors.New("cannot find link")
|
ErrFindLink = errors.New("cannot find link")
|
||||||
@@ -34,7 +33,11 @@ var (
|
|||||||
|
|
||||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||||
kernelSupported := w.netlink.IsWireguardSupported()
|
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||||
|
if err != nil {
|
||||||
|
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
setupFunction := setupUserSpace
|
setupFunction := setupUserSpace
|
||||||
switch w.settings.Implementation {
|
switch w.settings.Implementation {
|
||||||
@@ -67,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
|
|
||||||
defer closers.cleanup(w.logger)
|
defer closers.cleanup(w.logger)
|
||||||
|
|
||||||
link, waitAndCleanup, err := setupFunction(ctx,
|
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
||||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
|
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- err
|
waitError <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.addAddresses(link, w.settings.Addresses)
|
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||||
return
|
return
|
||||||
@@ -87,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
linkIndex, err := w.netlink.LinkSetUp(link)
|
err = w.netlink.LinkSetUp(linkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
link.Index = linkIndex
|
|
||||||
closers.add("shutting down link", stepFour, func() error {
|
closers.add("shutting down link", stepFour, func() error {
|
||||||
return w.netlink.LinkSetDown(link)
|
return w.netlink.LinkSetDown(linkIndex)
|
||||||
})
|
})
|
||||||
|
|
||||||
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
|
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||||
return
|
return
|
||||||
@@ -106,7 +108,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
if *w.settings.IPv6 {
|
if *w.settings.IPv6 {
|
||||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||||
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
||||||
w.settings.FirewallMark, unix.AF_INET6)
|
w.settings.FirewallMark, netlink.FamilyV6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||||
return
|
return
|
||||||
@@ -115,7 +117,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
}
|
}
|
||||||
|
|
||||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
||||||
w.settings.FirewallMark, unix.AF_INET)
|
w.settings.FirewallMark, netlink.FamilyV4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
||||||
return
|
return
|
||||||
@@ -133,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
type waitAndCleanupFunc func() error
|
type waitAndCleanupFunc func() error
|
||||||
|
|
||||||
func setupKernelSpace(ctx context.Context,
|
func setupKernelSpace(ctx context.Context,
|
||||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||||
closers *closers, logger Logger) (
|
closers *closers, logger Logger) (
|
||||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||||
) {
|
) {
|
||||||
link = netlink.Link{
|
|
||||||
Type: "wireguard",
|
|
||||||
Name: interfaceName,
|
|
||||||
MTU: mtu,
|
|
||||||
}
|
|
||||||
links, err := netLinker.LinkList()
|
links, err := netLinker.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("listing links: %w", err)
|
return 0, nil, fmt.Errorf("listing links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup any previous Wireguard interface with the same name
|
// Cleanup any previous Wireguard interface with the same name
|
||||||
// See https://github.com/qdm12/gluetun/issues/1669
|
// See https://github.com/qdm12/gluetun/issues/1669
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
if link.Type == "wireguard" && link.Name == interfaceName {
|
if link.VirtualType == "wireguard" && link.Name == interfaceName {
|
||||||
err = netLinker.LinkDel(link)
|
err = netLinker.LinkDel(link.Index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
||||||
interfaceName, err)
|
interfaceName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
linkIndex, err := netLinker.LinkAdd(link)
|
link := netlink.Link{
|
||||||
if err != nil {
|
VirtualType: "wireguard",
|
||||||
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
Name: interfaceName,
|
||||||
|
MTU: mtu,
|
||||||
|
}
|
||||||
|
linkIndex, err = netLinker.LinkAdd(link)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||||
}
|
}
|
||||||
link.Index = linkIndex
|
|
||||||
closers.add("deleting link", stepFive, func() error {
|
closers.add("deleting link", stepFive, func() error {
|
||||||
return netLinker.LinkDel(link)
|
return netLinker.LinkDel(linkIndex)
|
||||||
})
|
})
|
||||||
|
|
||||||
waitAndCleanup = func() error {
|
waitAndCleanup = func() error {
|
||||||
@@ -174,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
return link, waitAndCleanup, nil
|
return linkIndex, waitAndCleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupUserSpace(ctx context.Context,
|
func setupUserSpace(ctx context.Context,
|
||||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||||
closers *closers, logger Logger) (
|
closers *closers, logger Logger) (
|
||||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||||
) {
|
) {
|
||||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||||
|
|
||||||
tunName, err := tun.Name()
|
tunName, err := tun.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||||
} else if tunName != interfaceName {
|
} else if tunName != interfaceName {
|
||||||
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||||
ErrCreateTun, interfaceName, tunName)
|
ErrCreateTun, interfaceName, tunName)
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err = netLinker.LinkByName(interfaceName)
|
link, err := netLinker.LinkByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||||
}
|
}
|
||||||
closers.add("deleting link", stepFive, func() error {
|
closers.add("deleting link", stepFive, func() error {
|
||||||
return netLinker.LinkDel(link)
|
return netLinker.LinkDel(link.Index)
|
||||||
})
|
})
|
||||||
|
|
||||||
bind := conn.NewDefaultBind()
|
bind := conn.NewDefaultBind()
|
||||||
@@ -217,16 +218,16 @@ func setupUserSpace(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
uapiFile, err := ipc.UAPIOpen(interfaceName)
|
uapiFile, err := uapiOpen(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||||
|
|
||||||
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
|
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||||
@@ -251,7 +252,7 @@ func setupUserSpace(ctx context.Context,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return link, waitAndCleanup, nil
|
return link.Index, waitAndCleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ type Settings struct {
|
|||||||
FirewallMark uint32
|
FirewallMark uint32
|
||||||
// Maximum Transmission Unit (MTU) setting for the network interface.
|
// Maximum Transmission Unit (MTU) setting for the network interface.
|
||||||
// It defaults to device.DefaultMTU from wireguard-go which is 1420
|
// It defaults to device.DefaultMTU from wireguard-go which is 1420
|
||||||
MTU uint16
|
MTU uint32
|
||||||
// RulePriority is the priority for the rule created with the
|
// RulePriority is the priority for the rule created with the
|
||||||
// FirewallMark.
|
// FirewallMark.
|
||||||
RulePriority int
|
RulePriority uint32
|
||||||
// IPv6 can bet set to true if IPv6 should be handled.
|
// IPv6 can bet set to true if IPv6 should be handled.
|
||||||
// It defaults to false if left unset.
|
// It defaults to false if left unset.
|
||||||
IPv6 *bool
|
IPv6 *bool
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func uapiOpen(name string) (*os.File, error) {
|
||||||
|
return ipc.UAPIOpen(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||||
|
return ipc.UAPIListen(interfaceName, uapiFile)
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func uapiOpen(name string) (*os.File, error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user