feat(publicip/api): query all fetchers in parallel and pick most popular result

This commit is contained in:
Quentin McGaw
2025-12-23 16:23:22 +00:00
parent 617f1b764f
commit 10a7c75aa6
3 changed files with 276 additions and 32 deletions
+1 -1
View File
@@ -5,7 +5,7 @@ import (
)
type PublicIP struct {
IP netip.Addr `json:"public_ip,omitempty"`
IP netip.Addr `json:"public_ip"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
City string `json:"city,omitempty"`
+204 -30
View File
@@ -8,20 +8,29 @@ import (
"strings"
"sync"
"time"
"unicode"
"github.com/qdm12/gluetun/internal/models"
"golang.org/x/text/runes"
"golang.org/x/text/transform"
"golang.org/x/text/unicode/norm"
)
// ResilientFetcher is a fetcher implementation using multiple fetchers.
// If a fetcher fails, it tries the next one.
// To fetch public IP information for a specific IP address,
// it fetches from all sources to find the best result, since data
// from a single source can be wrong.
type ResilientFetcher struct {
fetchers []Fetcher
logger Warner
fetcherToBanTime map[Fetcher]time.Time
banMutex sync.RWMutex
mutex sync.RWMutex
timeNow func() time.Time
}
// NewResilient creates a 'resilient' fetcher given multiple fetchers.
// For example, it can handle bans and move on to another fetcher if one fails.
func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher {
return &ResilientFetcher{
fetchers: fetchers,
@@ -31,7 +40,15 @@ func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher {
}
}
func (r *ResilientFetcher) setBanned(fetcher Fetcher) {
r.banMutex.Lock()
defer r.banMutex.Unlock()
r.fetcherToBanTime[fetcher] = r.timeNow()
}
func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) {
r.banMutex.Lock()
defer r.banMutex.Unlock()
banTime, banned := r.fetcherToBanTime[fetcher]
if !banned {
return false
@@ -49,25 +66,21 @@ func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) {
func (r *ResilientFetcher) String() string {
r.mutex.RLock()
defer r.mutex.RUnlock()
names := make([]string, 0, len(r.fetchers))
for _, fetcher := range r.fetchers {
if r.isBanned(fetcher) {
continue
}
return fetcher.String()
names = append(names, fetcher.String())
}
if len(names) == 0 {
return "<all-banned>"
}
return strings.Join(names, "+")
}
func (r *ResilientFetcher) Token() string {
r.mutex.RLock()
defer r.mutex.RUnlock()
for _, fetcher := range r.fetchers {
if r.isBanned(fetcher) {
continue
}
return fetcher.Token()
}
return "<all-banned>"
panic("invalid call")
}
// CanFetchAnyIP returns true if any of the fetchers
@@ -85,45 +98,161 @@ func (r *ResilientFetcher) CanFetchAnyIP() bool {
return false
}
var ErrFetchersAllRateLimited = errors.New("all fetchers are rate limited")
// FetchInfo obtains information on the ip address provided.
// If the ip is the zero value, the public IP address of the machine
// is used as the IP.
// If a fetcher gets banned, the next one is tried until all have been exhausted.
// Fetchers still within their banned period are skipped.
// If an error unrelated to being banned is encountered, it is returned and more
// fetchers are tried.
// It queries all non-banned fetchers in parallel to obtain the most popular result.
// It only returns an error if all fetchers fail to return information.
func (r *ResilientFetcher) FetchInfo(ctx context.Context, ip netip.Addr) (
result models.PublicIP, err error,
) {
r.mutex.RLock()
defer r.mutex.RUnlock()
for _, fetcher := range r.fetchers {
type resultData struct {
i int
result models.PublicIP
err error
}
resultsCh := make(chan resultData)
fetchersStarted := 0
for range r.fetchers {
fetcher := r.fetchers[fetchersStarted]
if r.isBanned(fetcher) ||
(ip.IsValid() && !fetcher.CanFetchAnyIP()) {
continue
}
result, err = fetcher.FetchInfo(ctx, ip)
if err == nil || !errors.Is(err, ErrTooManyRequests) {
return result, err
go func(i int, fetcher Fetcher) {
result, err := fetcher.FetchInfo(ctx, ip)
resultsCh <- resultData{
i: i,
result: result,
err: err,
}
}(fetchersStarted, fetcher)
fetchersStarted++
}
// Fetcher is banned
r.fetcherToBanTime[fetcher] = r.timeNow()
r.logger.Warn(fetcher.String() + ": " + err.Error())
// Collect resultDatas from goroutines first, which takes I/O time
// so that we don't lock the ban map mutex for too long.
resultDatas := make([]resultData, fetchersStarted)
for range resultDatas {
data := <-resultsCh
resultDatas[data.i] = data
}
fetcherNames := make([]string, len(r.fetchers))
for i, fetcher := range r.fetchers {
fetcherNames[i] = fetcher.String()
// Mutex lock ban map and process results
results := make([]models.PublicIP, 0, fetchersStarted)
errs := make([]error, 0, fetchersStarted)
for _, data := range resultDatas {
fetcher := r.fetchers[data.i]
if data.err != nil {
if errors.Is(data.err, ErrTooManyRequests) {
r.setBanned(fetcher)
}
errs = append(errs, fmt.Errorf("%s: %w", fetcher, data.err))
continue
}
results = append(results, data.result)
}
return result, fmt.Errorf("%w (%s)",
ErrFetchersAllRateLimited,
strings.Join(fetcherNames, ", "))
if len(results) == 0 { // all failed
return models.PublicIP{}, fmt.Errorf("all fetchers failed: %w", errors.Join(errs...))
}
return getMostPopularResult(results), nil
}
// getMostPopularResult finds the most popular [models.PublicIP] from
// a slice of results. It does so by first checking the country, then
// region, then city fields. The other fields are ignored in this comparison.
func getMostPopularResult(results []models.PublicIP) models.PublicIP {
if len(results) == 0 {
panic("no results to choose from")
}
// 1. Filter by Country
countries := make([]string, len(results))
for i, r := range results {
countries[i] = r.Country
}
_, countryMembers := getMostPopularString(countries)
results = filterInPlace(results, countryMembers)
// 2. Filter by Region
regions := make([]string, len(results))
for i, r := range results {
regions[i] = r.Region
}
_, regionMembers := getMostPopularString(regions)
results = filterInPlace(results, regionMembers)
// 3. Filter by City
cities := make([]string, len(results))
for i, r := range results {
cities[i] = r.City
}
winnerIdx, _ := getMostPopularString(cities)
return results[winnerIdx]
}
// filterInPlace moves selected indices to the front and trims the slice.
func filterInPlace(results []models.PublicIP, indices []int) []models.PublicIP {
for i, originalIdx := range indices {
results[i] = results[originalIdx]
}
return results[:len(indices)]
}
// getMostPopularString returns the index of the representative winner
// and a slice of all indexes that belong to that winner's cluster.
func getMostPopularString(values []string) (winnerIdx int, memberIdxs []int) {
if len(values) == 0 {
return -1, nil
}
type cluster struct {
firstIndex int
normRep string
members []int
}
var groups []cluster
for i, value := range values {
normP := normalize(value)
found := false
for j := range groups {
if levenshteinDistance(normP, groups[j].normRep) <= 1 {
groups[j].members = append(groups[j].members, i)
found = true
break
}
}
if !found {
groups = append(groups, cluster{
firstIndex: i,
normRep: normP,
members: []int{i},
})
}
}
maxCount := -1
var bestGroup cluster
for _, g := range groups {
if len(g.members) > maxCount {
maxCount = len(g.members)
bestGroup = g
}
}
return bestGroup.firstIndex, bestGroup.members
}
func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) {
@@ -151,3 +280,48 @@ func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) {
r.fetchers = fetchers
r.fetcherToBanTime = newFetcherToBanTime
}
// normalize removes accents, trims space, and lowercases the string.
func normalize(s string) string {
firstParentheseIndex := strings.Index(s, " (")
if firstParentheseIndex != -1 {
s = s[:firstParentheseIndex]
}
transformer := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
result, _, err := transform.String(transformer, s)
if err != nil {
panic(err)
}
return strings.ToLower(strings.TrimSpace(result))
}
// levenshteinDistance calculates the edit distance
// between two strings a and b.
func levenshteinDistance(a, b string) int {
switch {
case len(a) == 0:
return len(b)
case len(b) == 0:
return len(a)
}
column := make([]int, len(b)+1)
for i := 0; i <= len(b); i++ {
column[i] = i
}
for i := 1; i <= len(a); i++ {
column[0] = i
lastValue := i - 1
for j := 1; j <= len(b); j++ {
oldValue := column[j]
cost := 0
if a[i-1] != b[j-1] {
cost = 1
}
column[j] = min(column[j]+1, min(column[j-1]+1, lastValue+cost))
lastValue = oldValue
}
}
return column[len(b)]
}
+70
View File
@@ -0,0 +1,70 @@
package api
import (
"testing"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
)
func Test_GetMostPopularResult(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
input []models.PublicIP
expected models.PublicIP
}{
"exact_matches": {
input: []models.PublicIP{
{Country: "France", City: "Paris"},
{Country: "USA", City: "New York"},
{Country: "France", City: "Paris"},
},
expected: models.PublicIP{Country: "France", City: "Paris"},
},
"fuzzy_country_matching": {
input: []models.PublicIP{
{Country: "Germany", Region: "Bavaria", City: "Munich"},
{Country: "Germani", Region: "Bavaria", City: "Munich"},
{Country: "France", Region: "IDF", City: "Paris"},
},
expected: models.PublicIP{Country: "Germany", Region: "Bavaria", City: "Munich"},
},
"hierarchy_priority": {
input: []models.PublicIP{
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
{Country: "USA", Region: "New York", City: "Syracuse"},
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
},
expected: models.PublicIP{Country: "Italy", Region: "Sicily", City: "Syracuse"},
},
"normalization_check": {
input: []models.PublicIP{
{Country: "Canada", City: "Montréal"},
{Country: "Canada", City: "Montreal "},
{Country: "UK", City: "London"},
},
expected: models.PublicIP{Country: "Canada", City: "Montréal"},
},
"all_different": {
input: []models.PublicIP{
{Country: "Canada", City: "Montréal"},
{Country: "US", City: "New York"},
{Country: "UK", City: "London"},
},
expected: models.PublicIP{Country: "US", City: "New York"},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
result := getMostPopularResult(testCase.input)
assert.Equal(t, testCase.expected.Country, result.Country)
assert.Equal(t, testCase.expected.Region, result.Region)
assert.Equal(t, testCase.expected.City, result.City)
})
}
}