chore(all): replace netlink library for more flexibility (#3107)

This commit is contained in:
Quentin McGaw
2026-01-27 10:11:39 +01:00
committed by GitHub
parent e292a4c9be
commit facc6df3be
50 changed files with 1074 additions and 579 deletions
+1 -1
View File
@@ -14,7 +14,7 @@ type DefaultRoute struct {
NetInterface string
Gateway netip.Addr
AssignedIP netip.Addr
Family int
Family uint8
}
func (d DefaultRoute) String() string {
+4 -4
View File
@@ -8,8 +8,8 @@ import (
)
const (
inboundTable = 200
inboundPriority = 100
inboundTable uint32 = 200
inboundPriority uint32 = 100
)
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
return nil
}
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
return nil
}
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
+2 -2
View File
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
func ipMatchesFamily(ip netip.Addr, family int) bool {
func ipMatchesFamily(ip netip.Addr, family uint8) bool {
return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6())
}
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
+3 -3
View File
@@ -26,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("listing links: %w", err)
}
localLinks := make(map[int]struct{})
localLinks := make(map[uint32]struct{})
for _, link := range links {
if link.EncapType != "ether" {
if link.DeviceType != netlink.DeviceTypeEthernet {
continue
}
@@ -95,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
// Local has higher priority then outbound(99) and inbound(100) as the
// local routes might be necessary to reach the outbound/inbound routes.
const localPriority = 98
const localPriority uint32 = 98
// Main table was setup correctly by Docker, just need to add rules to use it
src := netip.Prefix{}
+14 -14
View File
@@ -5,6 +5,7 @@
package routing
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrList mocks base method.
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) {
func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Addr)
ret0, _ := ret[0].([]netip.Prefix)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
}
// LinkByIndex mocks base method.
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) {
func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
ret0, _ := ret[0].(netlink.Link)
@@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
}
// RuleList mocks base method.
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) {
func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleList", arg0)
ret0, _ := ret[0].([]netlink.Rule)
+2 -2
View File
@@ -9,8 +9,8 @@ import (
)
const (
outboundTable = 199
outboundPriority = 99
outboundTable uint32 = 199
outboundPriority uint32 = 99
)
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
+17 -4
View File
@@ -9,25 +9,33 @@ import (
)
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int,
iface string, table uint32,
) error {
destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr)
r.logger.Debug("ip route replace " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(table))
" table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}
if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
@@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
}
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int,
iface string, table uint32,
) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr)
r.logger.Debug("ip route delete " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(table))
" table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
}
if err := r.netLinker.RouteDel(route); err != nil {
+10 -10
View File
@@ -15,20 +15,20 @@ type NetLinker interface {
}
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, prefix netip.Prefix) error
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -36,11 +36,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(index uint32) (err error)
LinkSetUp(index uint32) (err error)
LinkSetDown(index uint32) (err error)
}
type Routing struct {
+39 -13
View File
@@ -7,12 +7,19 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error
return nil
}
// rulesAreEqual checks whether two rules are equal
// only according to src, dst, priority and table.
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table
}
@@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+30 -22
View File
@@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix {
}
func makeIPRule(src, dst netip.Prefix,
table, priority int,
table uint32, priority uint32,
) netlink.Rule {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
return netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
}
func Test_Routing_addIPRule(t *testing.T) {
@@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
table uint32
priority uint32
ruleList ruleListCall
ruleAdd ruleAddCall
err error
@@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
table uint32
priority uint32
ruleList ruleListCall
ruleDel ruleDelCall
err error
@@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
}
}
func ptrTo[T any](v T) *T { return &v }
func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
@@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Priority: ptrTo(uint32(999)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
Priority: ptrTo(uint32(100)),
Table: 102,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
equal: true,
+3 -4
View File
@@ -33,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
return route.Gw, nil
case route.Dst.IsSingleIP() &&
route.Dst.Addr().Compare(route.Src) == 0 &&
route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
route.Table == tableLocal: // Wireguard
route.Src = route.Src.Unmap()
if route.Src.Is6() {
if route.Src.Addr().Is6() {
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
}
bytes := route.Src.As4()
bytes := route.Src.Addr().As4()
// force last byte to 1 to get the VPN gateway IP
// This is not necessarily bullet proof but it seems to work.
bytes[3] = 1