feat(vpn): rotate filtered servers on internal vpn restarts

- Fix #290
This commit is contained in:
Quentin McGaw
2026-05-04 03:28:48 +00:00
parent 4b819b4dbb
commit fed09562e5
57 changed files with 345 additions and 220 deletions
+3 -3
View File
@@ -2,7 +2,6 @@ package utils
import (
"fmt"
"math/rand"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/vpn"
@@ -35,7 +34,8 @@ func GetConnection(provider string,
selection settings.ServerSelection,
defaults ConnectionDefaults,
ipv6Supported bool,
randSource rand.Source) (
connPicker *ConnectionPicker,
) (
connection models.Connection, err error,
) {
servers, err := storage.FilterServers(provider, selection)
@@ -75,5 +75,5 @@ func GetConnection(provider string,
}
}
return pickConnection(connections, selection, randSource)
return pickConnection(connections, selection, connPicker)
}
+2 -1
View File
@@ -183,6 +183,7 @@ func Test_GetConnection(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
connPicker := NewConnectionPicker()
storage := common.NewMockStorage(ctrl)
storage.EXPECT().
@@ -191,7 +192,7 @@ func Test_GetConnection(t *testing.T) {
connection, err := GetConnection(testCase.provider, storage,
testCase.serverSelection, testCase.defaults, testCase.ipv6Supported,
testCase.randSource)
connPicker)
assert.Equal(t, testCase.connection, connection)
if testCase.errMessage != "" {
+67 -10
View File
@@ -1,23 +1,86 @@
package utils
import (
"encoding/binary"
"errors"
"fmt"
"math/rand"
"hash/fnv"
"net/netip"
"sync"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
// ConnectionPicker is a struct that holds the state of the connection pool cycler.
type ConnectionPicker struct {
mutex sync.Mutex
fingerprint uint64
nextIndex uint
}
func NewConnectionPicker() *ConnectionPicker {
return &ConnectionPicker{}
}
func (c *ConnectionPicker) pickConnection(connections []models.Connection,
) models.Connection {
fingerprint := fingerprintPool(connections)
c.mutex.Lock()
defer c.mutex.Unlock()
if c.fingerprint != fingerprint || c.nextIndex >= uint(len(connections)) {
c.fingerprint = fingerprint
c.nextIndex = 0
}
connection := connections[c.nextIndex]
c.nextIndex++
if c.nextIndex >= uint(len(connections)) {
c.nextIndex = 0
}
return connection
}
func fingerprintPool(connections []models.Connection) uint64 {
hasher := fnv.New64a()
for _, connection := range connections {
_, _ = hasher.Write([]byte(connection.Type))
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write(connection.IP.AsSlice())
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write(binary.BigEndian.AppendUint16(nil, connection.Port))
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write([]byte(connection.Protocol))
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write([]byte(connection.Hostname))
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write([]byte(connection.PubKey))
_, _ = hasher.Write([]byte("|"))
_, _ = hasher.Write([]byte(connection.ServerName))
_, _ = hasher.Write([]byte("|"))
if connection.PortForward {
_, _ = hasher.Write([]byte("1"))
} else {
_, _ = hasher.Write([]byte("0"))
}
_, _ = hasher.Write([]byte("\n"))
}
return hasher.Sum64()
}
// pickConnection picks a connection from a pool of connections.
// If the VPN protocol is Wireguard and the target IP is set,
// it finds the connection corresponding to this target IP.
// Otherwise, it picks a random connection from the pool of connections
// Otherwise, it cycles through the pool of connections.
// and sets the target IP address as the IP if this one is set.
func pickConnection(connections []models.Connection,
selection settings.ServerSelection, randSource rand.Source) (
selection settings.ServerSelection, picker *ConnectionPicker) (
connection models.Connection, err error,
) {
if len(connections) == 0 {
@@ -40,7 +103,7 @@ func pickConnection(connections []models.Connection,
return getTargetIPConnection(connections, targetIP)
}
connection = pickRandomConnection(connections, randSource)
connection = picker.pickConnection(connections)
if targetIPSet {
connection.IP = targetIP
}
@@ -48,12 +111,6 @@ func pickConnection(connections []models.Connection,
return connection, nil
}
func pickRandomConnection(connections []models.Connection,
source rand.Source,
) models.Connection {
return connections[rand.New(source).Intn(len(connections))] //nolint:gosec
}
func getTargetIPConnection(connections []models.Connection,
targetIP netip.Addr,
) (connection models.Connection, err error) {
+123 -12
View File
@@ -1,26 +1,137 @@
package utils
import (
"math/rand"
"net/netip"
"testing"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
)
func Test_pickRandomConnection(t *testing.T) {
func Test_ConnectionPicker_pickConnection(t *testing.T) {
t.Parallel()
connections := []models.Connection{
{Port: 1}, {Port: 2}, {Port: 3}, {Port: 4},
picker := NewConnectionPicker()
poolA := []models.Connection{
{Port: 1}, {Port: 2}, {Port: 3},
}
source := rand.NewSource(0)
connection := picker.pickConnection(poolA)
assert.Equal(t, models.Connection{Port: 1}, connection)
connection := pickRandomConnection(connections, source)
assert.Equal(t, models.Connection{Port: 3}, connection)
connection = pickRandomConnection(connections, source)
assert.Equal(t, models.Connection{Port: 3}, connection)
connection = pickRandomConnection(connections, source)
connection = picker.pickConnection(poolA)
assert.Equal(t, models.Connection{Port: 2}, connection)
connection = picker.pickConnection(poolA)
assert.Equal(t, models.Connection{Port: 3}, connection)
connection = picker.pickConnection(poolA)
assert.Equal(t, models.Connection{Port: 1}, connection)
poolB := []models.Connection{
{Port: 10}, {Port: 20},
}
connection = picker.pickConnection(poolB)
assert.Equal(t, models.Connection{Port: 10}, connection)
connection = picker.pickConnection(poolB)
assert.Equal(t, models.Connection{Port: 20}, connection)
connection = picker.pickConnection(poolB)
assert.Equal(t, models.Connection{Port: 10}, connection)
}
func Test_pickConnection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
connections []models.Connection
selection settings.ServerSelection
connection1 models.Connection
connection2 models.Connection
errMessage string
}{
"empty_connections": {
errMessage: "no connection to pick from",
},
"openvpn_cycles": {
connections: []models.Connection{
{Type: vpn.OpenVPN, Port: 1, Hostname: "one"},
{Type: vpn.OpenVPN, Port: 2, Hostname: "two"},
},
selection: settings.ServerSelection{VPN: vpn.OpenVPN},
connection1: models.Connection{
Type: vpn.OpenVPN, Port: 1,
Hostname: "one",
},
connection2: models.Connection{
Type: vpn.OpenVPN, Port: 2,
Hostname: "two",
},
},
"openvpn_endpoint_ip_overrides_cycle_pick": {
connections: []models.Connection{
{Type: vpn.OpenVPN, Hostname: "one", IP: netip.AddrFrom4([4]byte{1, 1, 1, 1})},
{Type: vpn.OpenVPN, Hostname: "two", IP: netip.AddrFrom4([4]byte{2, 2, 2, 2})},
},
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
OpenVPN: settings.OpenVPNSelection{
EndpointIP: netip.AddrFrom4([4]byte{9, 9, 9, 9}),
},
},
connection1: models.Connection{
Type: vpn.OpenVPN, Hostname: "one",
IP: netip.AddrFrom4([4]byte{9, 9, 9, 9}),
},
connection2: models.Connection{
Type: vpn.OpenVPN, Hostname: "two",
IP: netip.AddrFrom4([4]byte{9, 9, 9, 9}),
},
},
"wireguard_endpoint_ip_picks_target": {
connections: []models.Connection{
{Type: vpn.Wireguard, Hostname: "one", IP: netip.AddrFrom4([4]byte{1, 1, 1, 1})},
{Type: vpn.Wireguard, Hostname: "two", IP: netip.AddrFrom4([4]byte{2, 2, 2, 2})},
},
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
Wireguard: settings.WireguardSelection{
EndpointIP: netip.AddrFrom4([4]byte{2, 2, 2, 2}),
},
},
connection1: models.Connection{
Type: vpn.Wireguard, Hostname: "two",
IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}),
},
connection2: models.Connection{
Type: vpn.Wireguard, Hostname: "two",
IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}),
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
connPicker := NewConnectionPicker()
connection, err := pickConnection(testCase.connections,
testCase.selection, connPicker)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
assert.Equal(t, models.Connection{}, connection)
return
}
assert.NoError(t, err)
assert.Equal(t, testCase.connection1, connection)
connection, err = pickConnection(testCase.connections,
testCase.selection, connPicker)
assert.NoError(t, err)
assert.Equal(t, testCase.connection2, connection)
})
}
}