Files
gluetun/pkg/updaters/updateall.go

259 lines
9.6 KiB
Go

package updaters
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"time"
"github.com/qdm12/dns/v2/pkg/doh"
dnsprovider "github.com/qdm12/dns/v2/pkg/provider"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
publicipapi "github.com/qdm12/gluetun/internal/publicip/api"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/qdm12/gluetun/internal/updater/unzip"
"github.com/qdm12/gluetun/pkg/updaters/airvpn"
"github.com/qdm12/gluetun/pkg/updaters/cyberghost"
"github.com/qdm12/gluetun/pkg/updaters/expressvpn"
"github.com/qdm12/gluetun/pkg/updaters/fastestvpn"
"github.com/qdm12/gluetun/pkg/updaters/giganews"
"github.com/qdm12/gluetun/pkg/updaters/hidemyass"
"github.com/qdm12/gluetun/pkg/updaters/ipvanish"
"github.com/qdm12/gluetun/pkg/updaters/ivpn"
"github.com/qdm12/gluetun/pkg/updaters/mullvad"
"github.com/qdm12/gluetun/pkg/updaters/nordvpn"
"github.com/qdm12/gluetun/pkg/updaters/perfectprivacy"
"github.com/qdm12/gluetun/pkg/updaters/privado"
"github.com/qdm12/gluetun/pkg/updaters/privateinternetaccess"
"github.com/qdm12/gluetun/pkg/updaters/privatevpn"
"github.com/qdm12/gluetun/pkg/updaters/protonvpn"
"github.com/qdm12/gluetun/pkg/updaters/purevpn"
"github.com/qdm12/gluetun/pkg/updaters/slickvpn"
"github.com/qdm12/gluetun/pkg/updaters/surfshark"
"github.com/qdm12/gluetun/pkg/updaters/torguard"
"github.com/qdm12/gluetun/pkg/updaters/vpnsecure"
"github.com/qdm12/gluetun/pkg/updaters/vpnunlimited"
"github.com/qdm12/gluetun/pkg/updaters/vyprvpn"
"github.com/qdm12/gluetun/pkg/updaters/windscribe"
"github.com/qdm12/gosettings"
)
// UpdateAllSettings contains the configuration for the [UpdateAll] function.
type UpdateAllSettings struct {
// OutputPath is the directory where the provider JSON files will be written.
// It defaults to the current directory if left unset.
OutputPath *string
// ProtonEmail is the email for the ProtonVPN account, which is required
// to update ProtonVPN servers.
ProtonEmail *string
// ProtonPassword is the password for the ProtonVPN account, which is required
// to update ProtonVPN servers.
ProtonPassword *string
// IpinfoToken is the API token for the IPInfo public IP API.
// If not provided, the IP fetcher will still work but may be
// subject to stricter rate limits for ipinfo.io.
IpinfoToken *string
// MinServers is a map of provider name to minimum number of servers required for a successful update.
// If a provider has fewer servers than the specified minimum, the update will be considered a failure
// for that provider. If the provider name is not found in this map, it is assumed that there is no
// minimum server requirement for that provider.
MinServers map[string]uint
}
func (s *UpdateAllSettings) setDefaults() {
s.OutputPath = gosettings.DefaultPointer(s.OutputPath, "")
s.ProtonEmail = gosettings.DefaultPointer(s.ProtonEmail, "")
s.ProtonPassword = gosettings.DefaultPointer(s.ProtonPassword, "")
s.IpinfoToken = gosettings.DefaultPointer(s.IpinfoToken, "")
if s.MinServers == nil {
s.MinServers = make(map[string]uint)
}
}
var (
errProtonEmailRequired = errors.New("proton email is required for updating ProtonVPN servers")
errProtonPasswordRequired = errors.New("proton password is required for updating ProtonVPN servers")
errMinServersProviderNotFound = errors.New("provider name in MinServers not found in list of all providers")
)
func (s *UpdateAllSettings) validate() error {
switch {
case *s.ProtonEmail == "":
return fmt.Errorf("%w", errProtonEmailRequired)
case *s.ProtonPassword == "":
return fmt.Errorf("%w", errProtonPasswordRequired)
}
allProviders := providers.All()
for providerName := range s.MinServers {
if !slices.Contains(allProviders, providerName) {
return fmt.Errorf("%w: %s", errMinServersProviderNotFound, providerName)
}
}
return nil
}
type Logger interface {
Warn(message string)
}
var errUpdateAllFailed = errors.New("update failed for one or more providers")
// UpdateAll fetches server data for all providers in parallel and writes each
// provider's data as <provider-name>.json into the output path specified.
// Errors for individual providers are collected and returned as a combined error.
func UpdateAll(ctx context.Context, client *http.Client, logger Logger, settings UpdateAllSettings) error {
settings.setDefaults()
err := settings.validate()
if err != nil {
return fmt.Errorf("validating settings: %w", err)
}
const permission = 0o755
err = os.MkdirAll(*settings.OutputPath, permission)
if err != nil {
return fmt.Errorf("creating output directory: %w", err)
}
dohDialer, err := doh.New(doh.Settings{
UpstreamResolvers: []dnsprovider.Provider{
dnsprovider.Cloudflare(),
dnsprovider.Google(),
},
})
if err != nil {
return fmt.Errorf("creating updater DoH dialer: %w", err)
}
parallelResolver := resolver.NewParallelResolver(dohDialer)
unzipper := unzip.New(client)
ipFetcher, err := buildIPFetcher(client, logger, *settings.IpinfoToken)
if err != nil {
return fmt.Errorf("creating IP fetcher: %w", err)
}
fetchers := buildFetchers(client, parallelResolver, unzipper, ipFetcher,
logger, *settings.ProtonEmail, *settings.ProtonPassword)
results := make(chan error)
allProviders := providers.All()
for _, providerName := range allProviders {
go func(providerName string) {
fetcher := fetchers[providerName]
minServers := settings.MinServers[providerName]
err := fetchAndWrite(ctx, providerName, fetcher, minServers, *settings.OutputPath)
if err != nil {
err = fmt.Errorf("provider %s: %w", providerName, err)
}
results <- err
}(providerName)
}
errs := make([]string, 0, len(allProviders))
for range allProviders {
err := <-results
if err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return fmt.Errorf("%w:\n%s", errUpdateAllFailed, strings.Join(errs, "\n"))
}
return nil
}
func buildIPFetcher(client *http.Client, logger Logger, ipinfoToken string) (
fetcher *publicipapi.ResilientFetcher, err error,
) {
nameTokenPairs := []publicipapi.NameToken{
{Name: string(publicipapi.IPInfo), Token: ipinfoToken},
{Name: string(publicipapi.IP2Location)},
{Name: string(publicipapi.IfConfigCo)},
}
fetchers, err := publicipapi.New(nameTokenPairs, client)
if err != nil {
return nil, fmt.Errorf("creating public IP fetchers: %w", err)
}
fetcher = publicipapi.NewResilient(fetchers, logger)
return fetcher, nil
}
func buildFetchers(client *http.Client, parallelResolver common.ParallelResolver,
unzipper common.Unzipper, ipFetcher common.IPFetcher,
logger Logger, protonEmail, protonPassword string,
) map[string]common.Fetcher {
return map[string]common.Fetcher{
providers.Airvpn: airvpn.New(client),
providers.Cyberghost: cyberghost.New(parallelResolver, logger),
providers.Expressvpn: expressvpn.New(unzipper, logger, parallelResolver),
providers.Fastestvpn: fastestvpn.New(client, logger, parallelResolver),
providers.Giganews: giganews.New(unzipper, logger, parallelResolver),
providers.HideMyAss: hidemyass.New(client, logger, parallelResolver),
providers.Ipvanish: ipvanish.New(unzipper, logger, parallelResolver),
providers.Ivpn: ivpn.New(client, logger, parallelResolver),
providers.Mullvad: mullvad.New(client),
providers.Nordvpn: nordvpn.New(client, logger),
providers.Perfectprivacy: perfectprivacy.New(unzipper, logger),
providers.Privado: privado.New(client, logger),
providers.PrivateInternetAccess: privateinternetaccess.New(client),
providers.Privatevpn: privatevpn.New(unzipper, logger, parallelResolver),
providers.Protonvpn: protonvpn.New(client, logger, protonEmail, protonPassword),
providers.Purevpn: purevpn.New(ipFetcher, unzipper, logger, parallelResolver),
providers.SlickVPN: slickvpn.New(client, logger, parallelResolver),
providers.Surfshark: surfshark.New(client, unzipper, logger, parallelResolver),
providers.Torguard: torguard.New(unzipper, logger, parallelResolver),
providers.VPNSecure: vpnsecure.New(client, logger, parallelResolver),
providers.VPNUnlimited: vpnunlimited.New(unzipper, logger, parallelResolver),
providers.Vyprvpn: vyprvpn.New(unzipper, logger, parallelResolver),
providers.Windscribe: windscribe.New(client, logger),
}
}
func fetchAndWrite(ctx context.Context, providerName string, fetcher common.Fetcher,
minServers uint, outputDirPath string,
) error {
filename := strings.ToLower(strings.ReplaceAll(providerName, " ", "")) + ".json"
destinationPath := filepath.Join(outputDirPath, filename)
const permission = 0o644
file, err := os.OpenFile(destinationPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, permission)
if err != nil {
return fmt.Errorf("opening output file: %w", err)
}
servers, err := fetcher.FetchServers(ctx, int(minServers)) //nolint:gosec
if err != nil {
_ = file.Close()
_ = os.Remove(destinationPath)
return fmt.Errorf("fetching servers: %w", err)
}
data := models.Servers{
Version: fetcher.Version(),
Timestamp: time.Now().Unix(),
Servers: servers,
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
err = encoder.Encode(data)
if err != nil {
_ = file.Close()
return fmt.Errorf("encoding servers to JSON: %w", err)
}
err = file.Close()
if err != nil {
return fmt.Errorf("closing output file: %w", err)
}
return nil
}