diff --git a/internal/models/publicip.go b/internal/models/publicip.go index 0978f4d9..4f1c2555 100644 --- a/internal/models/publicip.go +++ b/internal/models/publicip.go @@ -5,7 +5,7 @@ import ( ) type PublicIP struct { - IP netip.Addr `json:"public_ip,omitempty"` + IP netip.Addr `json:"public_ip"` Region string `json:"region,omitempty"` Country string `json:"country,omitempty"` City string `json:"city,omitempty"` diff --git a/internal/publicip/api/resilient.go b/internal/publicip/api/resilient.go index 0aefd677..ca66a6ad 100644 --- a/internal/publicip/api/resilient.go +++ b/internal/publicip/api/resilient.go @@ -8,20 +8,29 @@ import ( "strings" "sync" "time" + "unicode" "github.com/qdm12/gluetun/internal/models" + "golang.org/x/text/runes" + "golang.org/x/text/transform" + "golang.org/x/text/unicode/norm" ) +// ResilientFetcher is a fetcher implementation using multiple fetchers. +// If a fetcher fails, it tries the next one. +// To fetch public IP information for a specific IP address, +// it fetches from all sources to find the best result, since data +// from a single source can be wrong. type ResilientFetcher struct { fetchers []Fetcher logger Warner fetcherToBanTime map[Fetcher]time.Time + banMutex sync.RWMutex mutex sync.RWMutex timeNow func() time.Time } // NewResilient creates a 'resilient' fetcher given multiple fetchers. -// For example, it can handle bans and move on to another fetcher if one fails. func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher { return &ResilientFetcher{ fetchers: fetchers, @@ -31,7 +40,15 @@ func NewResilient(fetchers []Fetcher, logger Warner) *ResilientFetcher { } } +func (r *ResilientFetcher) setBanned(fetcher Fetcher) { + r.banMutex.Lock() + defer r.banMutex.Unlock() + r.fetcherToBanTime[fetcher] = r.timeNow() +} + func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) { + r.banMutex.Lock() + defer r.banMutex.Unlock() banTime, banned := r.fetcherToBanTime[fetcher] if !banned { return false @@ -49,25 +66,21 @@ func (r *ResilientFetcher) isBanned(fetcher Fetcher) (banned bool) { func (r *ResilientFetcher) String() string { r.mutex.RLock() defer r.mutex.RUnlock() + names := make([]string, 0, len(r.fetchers)) for _, fetcher := range r.fetchers { if r.isBanned(fetcher) { continue } - return fetcher.String() + names = append(names, fetcher.String()) } - return "" + if len(names) == 0 { + return "" + } + return strings.Join(names, "+") } func (r *ResilientFetcher) Token() string { - r.mutex.RLock() - defer r.mutex.RUnlock() - for _, fetcher := range r.fetchers { - if r.isBanned(fetcher) { - continue - } - return fetcher.Token() - } - return "" + panic("invalid call") } // CanFetchAnyIP returns true if any of the fetchers @@ -85,45 +98,161 @@ func (r *ResilientFetcher) CanFetchAnyIP() bool { return false } -var ErrFetchersAllRateLimited = errors.New("all fetchers are rate limited") - // FetchInfo obtains information on the ip address provided. // If the ip is the zero value, the public IP address of the machine // is used as the IP. -// If a fetcher gets banned, the next one is tried – until all have been exhausted. -// Fetchers still within their banned period are skipped. -// If an error unrelated to being banned is encountered, it is returned and more -// fetchers are tried. +// It queries all non-banned fetchers in parallel to obtain the most popular result. +// It only returns an error if all fetchers fail to return information. func (r *ResilientFetcher) FetchInfo(ctx context.Context, ip netip.Addr) ( result models.PublicIP, err error, ) { r.mutex.RLock() defer r.mutex.RUnlock() - for _, fetcher := range r.fetchers { + type resultData struct { + i int + result models.PublicIP + err error + } + resultsCh := make(chan resultData) + fetchersStarted := 0 + for range r.fetchers { + fetcher := r.fetchers[fetchersStarted] if r.isBanned(fetcher) || (ip.IsValid() && !fetcher.CanFetchAnyIP()) { continue } - result, err = fetcher.FetchInfo(ctx, ip) - if err == nil || !errors.Is(err, ErrTooManyRequests) { - return result, err + go func(i int, fetcher Fetcher) { + result, err := fetcher.FetchInfo(ctx, ip) + resultsCh <- resultData{ + i: i, + result: result, + err: err, + } + }(fetchersStarted, fetcher) + fetchersStarted++ + } + + // Collect resultDatas from goroutines first, which takes I/O time + // so that we don't lock the ban map mutex for too long. + resultDatas := make([]resultData, fetchersStarted) + for range resultDatas { + data := <-resultsCh + resultDatas[data.i] = data + } + + // Mutex lock ban map and process results + results := make([]models.PublicIP, 0, fetchersStarted) + errs := make([]error, 0, fetchersStarted) + for _, data := range resultDatas { + fetcher := r.fetchers[data.i] + if data.err != nil { + if errors.Is(data.err, ErrTooManyRequests) { + r.setBanned(fetcher) + } + errs = append(errs, fmt.Errorf("%s: %w", fetcher, data.err)) + continue + } + results = append(results, data.result) + } + + if len(results) == 0 { // all failed + return models.PublicIP{}, fmt.Errorf("all fetchers failed: %w", errors.Join(errs...)) + } + + return getMostPopularResult(results), nil +} + +// getMostPopularResult finds the most popular [models.PublicIP] from +// a slice of results. It does so by first checking the country, then +// region, then city fields. The other fields are ignored in this comparison. +func getMostPopularResult(results []models.PublicIP) models.PublicIP { + if len(results) == 0 { + panic("no results to choose from") + } + + // 1. Filter by Country + countries := make([]string, len(results)) + for i, r := range results { + countries[i] = r.Country + } + _, countryMembers := getMostPopularString(countries) + results = filterInPlace(results, countryMembers) + + // 2. Filter by Region + regions := make([]string, len(results)) + for i, r := range results { + regions[i] = r.Region + } + _, regionMembers := getMostPopularString(regions) + results = filterInPlace(results, regionMembers) + + // 3. Filter by City + cities := make([]string, len(results)) + for i, r := range results { + cities[i] = r.City + } + winnerIdx, _ := getMostPopularString(cities) + + return results[winnerIdx] +} + +// filterInPlace moves selected indices to the front and trims the slice. +func filterInPlace(results []models.PublicIP, indices []int) []models.PublicIP { + for i, originalIdx := range indices { + results[i] = results[originalIdx] + } + return results[:len(indices)] +} + +// getMostPopularString returns the index of the representative winner +// and a slice of all indexes that belong to that winner's cluster. +func getMostPopularString(values []string) (winnerIdx int, memberIdxs []int) { + if len(values) == 0 { + return -1, nil + } + + type cluster struct { + firstIndex int + normRep string + members []int + } + + var groups []cluster + + for i, value := range values { + normP := normalize(value) + found := false + + for j := range groups { + if levenshteinDistance(normP, groups[j].normRep) <= 1 { + groups[j].members = append(groups[j].members, i) + found = true + break + } } - // Fetcher is banned - r.fetcherToBanTime[fetcher] = r.timeNow() - r.logger.Warn(fetcher.String() + ": " + err.Error()) + if !found { + groups = append(groups, cluster{ + firstIndex: i, + normRep: normP, + members: []int{i}, + }) + } } - fetcherNames := make([]string, len(r.fetchers)) - for i, fetcher := range r.fetchers { - fetcherNames[i] = fetcher.String() + maxCount := -1 + var bestGroup cluster + + for _, g := range groups { + if len(g.members) > maxCount { + maxCount = len(g.members) + bestGroup = g + } } - return result, fmt.Errorf("%w (%s)", - ErrFetchersAllRateLimited, - strings.Join(fetcherNames, ", ")) + return bestGroup.firstIndex, bestGroup.members } func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) { @@ -151,3 +280,48 @@ func (r *ResilientFetcher) UpdateFetchers(fetchers []Fetcher) { r.fetchers = fetchers r.fetcherToBanTime = newFetcherToBanTime } + +// normalize removes accents, trims space, and lowercases the string. +func normalize(s string) string { + firstParentheseIndex := strings.Index(s, " (") + if firstParentheseIndex != -1 { + s = s[:firstParentheseIndex] + } + transformer := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC) + result, _, err := transform.String(transformer, s) + if err != nil { + panic(err) + } + return strings.ToLower(strings.TrimSpace(result)) +} + +// levenshteinDistance calculates the edit distance +// between two strings a and b. +func levenshteinDistance(a, b string) int { + switch { + case len(a) == 0: + return len(b) + case len(b) == 0: + return len(a) + } + + column := make([]int, len(b)+1) + for i := 0; i <= len(b); i++ { + column[i] = i + } + + for i := 1; i <= len(a); i++ { + column[0] = i + lastValue := i - 1 + for j := 1; j <= len(b); j++ { + oldValue := column[j] + cost := 0 + if a[i-1] != b[j-1] { + cost = 1 + } + column[j] = min(column[j]+1, min(column[j-1]+1, lastValue+cost)) + lastValue = oldValue + } + } + return column[len(b)] +} diff --git a/internal/publicip/api/resilient_test.go b/internal/publicip/api/resilient_test.go new file mode 100644 index 00000000..15bfd740 --- /dev/null +++ b/internal/publicip/api/resilient_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "testing" + + "github.com/qdm12/gluetun/internal/models" + "github.com/stretchr/testify/assert" +) + +func Test_GetMostPopularResult(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + input []models.PublicIP + expected models.PublicIP + }{ + "exact_matches": { + input: []models.PublicIP{ + {Country: "France", City: "Paris"}, + {Country: "USA", City: "New York"}, + {Country: "France", City: "Paris"}, + }, + expected: models.PublicIP{Country: "France", City: "Paris"}, + }, + "fuzzy_country_matching": { + input: []models.PublicIP{ + {Country: "Germany", Region: "Bavaria", City: "Munich"}, + {Country: "Germani", Region: "Bavaria", City: "Munich"}, + {Country: "France", Region: "IDF", City: "Paris"}, + }, + expected: models.PublicIP{Country: "Germany", Region: "Bavaria", City: "Munich"}, + }, + "hierarchy_priority": { + input: []models.PublicIP{ + {Country: "Italy", Region: "Sicily", City: "Syracuse"}, + {Country: "Italy", Region: "Sicily", City: "Syracuse"}, + {Country: "USA", Region: "New York", City: "Syracuse"}, + {Country: "Italy", Region: "Sicily", City: "Syracuse"}, + }, + expected: models.PublicIP{Country: "Italy", Region: "Sicily", City: "Syracuse"}, + }, + "normalization_check": { + input: []models.PublicIP{ + {Country: "Canada", City: "Montréal"}, + {Country: "Canada", City: "Montreal "}, + {Country: "UK", City: "London"}, + }, + expected: models.PublicIP{Country: "Canada", City: "Montréal"}, + }, + "all_different": { + input: []models.PublicIP{ + {Country: "Canada", City: "Montréal"}, + {Country: "US", City: "New York"}, + {Country: "UK", City: "London"}, + }, + expected: models.PublicIP{Country: "US", City: "New York"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + result := getMostPopularResult(testCase.input) + + assert.Equal(t, testCase.expected.Country, result.Country) + assert.Equal(t, testCase.expected.Region, result.Region) + assert.Equal(t, testCase.expected.City, result.City) + }) + } +}