mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(publicip/api): query all fetchers in parallel and pick most popular result
This commit is contained in:
@@ -5,7 +5,7 @@ import (
|
||||
)
|
||||
|
||||
type PublicIP struct {
|
||||
IP netip.Addr `json:"public_ip,omitempty"`
|
||||
IP netip.Addr `json:"public_ip"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
|
||||
@@ -8,20 +8,29 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"golang.org/x/text/runes"
|
||||
"golang.org/x/text/transform"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
// ResilientFetcher is a fetcher implementation using multiple fetchers.
|
||||
// If a fetcher fails, it tries the next one.
|
||||
// To fetch public IP information for a specific IP address,
|
||||
// it fetches from all sources to find the best result, since data
|
||||
// from a single source can be wrong.
|
||||
type ResilientFetcher struct {
|
||||
fetchers []Fetcher
|
||||
logger Warner
|
||||
fetcherToBanTime map[Fetcher]time.Time
|
||||
banMutex sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
timeNow func() time.Time
|
||||
}
|
||||
|
||||
// NewResilient creates a 'resilient' fetcher given multiple fetchers.
|
||||
// For example, it can handle bans and move on to another fetcher if one fails.
|
||||
func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher {
|
||||
return &ResilientFetcher{
|
||||
fetchers: fetchers,
|
||||
@@ -31,7 +40,15 @@ func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ResilientFetcher) setBanned(fetcher Fetcher) {
|
||||
r.banMutex.Lock()
|
||||
defer r.banMutex.Unlock()
|
||||
r.fetcherToBanTime[fetcher] = r.timeNow()
|
||||
}
|
||||
|
||||
func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) {
|
||||
r.banMutex.Lock()
|
||||
defer r.banMutex.Unlock()
|
||||
banTime, banned := r.fetcherToBanTime[fetcher]
|
||||
if !banned {
|
||||
return false
|
||||
@@ -49,25 +66,21 @@ func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) {
|
||||
func (r *ResilientFetcher) String() string {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
names := make([]string, 0, len(r.fetchers))
|
||||
for _, fetcher := range r.fetchers {
|
||||
if r.isBanned(fetcher) {
|
||||
continue
|
||||
}
|
||||
return fetcher.String()
|
||||
names = append(names, fetcher.String())
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return "<all-banned>"
|
||||
}
|
||||
return strings.Join(names, "+")
|
||||
}
|
||||
|
||||
func (r *ResilientFetcher) Token() string {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
for _, fetcher := range r.fetchers {
|
||||
if r.isBanned(fetcher) {
|
||||
continue
|
||||
}
|
||||
return fetcher.Token()
|
||||
}
|
||||
return "<all-banned>"
|
||||
panic("invalid call")
|
||||
}
|
||||
|
||||
// CanFetchAnyIP returns true if any of the fetchers
|
||||
@@ -85,45 +98,161 @@ func (r *ResilientFetcher) CanFetchAnyIP() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var ErrFetchersAllRateLimited = errors.New("all fetchers are rate limited")
|
||||
|
||||
// FetchInfo obtains information on the ip address provided.
|
||||
// If the ip is the zero value, the public IP address of the machine
|
||||
// is used as the IP.
|
||||
// If a fetcher gets banned, the next one is tried – until all have been exhausted.
|
||||
// Fetchers still within their banned period are skipped.
|
||||
// If an error unrelated to being banned is encountered, it is returned and more
|
||||
// fetchers are tried.
|
||||
// It queries all non-banned fetchers in parallel to obtain the most popular result.
|
||||
// It only returns an error if all fetchers fail to return information.
|
||||
func (r *ResilientFetcher) FetchInfo(ctx context.Context, ip netip.Addr) (
|
||||
result models.PublicIP, err error,
|
||||
) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
for _, fetcher := range r.fetchers {
|
||||
type resultData struct {
|
||||
i int
|
||||
result models.PublicIP
|
||||
err error
|
||||
}
|
||||
resultsCh := make(chan resultData)
|
||||
fetchersStarted := 0
|
||||
for range r.fetchers {
|
||||
fetcher := r.fetchers[fetchersStarted]
|
||||
if r.isBanned(fetcher) ||
|
||||
(ip.IsValid() && !fetcher.CanFetchAnyIP()) {
|
||||
continue
|
||||
}
|
||||
|
||||
result, err = fetcher.FetchInfo(ctx, ip)
|
||||
if err == nil || !errors.Is(err, ErrTooManyRequests) {
|
||||
return result, err
|
||||
go func(i int, fetcher Fetcher) {
|
||||
result, err := fetcher.FetchInfo(ctx, ip)
|
||||
resultsCh <- resultData{
|
||||
i: i,
|
||||
result: result,
|
||||
err: err,
|
||||
}
|
||||
}(fetchersStarted, fetcher)
|
||||
fetchersStarted++
|
||||
}
|
||||
|
||||
// Fetcher is banned
|
||||
r.fetcherToBanTime[fetcher] = r.timeNow()
|
||||
r.logger.Warn(fetcher.String() + ": " + err.Error())
|
||||
// Collect resultDatas from goroutines first, which takes I/O time
|
||||
// so that we don't lock the ban map mutex for too long.
|
||||
resultDatas := make([]resultData, fetchersStarted)
|
||||
for range resultDatas {
|
||||
data := <-resultsCh
|
||||
resultDatas[data.i] = data
|
||||
}
|
||||
|
||||
fetcherNames := make([]string, len(r.fetchers))
|
||||
for i, fetcher := range r.fetchers {
|
||||
fetcherNames[i] = fetcher.String()
|
||||
// Mutex lock ban map and process results
|
||||
results := make([]models.PublicIP, 0, fetchersStarted)
|
||||
errs := make([]error, 0, fetchersStarted)
|
||||
for _, data := range resultDatas {
|
||||
fetcher := r.fetchers[data.i]
|
||||
if data.err != nil {
|
||||
if errors.Is(data.err, ErrTooManyRequests) {
|
||||
r.setBanned(fetcher)
|
||||
}
|
||||
errs = append(errs, fmt.Errorf("%s: %w", fetcher, data.err))
|
||||
continue
|
||||
}
|
||||
results = append(results, data.result)
|
||||
}
|
||||
|
||||
return result, fmt.Errorf("%w (%s)",
|
||||
ErrFetchersAllRateLimited,
|
||||
strings.Join(fetcherNames, ", "))
|
||||
if len(results) == 0 { // all failed
|
||||
return models.PublicIP{}, fmt.Errorf("all fetchers failed: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
return getMostPopularResult(results), nil
|
||||
}
|
||||
|
||||
// getMostPopularResult finds the most popular [models.PublicIP] from
|
||||
// a slice of results. It does so by first checking the country, then
|
||||
// region, then city fields. The other fields are ignored in this comparison.
|
||||
func getMostPopularResult(results []models.PublicIP) models.PublicIP {
|
||||
if len(results) == 0 {
|
||||
panic("no results to choose from")
|
||||
}
|
||||
|
||||
// 1. Filter by Country
|
||||
countries := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
countries[i] = r.Country
|
||||
}
|
||||
_, countryMembers := getMostPopularString(countries)
|
||||
results = filterInPlace(results, countryMembers)
|
||||
|
||||
// 2. Filter by Region
|
||||
regions := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
regions[i] = r.Region
|
||||
}
|
||||
_, regionMembers := getMostPopularString(regions)
|
||||
results = filterInPlace(results, regionMembers)
|
||||
|
||||
// 3. Filter by City
|
||||
cities := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
cities[i] = r.City
|
||||
}
|
||||
winnerIdx, _ := getMostPopularString(cities)
|
||||
|
||||
return results[winnerIdx]
|
||||
}
|
||||
|
||||
// filterInPlace moves selected indices to the front and trims the slice.
|
||||
func filterInPlace(results []models.PublicIP, indices []int) []models.PublicIP {
|
||||
for i, originalIdx := range indices {
|
||||
results[i] = results[originalIdx]
|
||||
}
|
||||
return results[:len(indices)]
|
||||
}
|
||||
|
||||
// getMostPopularString returns the index of the representative winner
|
||||
// and a slice of all indexes that belong to that winner's cluster.
|
||||
func getMostPopularString(values []string) (winnerIdx int, memberIdxs []int) {
|
||||
if len(values) == 0 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
type cluster struct {
|
||||
firstIndex int
|
||||
normRep string
|
||||
members []int
|
||||
}
|
||||
|
||||
var groups []cluster
|
||||
|
||||
for i, value := range values {
|
||||
normP := normalize(value)
|
||||
found := false
|
||||
|
||||
for j := range groups {
|
||||
if levenshteinDistance(normP, groups[j].normRep) <= 1 {
|
||||
groups[j].members = append(groups[j].members, i)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
groups = append(groups, cluster{
|
||||
firstIndex: i,
|
||||
normRep: normP,
|
||||
members: []int{i},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
maxCount := -1
|
||||
var bestGroup cluster
|
||||
|
||||
for _, g := range groups {
|
||||
if len(g.members) > maxCount {
|
||||
maxCount = len(g.members)
|
||||
bestGroup = g
|
||||
}
|
||||
}
|
||||
|
||||
return bestGroup.firstIndex, bestGroup.members
|
||||
}
|
||||
|
||||
func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) {
|
||||
@@ -151,3 +280,48 @@ func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) {
|
||||
r.fetchers = fetchers
|
||||
r.fetcherToBanTime = newFetcherToBanTime
|
||||
}
|
||||
|
||||
// normalize removes accents, trims space, and lowercases the string.
|
||||
func normalize(s string) string {
|
||||
firstParentheseIndex := strings.Index(s, " (")
|
||||
if firstParentheseIndex != -1 {
|
||||
s = s[:firstParentheseIndex]
|
||||
}
|
||||
transformer := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
|
||||
result, _, err := transform.String(transformer, s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(result))
|
||||
}
|
||||
|
||||
// levenshteinDistance calculates the edit distance
|
||||
// between two strings a and b.
|
||||
func levenshteinDistance(a, b string) int {
|
||||
switch {
|
||||
case len(a) == 0:
|
||||
return len(b)
|
||||
case len(b) == 0:
|
||||
return len(a)
|
||||
}
|
||||
|
||||
column := make([]int, len(b)+1)
|
||||
for i := 0; i <= len(b); i++ {
|
||||
column[i] = i
|
||||
}
|
||||
|
||||
for i := 1; i <= len(a); i++ {
|
||||
column[0] = i
|
||||
lastValue := i - 1
|
||||
for j := 1; j <= len(b); j++ {
|
||||
oldValue := column[j]
|
||||
cost := 0
|
||||
if a[i-1] != b[j-1] {
|
||||
cost = 1
|
||||
}
|
||||
column[j] = min(column[j]+1, min(column[j-1]+1, lastValue+cost))
|
||||
lastValue = oldValue
|
||||
}
|
||||
}
|
||||
return column[len(b)]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetMostPopularResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
input []models.PublicIP
|
||||
expected models.PublicIP
|
||||
}{
|
||||
"exact_matches": {
|
||||
input: []models.PublicIP{
|
||||
{Country: "France", City: "Paris"},
|
||||
{Country: "USA", City: "New York"},
|
||||
{Country: "France", City: "Paris"},
|
||||
},
|
||||
expected: models.PublicIP{Country: "France", City: "Paris"},
|
||||
},
|
||||
"fuzzy_country_matching": {
|
||||
input: []models.PublicIP{
|
||||
{Country: "Germany", Region: "Bavaria", City: "Munich"},
|
||||
{Country: "Germani", Region: "Bavaria", City: "Munich"},
|
||||
{Country: "France", Region: "IDF", City: "Paris"},
|
||||
},
|
||||
expected: models.PublicIP{Country: "Germany", Region: "Bavaria", City: "Munich"},
|
||||
},
|
||||
"hierarchy_priority": {
|
||||
input: []models.PublicIP{
|
||||
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
|
||||
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
|
||||
{Country: "USA", Region: "New York", City: "Syracuse"},
|
||||
{Country: "Italy", Region: "Sicily", City: "Syracuse"},
|
||||
},
|
||||
expected: models.PublicIP{Country: "Italy", Region: "Sicily", City: "Syracuse"},
|
||||
},
|
||||
"normalization_check": {
|
||||
input: []models.PublicIP{
|
||||
{Country: "Canada", City: "Montréal"},
|
||||
{Country: "Canada", City: "Montreal "},
|
||||
{Country: "UK", City: "London"},
|
||||
},
|
||||
expected: models.PublicIP{Country: "Canada", City: "Montréal"},
|
||||
},
|
||||
"all_different": {
|
||||
input: []models.PublicIP{
|
||||
{Country: "Canada", City: "Montréal"},
|
||||
{Country: "US", City: "New York"},
|
||||
{Country: "UK", City: "London"},
|
||||
},
|
||||
expected: models.PublicIP{Country: "US", City: "New York"},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := getMostPopularResult(testCase.input)
|
||||
|
||||
assert.Equal(t, testCase.expected.Country, result.Country)
|
||||
assert.Equal(t, testCase.expected.Region, result.Region)
|
||||
assert.Equal(t, testCase.expected.City, result.City)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user