mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
chore: do not use sentinel errors when unneeded
- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly - only use sentinel errors when it has to be checked using `errors.Is` - replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error - exclude the sentinel error definition requirement from .golangci.yml - update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package alpine
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -9,8 +8,6 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var ErrUserAlreadyExists = errors.New("user already exists")
|
||||
|
||||
// CreateUser creates a user in Alpine with the given UID.
|
||||
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
|
||||
UIDStr := strconv.Itoa(uid)
|
||||
@@ -34,8 +31,8 @@ func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, e
|
||||
}
|
||||
|
||||
if u != nil {
|
||||
return "", fmt.Errorf("%w: with name %s for ID %s instead of %d",
|
||||
ErrUserAlreadyExists, username, u.Uid, uid)
|
||||
return "", fmt.Errorf("user already exists: with name %s for ID %s instead of %d",
|
||||
username, u.Uid, uid)
|
||||
}
|
||||
|
||||
const permission = fs.FileMode(0o644)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -28,7 +29,7 @@ func Test_New(t *testing.T) {
|
||||
PrivateKey: "",
|
||||
},
|
||||
},
|
||||
err: wireguard.ErrPrivateKeyMissing,
|
||||
err: errors.New("private key is missing"),
|
||||
},
|
||||
"minimal valid settings": {
|
||||
settings: Settings{
|
||||
|
||||
@@ -13,11 +13,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
var (
|
||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||
errDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||
@@ -52,8 +47,7 @@ func setupUserspace(ctx context.Context,
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||
} else if tunName != interfaceName {
|
||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||
errTunNameMismatch, interfaceName, tunName)
|
||||
return 0, nil, fmt.Errorf("TUN device name is mismatching: expected %q and got %q", interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
@@ -106,7 +100,7 @@ func setupUserspace(ctx context.Context,
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = errDeviceWaited
|
||||
err = errors.New("device waited for")
|
||||
}
|
||||
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
@@ -16,11 +16,6 @@ import (
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
||||
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
|
||||
)
|
||||
|
||||
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
|
||||
provider string, titleCaser cases.Caser,
|
||||
) {
|
||||
@@ -65,11 +60,10 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
}
|
||||
switch len(providers) {
|
||||
case 0:
|
||||
return fmt.Errorf("%w", ErrProviderUnspecified)
|
||||
return errors.New("VPN provider to format was not specified")
|
||||
case 1:
|
||||
default:
|
||||
return fmt.Errorf("%w: %d specified: %s",
|
||||
ErrMultipleProvidersToFormat, len(providers),
|
||||
return fmt.Errorf("more than one VPN provider to format were specified: %d specified: %s", len(providers),
|
||||
strings.Join(providers, ", "))
|
||||
}
|
||||
|
||||
|
||||
@@ -24,13 +24,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/updater/unzip"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
|
||||
ErrNoProviderSpecified = errors.New("no provider was specified")
|
||||
ErrUsernameMissing = errors.New("username is required for this provider")
|
||||
ErrPasswordMissing = errors.New("password is required for this provider")
|
||||
)
|
||||
|
||||
type UpdaterLogger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
@@ -65,14 +58,14 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
}
|
||||
|
||||
if !endUserMode && !maintainerMode {
|
||||
return fmt.Errorf("%w", ErrModeUnspecified)
|
||||
return errors.New("at least one of -enduser or -maintainer must be specified")
|
||||
}
|
||||
|
||||
if updateAll {
|
||||
options.Providers = providers.All()
|
||||
} else {
|
||||
if csvProviders == "" {
|
||||
return fmt.Errorf("%w", ErrNoProviderSpecified)
|
||||
return errors.New("no provider was specified")
|
||||
}
|
||||
options.Providers = strings.Split(csvProviders, ",")
|
||||
}
|
||||
|
||||
@@ -8,13 +8,6 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
errCommandEmpty = errors.New("command is empty")
|
||||
errSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
|
||||
errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
|
||||
errEscapeUnterminated = errors.New("unterminated backslash-escape")
|
||||
)
|
||||
|
||||
// split splits a command string into a slice of arguments.
|
||||
// This is especially important for commands such as:
|
||||
// /bin/sh -c "echo hello"
|
||||
@@ -25,7 +18,7 @@ var (
|
||||
// - expansion (brace, shell or pathname).
|
||||
func split(command string) (words []string, err error) {
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("%w", errCommandEmpty)
|
||||
return nil, errors.New("command is empty")
|
||||
}
|
||||
|
||||
const bufferSize = 1024
|
||||
@@ -42,7 +35,7 @@ func split(command string) (words []string, err error) {
|
||||
case character == '\\':
|
||||
// Look ahead to eventually skip an escaped newline
|
||||
if command[startIndex+runeSize:] == "" {
|
||||
return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command)
|
||||
return nil, fmt.Errorf("unterminated backslash-escape: %q", command)
|
||||
}
|
||||
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
|
||||
if character == '\n' {
|
||||
@@ -119,7 +112,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
startIndex = cursor
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated)
|
||||
return "", 0, errors.New("unterminated double-quoted string")
|
||||
}
|
||||
|
||||
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
@@ -127,7 +120,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
) {
|
||||
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
|
||||
if closingQuoteIndex == -1 {
|
||||
return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated)
|
||||
return "", 0, errors.New("unterminated single-quoted string")
|
||||
}
|
||||
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
|
||||
const singleQuoteRuneLength = 1
|
||||
@@ -139,7 +132,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
word string, newStartIndex int, err error,
|
||||
) {
|
||||
if input[startIndex:] == "" {
|
||||
return "", 0, fmt.Errorf("%w", errEscapeUnterminated)
|
||||
return "", 0, errors.New("unterminated backslash-escape")
|
||||
}
|
||||
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
|
||||
if character != '\n' { // backslash-escaped newline is ignored
|
||||
|
||||
@@ -12,12 +12,10 @@ func Test_split(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
command string
|
||||
words []string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
command: "",
|
||||
errWrapped: errCommandEmpty,
|
||||
errMessage: "command is empty",
|
||||
},
|
||||
"concrete_sh_command": {
|
||||
@@ -74,22 +72,18 @@ func Test_split(t *testing.T) {
|
||||
},
|
||||
"unterminated_single_quote": {
|
||||
command: "'abc'\\''def",
|
||||
errWrapped: errSingleQuoteUnterminated,
|
||||
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
|
||||
},
|
||||
"unterminated_double_quote": {
|
||||
command: "\"abc'def",
|
||||
errWrapped: errDoubleQuoteUnterminated,
|
||||
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
|
||||
},
|
||||
"unterminated_escape": {
|
||||
command: "abc\\",
|
||||
errWrapped: errEscapeUnterminated,
|
||||
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
|
||||
},
|
||||
"unterminated_escape_only": {
|
||||
command: " \\",
|
||||
errWrapped: errEscapeUnterminated,
|
||||
errMessage: `unterminated backslash-escape: " \\"`,
|
||||
},
|
||||
}
|
||||
@@ -101,9 +95,10 @@ func Test_split(t *testing.T) {
|
||||
words, err := split(testCase.command)
|
||||
|
||||
assert.Equal(t, testCase.words, words)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorContains(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -177,14 +176,6 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
return node
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
|
||||
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
||||
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
||||
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
||||
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
||||
)
|
||||
|
||||
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
const amneziaWG = true
|
||||
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
||||
@@ -194,16 +185,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
|
||||
if *a.JunkPacketCount == 0 {
|
||||
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
|
||||
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
} else {
|
||||
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
|
||||
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
||||
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
||||
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
|
||||
"jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,20 +213,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
case 1:
|
||||
_, err := strconv.Atoi(fields[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s value %s is not a number",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s is not a number", name, headerRange)
|
||||
}
|
||||
case 2: //nolint:mnd
|
||||
for _, field := range fields {
|
||||
_, err := strconv.Atoi(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s value %s is not a valid range",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s is not a valid range", name, headerRange)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s must be in the form n or n-m", name, headerRange)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -48,22 +47,15 @@ type DNS struct {
|
||||
UpstreamPlainAddresses []netip.AddrPort
|
||||
}
|
||||
|
||||
var (
|
||||
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
|
||||
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
|
||||
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
|
||||
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
|
||||
)
|
||||
|
||||
func (d DNS) validate() (err error) {
|
||||
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
|
||||
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
|
||||
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
|
||||
}
|
||||
|
||||
const minUpdatePeriod = 30 * time.Second
|
||||
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
|
||||
return fmt.Errorf("%w: %s must be bigger than %s",
|
||||
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
|
||||
return fmt.Errorf("update period is too short: %s must be bigger than %s",
|
||||
*d.UpdatePeriod, minUpdatePeriod)
|
||||
}
|
||||
|
||||
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||
@@ -81,9 +73,11 @@ func (d DNS) validate() (err error) {
|
||||
}
|
||||
switch {
|
||||
case *d.IPv6 && !selectedHasPlainIPv6:
|
||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
|
||||
return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
|
||||
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||
case !*d.IPv6 && !selectedHasPlainIPv4:
|
||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
|
||||
return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
|
||||
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||
}
|
||||
}
|
||||
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -37,22 +36,16 @@ func (b *DNSBlacklist) setDefaults() {
|
||||
|
||||
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
|
||||
|
||||
var (
|
||||
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
|
||||
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
|
||||
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
|
||||
)
|
||||
|
||||
func (b DNSBlacklist) validate() (err error) {
|
||||
for _, host := range b.AllowedHosts {
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host)
|
||||
return fmt.Errorf("allowed host is not valid: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
for _, host := range b.AddBlockedHosts {
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host)
|
||||
return fmt.Errorf("blocked host is not valid: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +54,7 @@ func (b DNSBlacklist) validate() (err error) {
|
||||
host = host[2:]
|
||||
}
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
|
||||
return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,8 +202,6 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
|
||||
return ips, ipPrefixes, nil
|
||||
}
|
||||
|
||||
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
|
||||
|
||||
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
||||
ipPrefixes []netip.Prefix, err error,
|
||||
) {
|
||||
@@ -236,8 +227,9 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf(
|
||||
"environment variable DOT_PRIVATE_ADDRESS: %w: %s",
|
||||
ErrPrivateAddressNotValid, privateAddress)
|
||||
"environment variable DOT_PRIVATE_ADDRESS: "+
|
||||
"private address is not a valid IP or CIDR range: %s",
|
||||
privateAddress)
|
||||
}
|
||||
|
||||
return ips, ipPrefixes, nil
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package settings
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrValueUnknown = errors.New("value is unknown")
|
||||
ErrCityNotValid = errors.New("the city specified is not valid")
|
||||
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
||||
ErrCategoryNotValid = errors.New("the category specified is not valid")
|
||||
ErrCountryNotValid = errors.New("the country specified is not valid")
|
||||
ErrFilepathMissing = errors.New("filepath is missing")
|
||||
ErrFirewallZeroPort = errors.New("cannot have a zero port")
|
||||
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
|
||||
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
||||
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
||||
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
|
||||
ErrMissingValue = errors.New("missing value")
|
||||
ErrNameNotValid = errors.New("the server name specified is not valid")
|
||||
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
|
||||
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
|
||||
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
|
||||
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
|
||||
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
|
||||
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
|
||||
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
|
||||
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
|
||||
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
|
||||
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
|
||||
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
|
||||
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
|
||||
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
|
||||
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
|
||||
ErrRegionNotValid = errors.New("the region specified is not valid")
|
||||
ErrServerAddressNotValid = errors.New("server listening address is not valid")
|
||||
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
|
||||
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
|
||||
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
|
||||
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
||||
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
|
||||
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
|
||||
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
||||
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
||||
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
||||
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
|
||||
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
|
||||
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
|
||||
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
|
||||
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
|
||||
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
|
||||
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
|
||||
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
|
||||
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
|
||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
||||
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
|
||||
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -20,16 +21,16 @@ type Firewall struct {
|
||||
|
||||
func (f Firewall) validate() (err error) {
|
||||
if hasZeroPort(f.VPNInputPorts) {
|
||||
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort)
|
||||
return errors.New("VPN input ports: cannot have a zero port")
|
||||
}
|
||||
|
||||
if hasZeroPort(f.InputPorts) {
|
||||
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
|
||||
return errors.New("input ports: cannot have a zero port")
|
||||
}
|
||||
|
||||
for _, subnet := range f.OutboundSubnets {
|
||||
if subnet.Addr().IsUnspecified() {
|
||||
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
|
||||
return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,25 +13,21 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
firewall Firewall
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errWrapped: log.ErrLevelNotRecognized,
|
||||
errMessage: "iptables settings: log level: level is not recognized: ",
|
||||
},
|
||||
"zero_vpn_input_port": {
|
||||
firewall: Firewall{
|
||||
VPNInputPorts: []uint16{0},
|
||||
},
|
||||
errWrapped: ErrFirewallZeroPort,
|
||||
errMessage: "VPN input ports: cannot have a zero port",
|
||||
},
|
||||
"zero_input_port": {
|
||||
firewall: Firewall{
|
||||
InputPorts: []uint16{0},
|
||||
},
|
||||
errWrapped: ErrFirewallZeroPort,
|
||||
errMessage: "input ports: cannot have a zero port",
|
||||
},
|
||||
"unspecified_outbound_subnet": {
|
||||
@@ -40,7 +36,6 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
errWrapped: ErrFirewallPublicOutboundSubnet,
|
||||
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
|
||||
},
|
||||
"public_outbound_subnet": {
|
||||
@@ -70,9 +65,10 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
|
||||
err := testCase.firewall.validate()
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -38,12 +38,6 @@ type Health struct {
|
||||
RestartVPN *bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
|
||||
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
|
||||
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
|
||||
)
|
||||
|
||||
func (h Health) Validate() (err error) {
|
||||
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
||||
if err != nil {
|
||||
@@ -53,16 +47,16 @@ func (h Health) Validate() (err error) {
|
||||
for _, ip := range h.ICMPTargetIPs {
|
||||
switch {
|
||||
case !ip.IsValid():
|
||||
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
|
||||
return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
|
||||
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
|
||||
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
|
||||
ErrICMPTargetIPsNotCompatible)
|
||||
return errors.New("ICMP target IP addresses are not compatible: " +
|
||||
"only a single IP address must be set if it is to be unspecified")
|
||||
}
|
||||
}
|
||||
|
||||
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
|
||||
return fmt.Errorf("small check type is not valid: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
|
||||
// Do not validate user and password
|
||||
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
|
||||
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -176,7 +176,6 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
|
||||
case "disabled", "no", "off":
|
||||
return ptrTo(false), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s",
|
||||
ErrValueUnknown, value)
|
||||
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package settings
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -92,7 +93,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
// Validate version
|
||||
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
|
||||
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err)
|
||||
return fmt.Errorf("version is not valid: %w", err)
|
||||
}
|
||||
|
||||
isCustom := vpnProvider == providers.Custom
|
||||
@@ -101,14 +102,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
vpnProvider != providers.VPNSecure
|
||||
|
||||
if isUserRequired && *o.User == "" {
|
||||
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty)
|
||||
return errors.New("user is empty")
|
||||
}
|
||||
|
||||
passwordRequired := isUserRequired &&
|
||||
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
|
||||
|
||||
if passwordRequired && *o.Password == "" {
|
||||
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty)
|
||||
return errors.New("password is empty")
|
||||
}
|
||||
|
||||
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
||||
@@ -132,23 +133,20 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
}
|
||||
|
||||
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
|
||||
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty)
|
||||
return errors.New("key passphrase is empty")
|
||||
}
|
||||
|
||||
const maxMSSFix = 10000
|
||||
if *o.MSSFix > maxMSSFix {
|
||||
return fmt.Errorf("%w: %d is over the maximum value of %d",
|
||||
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
|
||||
return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
|
||||
}
|
||||
|
||||
if !regexpInterfaceName.MatchString(o.Interface) {
|
||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
|
||||
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
|
||||
}
|
||||
|
||||
if *o.Verbosity < 0 || *o.Verbosity > 6 {
|
||||
return fmt.Errorf("%w: %d can only be between 0 and 5",
|
||||
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
|
||||
return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -162,7 +160,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
|
||||
}
|
||||
|
||||
if confFile == "" {
|
||||
return fmt.Errorf("%w", ErrFilepathMissing)
|
||||
return errors.New("filepath is missing")
|
||||
}
|
||||
|
||||
err = validate.FileExists(confFile)
|
||||
@@ -189,7 +187,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
|
||||
providers.VPNSecure,
|
||||
providers.VPNUnlimited:
|
||||
if clientCert == "" {
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
return errors.New("missing value")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +209,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
|
||||
providers.Cyberghost,
|
||||
providers.VPNUnlimited:
|
||||
if clientKey == "" {
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
return errors.New("missing value")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,7 +228,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
|
||||
encryptedPrivateKey string,
|
||||
) (err error) {
|
||||
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
return errors.New("missing value")
|
||||
}
|
||||
|
||||
if encryptedPrivateKey == "" {
|
||||
|
||||
@@ -62,8 +62,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
providers.Perfectprivacy,
|
||||
providers.Vyprvpn,
|
||||
) {
|
||||
return fmt.Errorf("%w: for VPN service provider %s",
|
||||
ErrOpenVPNTCPNotSupported, vpnProvider)
|
||||
return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
|
||||
}
|
||||
|
||||
// Validate CustomPort
|
||||
@@ -78,8 +77,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
providers.Nordvpn, providers.Purevpn,
|
||||
providers.Surfshark, providers.VPNSecure,
|
||||
providers.VPNUnlimited, providers.Vyprvpn:
|
||||
return fmt.Errorf("%w: for VPN service provider %s",
|
||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
|
||||
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
|
||||
default:
|
||||
var allowedTCP, allowedUDP []uint16
|
||||
switch vpnProvider {
|
||||
@@ -123,8 +121,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
}
|
||||
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
|
||||
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -136,7 +133,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
presets.Strong,
|
||||
}
|
||||
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err)
|
||||
return fmt.Errorf("PIA encryption preset is not valid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -24,21 +23,16 @@ type PMTUD struct {
|
||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||
)
|
||||
|
||||
// Validate validates PMTUD settings.
|
||||
func (p PMTUD) validate() (err error) {
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||
return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
|
||||
}
|
||||
}
|
||||
for i, addr := range p.TCPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||
return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -55,12 +55,6 @@ type PortForwarding struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
||||
)
|
||||
|
||||
func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
if !*p.Enabled {
|
||||
return nil
|
||||
@@ -78,7 +72,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
providers.Protonvpn,
|
||||
}
|
||||
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
|
||||
return fmt.Errorf("port forwarding cannot be enabled: %w", err)
|
||||
}
|
||||
|
||||
// Validate Filepath
|
||||
@@ -94,30 +88,31 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
const maxPortsCount = 1
|
||||
switch {
|
||||
case p.PortsCount > maxPortsCount:
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
case p.Username == "":
|
||||
return fmt.Errorf("%w", ErrPortForwardingUserEmpty)
|
||||
return errors.New("port forwarding username is empty")
|
||||
case p.Password == "":
|
||||
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty)
|
||||
return errors.New("port forwarding password is empty")
|
||||
}
|
||||
case providers.Protonvpn:
|
||||
const maxPortsCount = 4
|
||||
if p.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
}
|
||||
default:
|
||||
const maxPortsCount = 1
|
||||
if p.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
|
||||
switch {
|
||||
case len(p.ListeningPorts) != int(p.PortsCount):
|
||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount)
|
||||
return fmt.Errorf("listening ports length must be equal to ports count: "+
|
||||
"%d != %d", len(p.ListeningPorts), p.PortsCount)
|
||||
case slices.Contains(p.ListeningPorts, 0):
|
||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts)
|
||||
return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
}
|
||||
}
|
||||
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
||||
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
|
||||
return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||
|
||||
@@ -15,7 +15,6 @@ func Test_PublicIP_read(t *testing.T) {
|
||||
makeReader func(ctrl *gomock.Controller) *reader.Reader
|
||||
makeWarner func(ctrl *gomock.Controller) Warner
|
||||
settings PublicIP
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"nothing_read": {
|
||||
@@ -152,9 +151,10 @@ func Test_PublicIP_read(t *testing.T) {
|
||||
err := settings.read(reader, warner)
|
||||
|
||||
assert.Equal(t, testCase.settings, settings)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -46,8 +46,7 @@ func (c ControlServer) validate() (err error) {
|
||||
uid := os.Getuid()
|
||||
const maxPrivilegedPort = 1023
|
||||
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
|
||||
return fmt.Errorf("%w: %d when running with user ID %d",
|
||||
ErrControlServerPrivilegedPort, port, uid)
|
||||
return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
|
||||
}
|
||||
|
||||
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
|
||||
|
||||
@@ -71,25 +71,13 @@ type ServerSelection struct {
|
||||
Wireguard WireguardSelection `json:"wireguard"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
|
||||
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
|
||||
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
|
||||
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
|
||||
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
|
||||
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
|
||||
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
|
||||
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
|
||||
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
|
||||
)
|
||||
|
||||
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
||||
) (err error) {
|
||||
switch ss.VPN {
|
||||
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||
return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
|
||||
}
|
||||
|
||||
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
|
||||
@@ -150,7 +138,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
|
||||
// Only return error comparing with newer regions, we don't want to confuse the user
|
||||
// with the retro regions in the error message.
|
||||
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
|
||||
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||
return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,27 +152,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
) (err error) {
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
|
||||
return fmt.Errorf("the country specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||
return fmt.Errorf("the region specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
|
||||
return fmt.Errorf("the city specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
|
||||
return fmt.Errorf("the ISP specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||
return fmt.Errorf("the hostname specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
if vpnServiceProvider == providers.Custom {
|
||||
@@ -196,19 +184,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
// which requires a server name for TLS verification.
|
||||
filterChoices.Names = settings.Names
|
||||
default:
|
||||
return fmt.Errorf("%w: %d names specified instead of "+
|
||||
"0 or 1 for the custom provider",
|
||||
ErrNameNotValid, len(settings.Names))
|
||||
return fmt.Errorf("name is not valid: "+
|
||||
"%d names specified instead of 0 or 1 for the custom provider",
|
||||
len(settings.Names))
|
||||
}
|
||||
}
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
|
||||
return fmt.Errorf("the server name specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
|
||||
return fmt.Errorf("the category specified is not valid: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -255,12 +243,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
||||
switch {
|
||||
case *settings.FreeOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return fmt.Errorf("%w", ErrFreeOnlyNotSupported)
|
||||
return errors.New("free only filter is not supported")
|
||||
case *settings.PremiumOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
|
||||
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported)
|
||||
return errors.New("premium only filter is not supported")
|
||||
case *settings.FreeOnly && *settings.PremiumOnly:
|
||||
return fmt.Errorf("%w", ErrFreePremiumBothSet)
|
||||
return errors.New("free only and premium only filters are both set")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -269,21 +257,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
||||
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
|
||||
switch {
|
||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
||||
return errors.New("owned only filter is not supported")
|
||||
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
||||
return errors.New("port forwarding only filter is not supported: together with free only filter")
|
||||
case *settings.StreamOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
||||
return errors.New("stream only filter is not supported")
|
||||
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
|
||||
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported)
|
||||
return errors.New("multi hop only filter is not supported")
|
||||
case *settings.PortForwardOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
|
||||
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported)
|
||||
return errors.New("port forwarding only filter is not supported")
|
||||
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
|
||||
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported)
|
||||
return errors.New("secure core only filter is not supported")
|
||||
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
|
||||
return fmt.Errorf("%w", ErrTorOnlyNotSupported)
|
||||
return errors.New("tor only filter is not supported")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -37,20 +38,20 @@ type Updater struct {
|
||||
func (u Updater) Validate() (err error) {
|
||||
const minPeriod = time.Minute
|
||||
if *u.Period > 0 && *u.Period < minPeriod {
|
||||
return fmt.Errorf("%w: %d must be larger than %s",
|
||||
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
|
||||
return fmt.Errorf("VPN server data updater period is too small: "+
|
||||
"%d must be larger than %s", *u.Period, minPeriod)
|
||||
}
|
||||
|
||||
if u.MinRatio <= 0 || u.MinRatio > 1 {
|
||||
return fmt.Errorf("%w: %.2f must be between 0+ and 1",
|
||||
ErrMinRatioNotValid, u.MinRatio)
|
||||
return fmt.Errorf("minimum ratio is not valid: "+
|
||||
"%.2f must be between 0+ and 1", u.MinRatio)
|
||||
}
|
||||
|
||||
validProviders := providers.All()
|
||||
for _, provider := range u.Providers {
|
||||
err = validate.IsOneOf(provider, validProviders...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
|
||||
return fmt.Errorf("VPN provider name is not valid: %w", err)
|
||||
}
|
||||
|
||||
if provider == providers.Protonvpn {
|
||||
@@ -58,9 +59,9 @@ func (u Updater) Validate() (err error) {
|
||||
if authenticatedAPI {
|
||||
switch {
|
||||
case *u.ProtonEmail == "":
|
||||
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
|
||||
return errors.New("proton email is missing")
|
||||
case *u.ProtonPassword == "":
|
||||
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
|
||||
return errors.New("proton password is missing")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
// Validate Type
|
||||
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
||||
return fmt.Errorf("VPN type is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
@@ -54,7 +55,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
||||
// Validate PrivateKey
|
||||
if *w.PrivateKey == "" {
|
||||
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
||||
return errors.New("private key is not set")
|
||||
}
|
||||
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -68,7 +69,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
|
||||
if vpnProvider == providers.Airvpn {
|
||||
if *w.PreSharedKey == "" {
|
||||
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet)
|
||||
return errors.New("pre-shared key is not set")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,17 +83,15 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
|
||||
// Validate Addresses
|
||||
if len(w.Addresses) == 0 {
|
||||
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
|
||||
return errors.New("interface address is not set")
|
||||
}
|
||||
for i, ipNet := range w.Addresses {
|
||||
if !ipNet.IsValid() {
|
||||
return fmt.Errorf("%w: for address at index %d",
|
||||
ErrWireguardInterfaceAddressNotSet, i)
|
||||
return fmt.Errorf("interface address is not set: for address at index %d", i)
|
||||
}
|
||||
|
||||
if !ipv6Supported && ipNet.Addr().Is6() {
|
||||
return fmt.Errorf("%w: address %s",
|
||||
ErrWireguardInterfaceAddressIPv6, ipNet.String())
|
||||
return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,30 +99,27 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
// WARNING: do not check for IPv6 networks in the allowed IPs,
|
||||
// the wireguard code will take care to ignore it.
|
||||
if len(w.AllowedIPs) == 0 {
|
||||
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
|
||||
return errors.New("allowed IPs is not set")
|
||||
}
|
||||
for i, allowedIP := range w.AllowedIPs {
|
||||
if !allowedIP.IsValid() {
|
||||
return fmt.Errorf("%w: for allowed ip %d of %d",
|
||||
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
|
||||
return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
|
||||
}
|
||||
}
|
||||
|
||||
if *w.PersistentKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
|
||||
*w.PersistentKeepaliveInterval)
|
||||
return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
// Validate interface
|
||||
if !regexpInterfaceName.MatchString(w.Interface) {
|
||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
||||
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
|
||||
}
|
||||
|
||||
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
||||
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
||||
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
||||
return fmt.Errorf("implementation is not valid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -44,7 +45,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// endpoint IP addresses are baked in
|
||||
case providers.Custom:
|
||||
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet)
|
||||
return errors.New("endpoint IP is not set")
|
||||
}
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
@@ -54,13 +55,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// EndpointPort is required
|
||||
case providers.Custom:
|
||||
if *w.EndpointPort == 0 {
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
|
||||
return errors.New("endpoint port is not set")
|
||||
}
|
||||
// EndpointPort cannot be set
|
||||
case providers.Fastestvpn, providers.Nordvpn,
|
||||
providers.Protonvpn, providers.Surfshark:
|
||||
if *w.EndpointPort != 0 {
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
|
||||
return errors.New("endpoint port is set")
|
||||
}
|
||||
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
|
||||
// EndpointPort is optional and can be 0
|
||||
@@ -84,8 +85,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
||||
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
|
||||
return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
|
||||
@@ -96,15 +96,14 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// public keys are baked in
|
||||
case providers.Custom:
|
||||
if w.PublicKey == "" {
|
||||
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet)
|
||||
return errors.New("public key is not set")
|
||||
}
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
if w.PublicKey != "" {
|
||||
_, err := wgtypes.ParseKey(w.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s: %s",
|
||||
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
|
||||
return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,8 +74,6 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
|
||||
return privateKey, addresses
|
||||
}
|
||||
|
||||
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
|
||||
|
||||
func parseWireguardPeerSection(peerSection *ini.Section) (
|
||||
preSharedKey, publicKey, endpointIP, endpointPort *string,
|
||||
) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
@@ -63,8 +62,6 @@ func generateRandomString(length uint) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var errIPLeakSessionMismatch = errors.New("ipleak.net session mismatch")
|
||||
|
||||
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
||||
dnsToCount map[string]uint, err error,
|
||||
) {
|
||||
@@ -93,7 +90,7 @@ func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
} else if data.Session != session {
|
||||
return nil, fmt.Errorf("%w: expected %s, got %s", errIPLeakSessionMismatch, session, data.Session)
|
||||
return nil, fmt.Errorf("ipleak.net session mismatch: expected %s, got %s", session, data.Session)
|
||||
}
|
||||
|
||||
return data.IP, nil
|
||||
|
||||
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
Return("", errors.New("test error"))
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
return fmt.Errorf("policy is not valid: %s", policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
|
||||
@@ -2,7 +2,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@@ -13,10 +12,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
const (
|
||||
needIP6Tables = "ip6tables is required, please upgrade your kernel"
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
return "", fmt.Errorf("iptables version string is too short: %s", output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
}
|
||||
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
return fmt.Errorf("unknown policy: %s", policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
if ip.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstructionNoSave(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
||||
}
|
||||
|
||||
@@ -40,8 +40,6 @@ type mark struct {
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
|
||||
iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
|
||||
legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
|
||||
expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
|
||||
expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
return chainRule{}, errors.New("chain rule is malformed: empty line")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
|
||||
optionalFields[1])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
|
||||
optionalFields[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
|
||||
optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
return 0, errors.New("line number is zero")
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var ErrTargetUnknown = errors.New("unknown target")
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
return fmt.Errorf("unknown target: %s", target)
|
||||
}
|
||||
|
||||
var ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0", "all":
|
||||
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
|
||||
case "17", "udp":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
return "", fmt.Errorf("unknown protocol: %s", s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
return n, errors.New("metric size is malformed: empty string")
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
return n, fmt.Errorf("metric size is malformed: %w", err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
|
||||
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
|
||||
flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
|
||||
flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
|
||||
fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
|
||||
fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,12 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
var ErrNotSupported = errors.New("no iptables supported found")
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
iptablesPathsToTry ...string,
|
||||
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
if allArePermissionDenied {
|
||||
// If the error is related to a denied permission for all iptables path,
|
||||
// return an error describing what to do from an end-user perspective.
|
||||
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
|
||||
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
// this is a critical error, we want to make sure our test rule gets removed.
|
||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
||||
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
}
|
||||
|
||||
if inputPolicy == "" {
|
||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
||||
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
iptablesPathsToTry []string
|
||||
iptablesPath string
|
||||
errSentinel error
|
||||
errMessage string
|
||||
}{
|
||||
"critical error when checking": {
|
||||
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrTestRuleCleanup,
|
||||
errMessage: "for path1: failed cleaning up test rule: " +
|
||||
"output (exit code 4)",
|
||||
},
|
||||
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNetAdminMissing,
|
||||
errMessage: "NET_ADMIN capability is missing: " +
|
||||
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
||||
"path2: context: Permission denied (you must be root) (exit code 4)",
|
||||
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
|
||||
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
||||
|
||||
require.ErrorIs(t, err, testCase.errSentinel)
|
||||
if testCase.errSentinel != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
||||
})
|
||||
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
ok bool
|
||||
unsupportedMessage string
|
||||
criticalErrWrapped error
|
||||
criticalErrMessage string
|
||||
}{
|
||||
"append test rule permission denied": {
|
||||
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrTestRuleCleanup,
|
||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||
},
|
||||
"list input rules permission denied": {
|
||||
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some\noutput", nil)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||
},
|
||||
"set policy permission denied": {
|
||||
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
||||
if testCase.criticalErrWrapped != nil {
|
||||
if testCase.criticalErrMessage != "" {
|
||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||
} else {
|
||||
assert.NoError(t, criticalErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
panic(fmt.Sprintf("unknown TCP flag: %d", f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
return 0, fmt.Errorf("unknown TCP flag: %s", s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
@@ -266,8 +266,6 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
|
||||
return address, nil
|
||||
}
|
||||
|
||||
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
||||
|
||||
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||
logger Logger, checkName string, check func(ctx context.Context, try int) error,
|
||||
) error {
|
||||
@@ -297,7 +295,7 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||
for i, err := range errs {
|
||||
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
|
||||
}
|
||||
return fmt.Errorf("%w:\n\t%s", ErrAllCheckTriesFailed, strings.Join(errStrings, "\n\t"))
|
||||
return fmt.Errorf("all check tries failed:\n\t%s", strings.Join(errStrings, "\n\t"))
|
||||
}
|
||||
|
||||
func (c *Checker) startupCheck(ctx context.Context) error {
|
||||
@@ -342,7 +340,7 @@ func (c *Checker) startupCheck(ctx context.Context) error {
|
||||
for i, err := range errs {
|
||||
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
||||
return fmt.Errorf("all check tries failed: %s", strings.Join(errStrings, ", "))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -2,7 +2,6 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -68,7 +67,7 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
address string
|
||||
addressToDial string
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"host without port": {
|
||||
address: "test.com",
|
||||
@@ -79,8 +78,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
addressToDial: "test.com:80",
|
||||
},
|
||||
"bad address": {
|
||||
address: "test.com::",
|
||||
err: fmt.Errorf("splitting host and port from address: address test.com::: too many colons in address"), //nolint:lll
|
||||
address: "test.com::",
|
||||
errMessage: "splitting host and port from address: address test.com::: too many colons in address",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -91,8 +90,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
addressToDial, err := makeAddressToDial(testCase.address)
|
||||
|
||||
assert.Equal(t, testCase.addressToDial, addressToDial)
|
||||
if testCase.err != nil {
|
||||
assert.EqualError(t, err, testCase.err.Error())
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,15 +2,12 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
|
||||
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
@@ -41,6 +38,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
|
||||
return fmt.Errorf("HTTP response status is not OK: %d %s: %s",
|
||||
response.StatusCode, response.Status, string(b))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -41,8 +40,6 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
|
||||
return result
|
||||
}
|
||||
|
||||
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
|
||||
|
||||
func (c *Client) Check(ctx context.Context) error {
|
||||
dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
|
||||
resolver := &net.Resolver{
|
||||
@@ -59,7 +56,7 @@ func (c *Client) Check(ctx context.Context) error {
|
||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
|
||||
case len(ips) == 0:
|
||||
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs)
|
||||
return fmt.Errorf("with DNS server %s: no IPs found from DNS lookup", dnsAddr)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,11 +12,9 @@ type handler struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
|
||||
|
||||
func newHandler(logger Logger) *handler {
|
||||
return &handler{
|
||||
healthErr: errHealthcheckNotRunYet,
|
||||
healthErr: errors.New("healthcheck did not run yet"),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,11 +19,6 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
)
|
||||
|
||||
type Echoer struct {
|
||||
buffer []byte
|
||||
randomSource io.Reader
|
||||
@@ -60,10 +55,7 @@ func (e *Echoer) Reset() {
|
||||
e.seqStart = time.Now()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
|
||||
ErrNotPermitted = errors.New("not permitted")
|
||||
)
|
||||
var ErrNotPermitted = errors.New("not permitted")
|
||||
|
||||
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||
var ipVersion string
|
||||
@@ -114,14 +106,14 @@ func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
||||
return fmt.Errorf("%w from %s", ErrTimedOut, ip)
|
||||
return fmt.Errorf("timed out waiting for ICMP echo reply from %s", ip)
|
||||
}
|
||||
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
|
||||
}
|
||||
|
||||
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
||||
if !bytes.Equal(receivedData, sentData) {
|
||||
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData)
|
||||
return fmt.Errorf("ICMP data mismatch: sent %x to %s and received %x", sentData, ip, receivedData)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -216,8 +208,9 @@ func receiveEchoReply(conn net.PacketConn, id, seq int, buffer []byte, ipVersion
|
||||
message.Code, returnAddr, id, seq)
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
||||
ErrICMPBodyUnsupported, body, message.Type, message.Code, returnAddr, id, seq)
|
||||
return nil, fmt.Errorf("ICMP body type is not supported: "+
|
||||
"%T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
||||
body, message.Type, message.Code, returnAddr, id, seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
|
||||
@@ -20,11 +19,9 @@ func Test_New(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
expected *Server
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty settings": {
|
||||
errWrapped: ErrHandlerIsNotSet,
|
||||
errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
|
||||
},
|
||||
"filled settings": {
|
||||
@@ -52,9 +49,10 @@ func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, err := New(testCase.settings)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
require.EqualError(t, err, testCase.errMessage)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
|
||||
@@ -64,14 +64,6 @@ func (s *Settings) OverrideWith(other Settings) {
|
||||
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset")
|
||||
ErrLoggerIsNotSet = errors.New("logger cannot be left unset")
|
||||
ErrReadHeaderTimeoutTooSmall = errors.New("read header timeout is too small")
|
||||
ErrReadTimeoutTooSmall = errors.New("read timeout is too small")
|
||||
ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small")
|
||||
)
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
err = validate.ListeningAddress(s.Address, os.Getuid())
|
||||
if err != nil {
|
||||
@@ -79,31 +71,25 @@ func (s Settings) Validate() (err error) {
|
||||
}
|
||||
|
||||
if s.Handler == nil {
|
||||
return fmt.Errorf("%w", ErrHandlerIsNotSet)
|
||||
return errors.New("HTTP handler cannot be left unset")
|
||||
}
|
||||
|
||||
if s.Logger == nil {
|
||||
return fmt.Errorf("%w", ErrLoggerIsNotSet)
|
||||
return errors.New("logger cannot be left unset")
|
||||
}
|
||||
|
||||
const minReadTimeout = time.Millisecond
|
||||
if s.ReadHeaderTimeout < minReadTimeout {
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrReadHeaderTimeoutTooSmall,
|
||||
s.ReadHeaderTimeout, minReadTimeout)
|
||||
return fmt.Errorf("read header timeout is too small: %s must be at least %s", s.ReadHeaderTimeout, minReadTimeout)
|
||||
}
|
||||
|
||||
if s.ReadTimeout < minReadTimeout {
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrReadTimeoutTooSmall,
|
||||
s.ReadTimeout, minReadTimeout)
|
||||
return fmt.Errorf("read timeout is too small: %s must be at least %s", s.ReadTimeout, minReadTimeout)
|
||||
}
|
||||
|
||||
const minShutdownTimeout = 5 * time.Millisecond
|
||||
if s.ShutdownTimeout < minShutdownTimeout {
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrShutdownTimeoutTooSmall,
|
||||
s.ShutdownTimeout, minShutdownTimeout)
|
||||
return fmt.Errorf("shutdown timeout is too small: %s must be at least %s", s.ShutdownTimeout, minShutdownTimeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -189,30 +188,26 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"bad_address": {
|
||||
settings: Settings{
|
||||
Address: "address:notanint",
|
||||
},
|
||||
errWrapped: validate.ErrPortNotAnInteger,
|
||||
errMessage: "port value is not an integer: notanint",
|
||||
},
|
||||
"nil handler": {
|
||||
settings: Settings{
|
||||
Address: ":8000",
|
||||
},
|
||||
errWrapped: ErrHandlerIsNotSet,
|
||||
errMessage: ErrHandlerIsNotSet.Error(),
|
||||
errMessage: "HTTP handler cannot be left unset",
|
||||
},
|
||||
"nil logger": {
|
||||
settings: Settings{
|
||||
Address: ":8000",
|
||||
Handler: someHandler,
|
||||
},
|
||||
errWrapped: ErrLoggerIsNotSet,
|
||||
errMessage: ErrLoggerIsNotSet.Error(),
|
||||
errMessage: "logger cannot be left unset",
|
||||
},
|
||||
"read header timeout too small": {
|
||||
settings: Settings{
|
||||
@@ -221,7 +216,6 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
Logger: someLogger,
|
||||
ReadHeaderTimeout: time.Nanosecond,
|
||||
},
|
||||
errWrapped: ErrReadHeaderTimeoutTooSmall,
|
||||
errMessage: "read header timeout is too small: 1ns must be at least 1ms",
|
||||
},
|
||||
"read timeout too small": {
|
||||
@@ -232,7 +226,6 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
ReadHeaderTimeout: time.Millisecond,
|
||||
ReadTimeout: time.Nanosecond,
|
||||
},
|
||||
errWrapped: ErrReadTimeoutTooSmall,
|
||||
errMessage: "read timeout is too small: 1ns must be at least 1ms",
|
||||
},
|
||||
"shutdown timeout too small": {
|
||||
@@ -244,7 +237,6 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
ReadTimeout: time.Millisecond,
|
||||
ShutdownTimeout: time.Millisecond,
|
||||
},
|
||||
errWrapped: ErrShutdownTimeoutTooSmall,
|
||||
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
|
||||
},
|
||||
"valid settings": {
|
||||
@@ -265,9 +257,10 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
err := testCase.settings.Validate()
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,15 +2,12 @@ package loopstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var ErrInvalidStatus = errors.New("invalid status")
|
||||
|
||||
// ApplyStatus sends signals to the running loop depending on the
|
||||
// current status and status requested, such that its next status
|
||||
// matches the requested one. It is thread safe and a synchronous call
|
||||
@@ -73,7 +70,7 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
|
||||
return newStatus.String(), nil
|
||||
default:
|
||||
s.statusMu.Unlock()
|
||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||
ErrInvalidStatus, status, constants.Running, constants.Stopped)
|
||||
return "", fmt.Errorf("invalid status: %s: it can only be one of: %s, %s",
|
||||
status, constants.Running, constants.Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,11 @@ package mod
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errModuleNameUnknown = errors.New("unknown module name")
|
||||
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
|
||||
errKernelFeatureNotSet = errors.New("kernel feature not set")
|
||||
errKernelFeatureNotFound = errors.New("kernel feature not found")
|
||||
)
|
||||
|
||||
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
||||
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
||||
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
||||
@@ -39,7 +31,7 @@ func checkProcConfig(moduleName string) error {
|
||||
// If any group of kernel features is satisfied, then the module is considered supported.
|
||||
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
|
||||
return fmt.Errorf("unknown module name: %s", moduleName)
|
||||
}
|
||||
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
||||
for i, group := range kernelFeatureGroups {
|
||||
@@ -58,20 +50,20 @@ func checkProcConfig(moduleName string) error {
|
||||
switch {
|
||||
case ok:
|
||||
case strings.HasPrefix(line, name+"=m"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
|
||||
return fmt.Errorf("kernel feature is a module, not built-in: %s", name)
|
||||
case strings.HasPrefix(line, name+"=y"):
|
||||
featureToOK[name] = true
|
||||
if allFeaturesOK(featureToOK) {
|
||||
return nil
|
||||
}
|
||||
case strings.HasPrefix(line, "# "+name+" is not set"):
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
|
||||
return fmt.Errorf("kernel feature not set: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
|
||||
return fmt.Errorf("kernel feature not found: for module name %s", moduleName)
|
||||
}
|
||||
|
||||
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
||||
|
||||
@@ -181,8 +181,6 @@ func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrModulePathNotFound = errors.New("module path not found")
|
||||
|
||||
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
|
||||
// Kernel module names can have underscores or hyphens in their names,
|
||||
// but only one or the other in one particular name.
|
||||
@@ -205,5 +203,5 @@ func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modul
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName)
|
||||
return "", fmt.Errorf("module path not found: for %q", moduleName)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -14,15 +13,10 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModuleInfoNotFound = errors.New("module info not found")
|
||||
ErrCircularDependency = errors.New("circular dependency")
|
||||
)
|
||||
|
||||
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
|
||||
info, ok := modulesInfo[path]
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path)
|
||||
return fmt.Errorf("module info not found: %s", path)
|
||||
}
|
||||
|
||||
switch info.state {
|
||||
@@ -30,8 +24,7 @@ func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error
|
||||
case loaded, builtin:
|
||||
return nil
|
||||
case loading:
|
||||
return fmt.Errorf("%w: %s is already in the loading state",
|
||||
ErrCircularDependency, path)
|
||||
return fmt.Errorf("circular dependency: %s is already in the loading state", path)
|
||||
}
|
||||
|
||||
info.state = loading
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -109,8 +108,6 @@ func (s *Servers) toMarkdown(vpnProvider string) (formatted string, err error) {
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
var ErrMarkdownHeadersNotDefined = errors.New("markdown headers not defined")
|
||||
|
||||
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
||||
switch vpnProvider {
|
||||
case providers.Airvpn:
|
||||
@@ -169,6 +166,6 @@ func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
||||
case providers.Windscribe:
|
||||
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: for %s", ErrMarkdownHeadersNotDefined, vpnProvider)
|
||||
return nil, fmt.Errorf("markdown headers not defined: for %s", vpnProvider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,12 +15,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
||||
provider string
|
||||
servers Servers
|
||||
formatted string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"unsupported_provider": {
|
||||
provider: "unsupported",
|
||||
errWrapped: ErrMarkdownHeadersNotDefined,
|
||||
errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
|
||||
},
|
||||
providers.Cyberghost: {
|
||||
@@ -58,9 +56,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
||||
markdown, err := testCase.servers.toMarkdown(testCase.provider)
|
||||
|
||||
assert.Equal(t, testCase.formatted, markdown)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -38,27 +38,18 @@ type Server struct {
|
||||
IPs []netip.Addr `json:"ips,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNFieldEmpty = errors.New("vpn field is empty")
|
||||
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
|
||||
ErrIPsFieldEmpty = errors.New("ips field is empty")
|
||||
ErrNoNetworkProtocol = errors.New("both TCP and UDP fields are false for OpenVPN")
|
||||
ErrNetworkProtocolSet = errors.New("no network protocol should be set")
|
||||
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
|
||||
)
|
||||
|
||||
func (s *Server) HasMinimumInformation() (err error) {
|
||||
switch {
|
||||
case s.VPN == "":
|
||||
return fmt.Errorf("%w", ErrVPNFieldEmpty)
|
||||
return errors.New("vpn field is empty")
|
||||
case len(s.IPs) == 0:
|
||||
return fmt.Errorf("%w", ErrIPsFieldEmpty)
|
||||
return errors.New("ips field is empty")
|
||||
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
|
||||
return fmt.Errorf("%w", ErrNetworkProtocolSet)
|
||||
return errors.New("no network protocol should be set")
|
||||
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
|
||||
return fmt.Errorf("%w", ErrNoNetworkProtocol)
|
||||
return errors.New("both TCP and UDP fields are false for OpenVPN")
|
||||
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
|
||||
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
|
||||
return errors.New("wireguard public key field is empty")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
@@ -158,8 +157,6 @@ type Servers struct {
|
||||
Servers []Server `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
var ErrServersFormatNotSupported = errors.New("servers format not supported")
|
||||
|
||||
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
|
||||
switch format {
|
||||
case "markdown":
|
||||
@@ -167,7 +164,7 @@ func (s *Servers) Format(vpnProvider, format string) (formatted string, err erro
|
||||
case "json":
|
||||
return s.toJSON()
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrServersFormatNotSupported, format)
|
||||
return "", fmt.Errorf("servers format not supported: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
allServers *AllServers
|
||||
dataString string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no provider": {
|
||||
@@ -58,16 +57,18 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data, err := testCase.allServers.MarshalJSON()
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
data, err = json.Marshal(testCase.allServers)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
@@ -87,7 +88,6 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
dataString string
|
||||
allServers AllServers
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
@@ -131,9 +131,10 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
||||
|
||||
err := json.Unmarshal(data, &allServers)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, testCase.allServers, allServers)
|
||||
})
|
||||
|
||||
+16
-33
@@ -6,48 +6,40 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrRequestSizeTooSmall = errors.New("message size is too small")
|
||||
|
||||
func checkRequest(request []byte) (err error) {
|
||||
const minMessageSize = 2 // version number + operation code
|
||||
if len(request) < minMessageSize {
|
||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||
ErrRequestSizeTooSmall, minMessageSize, len(request))
|
||||
return fmt.Errorf("message size is too small: need at least %d bytes and got %d byte(s)",
|
||||
minMessageSize, len(request))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrResponseSizeTooSmall = errors.New("response size is too small")
|
||||
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
|
||||
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
|
||||
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
|
||||
)
|
||||
|
||||
func checkResponse(response []byte, expectedOperationCode byte,
|
||||
expectedResponseSize uint,
|
||||
) (err error) {
|
||||
const minResponseSize = 4
|
||||
if len(response) < minResponseSize {
|
||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||
ErrResponseSizeTooSmall, minResponseSize, len(response))
|
||||
return fmt.Errorf("response size is too small: "+
|
||||
"need at least %d bytes and got %d byte(s)",
|
||||
minResponseSize, len(response))
|
||||
}
|
||||
|
||||
if uint(len(response)) != expectedResponseSize {
|
||||
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
|
||||
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
|
||||
return fmt.Errorf("response size is unexpected: "+
|
||||
"expected %d bytes and got %d byte(s)",
|
||||
expectedResponseSize, len(response))
|
||||
}
|
||||
|
||||
protocolVersion := response[0]
|
||||
if protocolVersion != 0 {
|
||||
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
|
||||
return fmt.Errorf("protocol version is unknown: %d", protocolVersion)
|
||||
}
|
||||
|
||||
operationCode := response[1]
|
||||
if operationCode != expectedOperationCode {
|
||||
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
|
||||
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
|
||||
return fmt.Errorf("operation code is unexpected: expected 0x%x and got 0x%x", expectedOperationCode, operationCode)
|
||||
}
|
||||
|
||||
resultCode := binary.BigEndian.Uint16(response[2:4])
|
||||
@@ -59,15 +51,6 @@ func checkResponse(response []byte, expectedOperationCode byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVersionNotSupported = errors.New("version is not supported")
|
||||
ErrNotAuthorized = errors.New("not authorized")
|
||||
ErrNetworkFailure = errors.New("network failure")
|
||||
ErrOutOfResources = errors.New("out of resources")
|
||||
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
|
||||
ErrResultCodeUnknown = errors.New("result code is unknown")
|
||||
)
|
||||
|
||||
// checkResultCode checks the result code and returns an error
|
||||
// if the result code is not a success (0).
|
||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
|
||||
@@ -78,16 +61,16 @@ func checkResultCode(resultCode uint16) (err error) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return fmt.Errorf("%w", ErrVersionNotSupported)
|
||||
return errors.New("version is not supported")
|
||||
case 2:
|
||||
return fmt.Errorf("%w", ErrNotAuthorized)
|
||||
return errors.New("not authorized")
|
||||
case 3:
|
||||
return fmt.Errorf("%w", ErrNetworkFailure)
|
||||
return errors.New("network failure")
|
||||
case 4:
|
||||
return fmt.Errorf("%w", ErrOutOfResources)
|
||||
return errors.New("out of resources")
|
||||
case 5:
|
||||
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
|
||||
return errors.New("operation code is not supported")
|
||||
default:
|
||||
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
|
||||
return fmt.Errorf("result code is unknown: %d", resultCode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package natpmp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -11,12 +12,10 @@ func Test_checkRequest(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
request []byte
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"too_short": {
|
||||
request: []byte{1},
|
||||
err: ErrRequestSizeTooSmall,
|
||||
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
|
||||
},
|
||||
"success": {
|
||||
@@ -30,9 +29,10 @@ func Test_checkRequest(t *testing.T) {
|
||||
|
||||
err := checkRequest(testCase.request)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -50,33 +50,33 @@ func Test_checkResponse(t *testing.T) {
|
||||
}{
|
||||
"too_short": {
|
||||
response: []byte{1},
|
||||
err: ErrResponseSizeTooSmall,
|
||||
err: errors.New("response size is too small"),
|
||||
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
|
||||
},
|
||||
"size_mismatch": {
|
||||
response: []byte{0, 0, 0, 0},
|
||||
expectedResponseSize: 5,
|
||||
err: ErrResponseSizeUnexpected,
|
||||
err: errors.New("response size is unexpected"),
|
||||
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
|
||||
},
|
||||
"protocol_unknown": {
|
||||
response: []byte{1, 0, 0, 0},
|
||||
expectedResponseSize: 4,
|
||||
err: ErrProtocolVersionUnknown,
|
||||
err: errors.New("protocol version is unknown"),
|
||||
errMessage: "protocol version is unknown: 1",
|
||||
},
|
||||
"operation_code_unexpected": {
|
||||
response: []byte{0, 2, 0, 0},
|
||||
expectedOperationCode: 1,
|
||||
expectedResponseSize: 4,
|
||||
err: ErrOperationCodeUnexpected,
|
||||
err: errors.New("operation code is unexpected"),
|
||||
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
|
||||
},
|
||||
"result_code_failure": {
|
||||
response: []byte{0, 1, 0, 1},
|
||||
expectedOperationCode: 1,
|
||||
expectedResponseSize: 4,
|
||||
err: ErrVersionNotSupported,
|
||||
err: errors.New("version is not supported"),
|
||||
errMessage: "result code: version is not supported",
|
||||
},
|
||||
"success": {
|
||||
@@ -94,9 +94,11 @@ func Test_checkResponse(t *testing.T) {
|
||||
testCase.expectedOperationCode,
|
||||
testCase.expectedResponseSize)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
assert.ErrorContains(t, err, testCase.err.Error())
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -113,32 +115,32 @@ func Test_checkResultCode(t *testing.T) {
|
||||
"success": {},
|
||||
"version_unsupported": {
|
||||
resultCode: 1,
|
||||
err: ErrVersionNotSupported,
|
||||
err: errors.New("version is not supported"),
|
||||
errMessage: "version is not supported",
|
||||
},
|
||||
"not_authorized": {
|
||||
resultCode: 2,
|
||||
err: ErrNotAuthorized,
|
||||
err: errors.New("not authorized"),
|
||||
errMessage: "not authorized",
|
||||
},
|
||||
"network_failure": {
|
||||
resultCode: 3,
|
||||
err: ErrNetworkFailure,
|
||||
err: errors.New("network failure"),
|
||||
errMessage: "network failure",
|
||||
},
|
||||
"out_of_resources": {
|
||||
resultCode: 4,
|
||||
err: ErrOutOfResources,
|
||||
err: errors.New("out of resources"),
|
||||
errMessage: "out of resources",
|
||||
},
|
||||
"unsupported_operation_code": {
|
||||
resultCode: 5,
|
||||
err: ErrOperationCodeNotSupported,
|
||||
err: errors.New("operation code is not supported"),
|
||||
errMessage: "operation code is not supported",
|
||||
},
|
||||
"unknown": {
|
||||
resultCode: 6,
|
||||
err: ErrResultCodeUnknown,
|
||||
err: errors.New("result code is unknown"),
|
||||
errMessage: "result code is unknown: 6",
|
||||
},
|
||||
}
|
||||
@@ -149,9 +151,11 @@ func Test_checkResultCode(t *testing.T) {
|
||||
|
||||
err := checkResultCode(testCase.resultCode)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
assert.ErrorContains(t, err, testCase.err.Error())
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,17 +3,11 @@ package natpmp
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
|
||||
ErrLifetimeTooLong = errors.New("lifetime is too long")
|
||||
)
|
||||
|
||||
// Add or delete a port mapping. To delete a mapping, set both the
|
||||
// requestedExternalPort and lifetime to 0.
|
||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
|
||||
@@ -26,8 +20,9 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
||||
lifetimeSecondsFloat := lifetime.Seconds()
|
||||
const maxLifetimeSeconds = uint64(^uint32(0))
|
||||
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
|
||||
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds",
|
||||
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||
return 0, 0, 0, 0, fmt.Errorf("lifetime is too long: "+
|
||||
"%d seconds must at most %d seconds",
|
||||
uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||
}
|
||||
const messageSize = 12
|
||||
message := make([]byte, messageSize)
|
||||
@@ -38,7 +33,7 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
||||
case "tcp":
|
||||
message[1] = 2 // operationCode 2
|
||||
default:
|
||||
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol)
|
||||
return 0, 0, 0, 0, fmt.Errorf("network protocol is unknown: %s", protocol)
|
||||
}
|
||||
// [2:3] are reserved.
|
||||
binary.BigEndian.PutUint16(message[4:6], internalPort)
|
||||
|
||||
@@ -25,18 +25,15 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
||||
assignedInternalPort uint16
|
||||
assignedExternalPort uint16
|
||||
assignedLifetime time.Duration
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"lifetime_too_long": {
|
||||
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
|
||||
err: ErrLifetimeTooLong,
|
||||
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
|
||||
},
|
||||
"protocol_unknown": {
|
||||
lifetime: time.Second,
|
||||
protocol: "xyz",
|
||||
err: ErrNetworkProtocolUnknown,
|
||||
errMessage: "network protocol is unknown: xyz",
|
||||
},
|
||||
"rpc_error": {
|
||||
@@ -48,7 +45,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
||||
lifetime: 1200 * time.Second,
|
||||
initialConnectionDuration: time.Millisecond,
|
||||
exchanges: []udpExchange{{close: true}},
|
||||
err: ErrConnectionTimeout,
|
||||
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
|
||||
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
||||
},
|
||||
@@ -136,9 +132,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
|
||||
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
|
||||
assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
|
||||
if testCase.errMessage != "" {
|
||||
if testCase.err != nil {
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
}
|
||||
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -11,17 +11,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
|
||||
ErrConnectionTimeout = errors.New("connection timeout")
|
||||
)
|
||||
|
||||
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
||||
request []byte, responseSize uint) (
|
||||
response []byte, err error,
|
||||
) {
|
||||
if gateway.IsUnspecified() || !gateway.IsValid() {
|
||||
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified)
|
||||
return nil, errors.New("gateway IP is unspecified")
|
||||
}
|
||||
|
||||
err = checkRequest(request)
|
||||
@@ -114,8 +109,7 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
||||
}
|
||||
|
||||
if retryCount == c.maxRetries {
|
||||
return nil, fmt.Errorf("%w: failed attempts: %s",
|
||||
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
|
||||
return nil, fmt.Errorf("connection timeout: failed attempts: %s", dedupFailedAttempts(failedAttempts))
|
||||
}
|
||||
|
||||
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
|
||||
|
||||
@@ -20,20 +20,17 @@ func Test_Client_rpc(t *testing.T) {
|
||||
initialConnectionDuration time.Duration
|
||||
exchanges []udpExchange
|
||||
expectedResponse []byte
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"gateway_ip_unspecified": {
|
||||
gateway: netip.IPv6Unspecified(),
|
||||
request: []byte{0, 0},
|
||||
err: ErrGatewayIPUnspecified,
|
||||
errMessage: "gateway IP is unspecified",
|
||||
},
|
||||
"request_too_small": {
|
||||
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||
request: []byte{0},
|
||||
initialConnectionDuration: time.Nanosecond, // doesn't matter
|
||||
err: ErrRequestSizeTooSmall,
|
||||
errMessage: `checking request: message size is too small: ` +
|
||||
`need at least 2 bytes and got 1 byte\(s\)`,
|
||||
},
|
||||
@@ -53,7 +50,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
exchanges: []udpExchange{
|
||||
{request: []byte{0, 1}, close: true},
|
||||
},
|
||||
err: ErrConnectionTimeout,
|
||||
errMessage: "connection timeout: failed attempts: " +
|
||||
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
|
||||
},
|
||||
@@ -66,7 +62,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
request: []byte{0, 0},
|
||||
response: []byte{1},
|
||||
}},
|
||||
err: ErrResponseSizeTooSmall,
|
||||
errMessage: `checking response: response size is too small: ` +
|
||||
`need at least 4 bytes and got 1 byte\(s\)`,
|
||||
},
|
||||
@@ -80,7 +75,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||
response: []byte{0, 1, 2, 3}, // size 4
|
||||
}},
|
||||
err: ErrResponseSizeUnexpected,
|
||||
errMessage: `checking response: response size is unexpected: ` +
|
||||
`expected 5 bytes and got 4 byte\(s\)`,
|
||||
},
|
||||
@@ -94,7 +88,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
}},
|
||||
err: ErrProtocolVersionUnknown,
|
||||
errMessage: "checking response: protocol version is unknown: 1",
|
||||
},
|
||||
"unexpected_operation_code": {
|
||||
@@ -107,7 +100,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
}},
|
||||
err: ErrOperationCodeUnexpected,
|
||||
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
|
||||
},
|
||||
"failure_result_code": {
|
||||
@@ -120,7 +112,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
}},
|
||||
err: ErrResultCodeUnknown,
|
||||
errMessage: "checking response: result code: result code is unknown: 17",
|
||||
},
|
||||
"success": {
|
||||
@@ -153,9 +144,6 @@ func Test_Client_rpc(t *testing.T) {
|
||||
testCase.request, testCase.responseSize)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
if testCase.err != nil {
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
}
|
||||
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -41,8 +41,6 @@ func findAvailableTCPPort(t *testing.T) (port uint16) {
|
||||
func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
const ipv6InternetWorks = false
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -102,7 +100,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"firewall_add_error": {
|
||||
firewallAddErr: errTest,
|
||||
firewallAddErr: errors.New("test error"),
|
||||
errMessageRegex: func() string {
|
||||
return "accepting output traffic: test error"
|
||||
},
|
||||
@@ -122,7 +120,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||
},
|
||||
firewallRemoveErr: errTest,
|
||||
firewallRemoveErr: errors.New("test error"),
|
||||
errMessageRegex: func() string {
|
||||
return "removing output traffic rule: test error"
|
||||
},
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
@@ -47,8 +46,6 @@ func (n *NetLink) LinkList() (links []Link, err error) {
|
||||
return links, nil
|
||||
}
|
||||
|
||||
var ErrLinkNotFound = errors.New("link not found")
|
||||
|
||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
links, err := n.LinkList()
|
||||
if err != nil {
|
||||
@@ -61,7 +58,7 @@ func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
|
||||
return Link{}, fmt.Errorf("link not found: for name %s", name)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||
@@ -76,7 +73,7 @@ func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
|
||||
return Link{}, fmt.Errorf("link not found: for index %d", index)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||
@@ -114,7 +111,7 @@ func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
|
||||
return 0, fmt.Errorf("link not found: matching name %s", link.Name)
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
package extract
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRead = errors.New("cannot read file")
|
||||
ErrExtractConnection = errors.New("cannot extract connection from file")
|
||||
)
|
||||
|
||||
// Data extracts the lines and connection from the OpenVPN configuration file.
|
||||
func (e *Extractor) Data(filepath string) (lines []string,
|
||||
connection models.Connection, err error,
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var errRemoteLineNotFound = errors.New("remote line not found")
|
||||
|
||||
func extractDataFromLines(lines []string) (
|
||||
connection models.Connection, err error,
|
||||
) {
|
||||
@@ -35,7 +33,7 @@ func extractDataFromLines(lines []string) (
|
||||
}
|
||||
|
||||
if !connection.IP.IsValid() {
|
||||
return connection, errRemoteLineNotFound
|
||||
return connection, errors.New("remote line not found")
|
||||
}
|
||||
|
||||
if connection.Protocol == "" {
|
||||
@@ -81,19 +79,15 @@ func extractDataFromLine(line string) (
|
||||
return ip, 0, "", nil
|
||||
}
|
||||
|
||||
var errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected")
|
||||
|
||||
func extractProto(line string) (protocol string, err error) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 { //nolint:mnd
|
||||
return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line)
|
||||
return "", fmt.Errorf("proto line has not 2 fields as expected: %s", line)
|
||||
}
|
||||
|
||||
return parseProto(fields[1])
|
||||
}
|
||||
|
||||
var errProtocolNotSupported = errors.New("network protocol not supported")
|
||||
|
||||
func parseProto(field string) (protocol string, err error) {
|
||||
switch field {
|
||||
case "tcp", "tcp4", "tcp6", "tcp-client":
|
||||
@@ -106,16 +100,10 @@ func parseProto(field string) (protocol string, err error) {
|
||||
// determined by the remote IP address version.
|
||||
return constants.UDP, nil
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", errProtocolNotSupported, field)
|
||||
return "", fmt.Errorf("network protocol not supported: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected")
|
||||
errHostNotIP = errors.New("host is not an IP address")
|
||||
errPortNotValid = errors.New("port is not valid")
|
||||
)
|
||||
|
||||
func extractRemote(line string) (ip netip.Addr, port uint16,
|
||||
protocol string, err error,
|
||||
) {
|
||||
@@ -123,13 +111,13 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
||||
n := len(fields)
|
||||
|
||||
if n < 2 || n > 4 {
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line)
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("remote line has not 2 fields as expected: %s", line)
|
||||
}
|
||||
|
||||
host := fields[1]
|
||||
ip, err = netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host)
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("host is not an IP address: %s", host)
|
||||
// TODO resolve hostname once there is an option to allow it through
|
||||
// the firewall before the VPN is up.
|
||||
}
|
||||
@@ -137,9 +125,9 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
||||
if n > 2 { //nolint:mnd
|
||||
portInt, err := strconv.Atoi(fields[2])
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %s", line)
|
||||
} else if portInt < 1 || portInt > 65535 {
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
|
||||
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
|
||||
}
|
||||
port = uint16(portInt)
|
||||
}
|
||||
@@ -154,20 +142,18 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
|
||||
return ip, port, protocol, nil
|
||||
}
|
||||
|
||||
var errPostLineFieldsCount = errors.New("post line has not 2 fields as expected")
|
||||
|
||||
func extractPort(line string) (port uint16, err error) {
|
||||
fields := strings.Fields(line)
|
||||
const expectedFieldsCount = 2
|
||||
if len(fields) != expectedFieldsCount {
|
||||
return 0, fmt.Errorf("%w: %s", errPostLineFieldsCount, line)
|
||||
return 0, fmt.Errorf("post line has not 2 fields as expected: %s", line)
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(fields[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %s", errPortNotValid, line)
|
||||
return 0, fmt.Errorf("port is not valid: %s", line)
|
||||
} else if portInt < 1 || portInt > 65535 {
|
||||
return 0, fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
|
||||
return 0, fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
|
||||
}
|
||||
port = uint16(portInt)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ func Test_extractDataFromLines(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
lines []string
|
||||
connection models.Connection
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"success": {
|
||||
lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
|
||||
@@ -28,8 +28,8 @@ func Test_extractDataFromLines(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"extraction error": {
|
||||
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
||||
err: errors.New("on line 2: extracting protocol from proto line: network protocol not supported: bad"),
|
||||
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
||||
errMessage: "on line 2: extracting protocol from proto line: network protocol not supported: bad",
|
||||
},
|
||||
"only use first values found": {
|
||||
lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"},
|
||||
@@ -44,7 +44,7 @@ func Test_extractDataFromLines(t *testing.T) {
|
||||
connection: models.Connection{
|
||||
Protocol: constants.TCP,
|
||||
},
|
||||
err: errRemoteLineNotFound,
|
||||
errMessage: "remote line not found",
|
||||
},
|
||||
"default TCP port": {
|
||||
lines: []string{"remote 1.2.3.4", "proto tcp"},
|
||||
@@ -70,9 +70,8 @@ func Test_extractDataFromLines(t *testing.T) {
|
||||
|
||||
connection, err := extractDataFromLines(testCase.lines)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -86,18 +85,18 @@ func Test_extractDataFromLine(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
line string
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
protocol string
|
||||
isErr error
|
||||
line string
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
protocol string
|
||||
errMessage string
|
||||
}{
|
||||
"irrelevant line": {
|
||||
line: "bla",
|
||||
},
|
||||
"extract proto error": {
|
||||
line: "proto bad",
|
||||
isErr: errProtocolNotSupported,
|
||||
line: "proto bad",
|
||||
errMessage: "network protocol not supported",
|
||||
},
|
||||
"extract proto success": {
|
||||
line: "proto tcp",
|
||||
@@ -108,8 +107,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
||||
protocol: constants.TCP,
|
||||
},
|
||||
"extract remote error": {
|
||||
line: "remote bad",
|
||||
isErr: errHostNotIP,
|
||||
line: "remote bad",
|
||||
errMessage: "host is not an IP address",
|
||||
},
|
||||
"extract remote success": {
|
||||
line: "remote 1.2.3.4 1194 udp",
|
||||
@@ -118,8 +117,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
"extract_port_fail": {
|
||||
line: "port a",
|
||||
isErr: errPortNotValid,
|
||||
line: "port a",
|
||||
errMessage: "port is not valid",
|
||||
},
|
||||
"extract_port_success": {
|
||||
line: "port 1194",
|
||||
@@ -133,8 +132,8 @@ func Test_extractDataFromLine(t *testing.T) {
|
||||
|
||||
ip, port, protocol, err := extractDataFromLine(testCase.line)
|
||||
|
||||
if testCase.isErr != nil {
|
||||
assert.ErrorIs(t, err, testCase.isErr)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorContains(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errPEMDecode = errors.New("cannot decode PEM encoded block")
|
||||
|
||||
func PEM(b []byte) (encodedData string, err error) {
|
||||
pemBlock, _ := pem.Decode(b)
|
||||
if pemBlock == nil {
|
||||
return "", fmt.Errorf("%w", errPEMDecode)
|
||||
return "", errors.New("cannot decode PEM encoded block")
|
||||
}
|
||||
|
||||
der := pemBlock.Bytes
|
||||
|
||||
@@ -13,16 +13,13 @@ func Test_PEM(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
b []byte
|
||||
encodedData string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no input": {
|
||||
errWrapped: errPEMDecode,
|
||||
errMessage: "cannot decode PEM encoded block",
|
||||
},
|
||||
"bad input": {
|
||||
b: []byte{1, 2, 3},
|
||||
errWrapped: errPEMDecode,
|
||||
errMessage: "cannot decode PEM encoded block",
|
||||
},
|
||||
"valid data with extras": {
|
||||
@@ -46,9 +43,10 @@ func Test_PEM(t *testing.T) {
|
||||
encodedData, err := PEM(testCase.b)
|
||||
|
||||
assert.Equal(t, testCase.encodedData, encodedData)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package pkcs8
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
@@ -11,8 +10,6 @@ import (
|
||||
// https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers
|
||||
var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals
|
||||
|
||||
var ErrEncryptionAlgorithmNotPBES2 = errors.New("encryption algorithm is not PBES2")
|
||||
|
||||
type encryptedPrivateKey struct {
|
||||
EncryptionAlgorithm pkix.AlgorithmIdentifier
|
||||
EncryptedData []byte
|
||||
@@ -35,8 +32,8 @@ func getEncryptionAlgorithmOid(der []byte) (
|
||||
oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
|
||||
oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm
|
||||
if !oidAlgorithm.Equal(oidPBES2) {
|
||||
return nil, fmt.Errorf("%w: %s instead of PBES2 %s",
|
||||
ErrEncryptionAlgorithmNotPBES2, oidAlgorithm, oidPBES2)
|
||||
return nil, fmt.Errorf("encryption algorithm is not PBES2: %s instead of PBES2 %s",
|
||||
oidAlgorithm, oidPBES2)
|
||||
}
|
||||
|
||||
var encryptionAlgorithmParams encryptedAlgorithmParams
|
||||
|
||||
@@ -2,14 +2,11 @@ package pkcs8
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
pkcs8lib "github.com/youmark/pkcs8"
|
||||
)
|
||||
|
||||
var ErrUnsupportedKeyType = errors.New("unsupported key type")
|
||||
|
||||
// UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption
|
||||
// if its encryption is too weak for Openvpn/Openssl.
|
||||
// If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC.
|
||||
|
||||
@@ -2,15 +2,12 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||
)
|
||||
|
||||
var ErrVersionUnknown = errors.New("OpenVPN version is unknown")
|
||||
|
||||
const (
|
||||
binOpenvpn25 = "openvpn2.5"
|
||||
binOpenvpn26 = "openvpn2.6"
|
||||
@@ -26,7 +23,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
|
||||
case openvpn.Openvpn26:
|
||||
bin = binOpenvpn26
|
||||
default:
|
||||
return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version)
|
||||
return nil, nil, nil, fmt.Errorf("OpenVPN version is unknown: %s", version)
|
||||
}
|
||||
|
||||
args := []string{"--config", configPath}
|
||||
|
||||
@@ -2,7 +2,6 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
@@ -16,8 +15,6 @@ func (c *Configurator) Version26(ctx context.Context) (version string, err error
|
||||
return c.version(ctx, binOpenvpn26)
|
||||
}
|
||||
|
||||
var ErrVersionTooShort = errors.New("version output is too short")
|
||||
|
||||
func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) {
|
||||
cmd := exec.CommandContext(ctx, binName, "--version")
|
||||
output, err := c.cmder.Run(cmd)
|
||||
@@ -28,7 +25,7 @@ func (c *Configurator) version(ctx context.Context, binName string) (version str
|
||||
words := strings.Fields(firstLine)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrVersionTooShort, firstLine)
|
||||
return "", fmt.Errorf("version output is too short: %s", firstLine)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
|
||||
@@ -2,24 +2,17 @@ package icmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
||||
return fmt.Errorf("ICMP Next Hop MTU is too low: %d", mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
return fmt.Errorf("ICMP Next Hop MTU is too high: %d is larger than physical link MTU %d", mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -34,14 +27,12 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
return false, fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
) (err error) {
|
||||
@@ -51,12 +42,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
return fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
return fmt.Errorf("ICMP id mismatch: sent id %d and received id %d",
|
||||
outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
@@ -65,19 +56,17 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrEchoDataMismatch, len(sent), len(received))
|
||||
return fmt.Errorf("ICMP data mismatch: sent %d bytes and received %d bytes",
|
||||
len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrEchoDataMismatch, sent, received)
|
||||
return fmt.Errorf("ICMP data mismatch: sent %x and received %x",
|
||||
sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
errCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrMTUNotFound = errors.New("MTU not found")
|
||||
errTimeout = errors.New("operation timed out")
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
||||
case errors.Is(err, errTimeout) || errors.Is(err, errCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
|
||||
}
|
||||
|
||||
@@ -117,13 +117,10 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
case portUnreachable: // triggered by TCP or UDP from applications
|
||||
continue // ignore and wait for the next message
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrDestinationUnreachable,
|
||||
ErrCommunicationAdministrativelyProhibited,
|
||||
return 0, fmt.Errorf("ICMP destination unreachable: %w (code %d)", errCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrDestinationUnreachable, inboundMessage.Code)
|
||||
return 0, fmt.Errorf("ICMP destination unreachable: code %d", inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
@@ -158,7 +155,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("ICMP body type is not supported: %T", typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -115,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
||||
return 0, errors.New("ICMP destination unreachable")
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
@@ -128,7 +129,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("ICMP body type %T is not supported", typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
err = ErrNotPermitted
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
|
||||
continue
|
||||
default:
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
|
||||
return fmt.Errorf("ICMP body type is not supported: %T", message.Body)
|
||||
}
|
||||
|
||||
echoBody, _ := message.Body.(*icmp.Echo)
|
||||
@@ -183,8 +183,8 @@ func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
return fmt.Errorf("ICMP data mismatch: sent %dB and received %dB",
|
||||
sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
|
||||
@@ -29,7 +29,7 @@ func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(),
|
||||
}
|
||||
|
||||
var (
|
||||
errNoRoute = fmt.Errorf("no route to destination")
|
||||
errNoRoute = errors.New("no route to destination")
|
||||
ErrNetworkUnreachable = errors.New("network unreachable")
|
||||
)
|
||||
|
||||
|
||||
@@ -13,11 +13,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
|
||||
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
|
||||
)
|
||||
|
||||
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||
@@ -81,10 +76,10 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
|
||||
}
|
||||
}
|
||||
if icmpSuccess {
|
||||
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d",
|
||||
ErrICMPOkTCPFail, maxPossibleMTU)
|
||||
return 0, fmt.Errorf("PMTUD succeeded with ICMP but failed with TCP "+
|
||||
"- discarding ICMP obtained MTU %d", maxPossibleMTU)
|
||||
}
|
||||
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail)
|
||||
return 0, errors.New("PMTUD failed with both ICMP and TCP")
|
||||
}
|
||||
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
|
||||
return mtu, nil
|
||||
|
||||
@@ -57,8 +57,6 @@ func (l *noopLogger) Warn(_ string) {}
|
||||
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Error(_ string) {}
|
||||
|
||||
var errRouteNotFound = errors.New("route not found")
|
||||
|
||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||
if err != nil {
|
||||
@@ -76,7 +74,7 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
return min(link.MTU, maxMTU), nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||
return 0, errors.New("route not found: no loopback route found")
|
||||
}
|
||||
|
||||
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
@@ -100,7 +98,7 @@ func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
}
|
||||
}
|
||||
if mtu == 0 {
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
return 0, errors.New("route not found: no default route found")
|
||||
}
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
var errTCPServersUnreachable = errors.New("all TCP servers are unreachable")
|
||||
|
||||
// findHighestMSSDestination finds the destination with the highest
|
||||
// MSS amongst the provided destinations.
|
||||
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
|
||||
@@ -68,7 +66,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
|
||||
}
|
||||
|
||||
if mss == 0 { // no MSS found for any destination
|
||||
return netip.AddrPort{}, 0, fmt.Errorf("%w (%d servers)", errTCPServersUnreachable, len(dsts))
|
||||
return netip.AddrPort{}, 0, fmt.Errorf("all %d TCP servers are unreachable", len(dsts))
|
||||
}
|
||||
|
||||
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
|
||||
@@ -77,8 +75,6 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
|
||||
return dst, mss, nil
|
||||
}
|
||||
|
||||
var errMSSNotFound = errors.New("MSS option not found in reply")
|
||||
|
||||
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
|
||||
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
|
||||
mss uint32, err error,
|
||||
@@ -132,11 +128,12 @@ func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
|
||||
case replyHeader.typ != packetTypeSYNACK:
|
||||
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ)
|
||||
return 0, fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", replyHeader.typ)
|
||||
case replyHeader.ack != synSeq+1:
|
||||
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack)
|
||||
return 0, fmt.Errorf("TCP SYN-ACK ACK number %d does not match expected value %d",
|
||||
replyHeader.ack, synSeq+1)
|
||||
case replyHeader.options.mss == 0:
|
||||
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound)
|
||||
return 0, errors.New("MSS option not found in reply")
|
||||
}
|
||||
|
||||
err = sendRST(fd, src, dst, replyHeader.ack)
|
||||
|
||||
@@ -12,11 +12,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMTUNotFound = errors.New("MTU not found")
|
||||
ErrMSSTooSmall = errors.New("TCP MSS is too small to find the MTU")
|
||||
)
|
||||
|
||||
type testUnit struct {
|
||||
mtu uint32
|
||||
ok bool
|
||||
@@ -178,5 +173,5 @@ func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
||||
return 0, errors.New("MTU not found: your connection might not be working at all")
|
||||
}
|
||||
|
||||
@@ -75,13 +75,6 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
|
||||
return fileDescriptor(fdPlatform), stop, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
|
||||
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
|
||||
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
|
||||
errTCPPacketLost = errors.New("TCP packet was lost")
|
||||
)
|
||||
|
||||
// Craft and send a raw TCP packet to test the MTU.
|
||||
// It expects either an RST reply (if no server is listening)
|
||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||
@@ -142,9 +135,10 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
||||
// server actively closed the connection, try sending a SYN with data
|
||||
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
||||
case firstReplyHeader.typ != packetTypeSYNACK:
|
||||
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ)
|
||||
return fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", firstReplyHeader.typ)
|
||||
case firstReplyHeader.ack != synSeq+1:
|
||||
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack)
|
||||
return fmt.Errorf("TCP SYN-ACK ACK number does not match expected value: "+
|
||||
"expected %d, got %d", synSeq+1, firstReplyHeader.ack)
|
||||
}
|
||||
|
||||
if firstReplyHeader.options.mss != 0 {
|
||||
@@ -191,15 +185,13 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
||||
}
|
||||
return nil
|
||||
case packetTypeSYNACK: // server never received our MTU-test ACK packet
|
||||
return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost)
|
||||
return errors.New("TCP packet was lost: server responded with second SYN-ACK packet")
|
||||
default:
|
||||
_ = sendRST(fd, src, dst, finalPacketHeader.ack)
|
||||
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ)
|
||||
return fmt.Errorf("final TCP packet type is unexpected: %s", finalPacketHeader.typ)
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
|
||||
|
||||
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||
src, dst netip.AddrPort, mtu uint32,
|
||||
) error {
|
||||
@@ -223,7 +215,7 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||
return fmt.Errorf("parsing reply TCP header: %w", err)
|
||||
} else if replyPacketHeader.typ != packetTypeRST &&
|
||||
replyPacketHeader.typ != packetTypeRSTACK {
|
||||
return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ)
|
||||
return fmt.Errorf("TCP packet is not an RST: %s", replyPacketHeader.typ)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
@@ -120,17 +119,11 @@ type tcpHeader struct {
|
||||
options options
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPHeaderTooShort = errors.New("TCP header is too short")
|
||||
errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid")
|
||||
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
|
||||
)
|
||||
|
||||
// parseTCPHeader parses the TCP header from b.
|
||||
// b should be the entire TCP packet bytes.
|
||||
func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
||||
if len(b) < int(constants.BaseTCPHeaderLength) {
|
||||
return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b))
|
||||
return tcpHeader{}, fmt.Errorf("TCP header is too short: %d bytes", len(b))
|
||||
}
|
||||
|
||||
header.srcPort = binary.BigEndian.Uint16(b[0:2])
|
||||
@@ -146,11 +139,11 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
||||
|
||||
switch {
|
||||
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
|
||||
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes",
|
||||
errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength)
|
||||
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
|
||||
"data offset is %d bytes, expected at least %d bytes", header.dataOffset, constants.BaseTCPHeaderLength)
|
||||
case int(header.dataOffset) > len(b):
|
||||
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes",
|
||||
errTCPHeaderDataOffset, header.dataOffset, len(b))
|
||||
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
|
||||
"data offset is %d bytes, but packet is only %d bytes", header.dataOffset, len(b))
|
||||
}
|
||||
|
||||
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
|
||||
@@ -186,7 +179,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
|
||||
case flags&ackFlag != 0:
|
||||
header.typ = packetTypeACK
|
||||
default:
|
||||
return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
|
||||
return tcpHeader{}, fmt.Errorf("TCP packet type is unknown: flags are 0x%02x", flags)
|
||||
}
|
||||
|
||||
header.seq = binary.BigEndian.Uint32(b[4:8])
|
||||
@@ -206,15 +199,6 @@ type optionTimestamps struct {
|
||||
echo uint32
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPOptionLengthTruncated = errors.New("TCP option length is truncated")
|
||||
ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid")
|
||||
ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid")
|
||||
ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid")
|
||||
ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid")
|
||||
errTCPOptionTypeUnknown = errors.New("TCP option type is unknown")
|
||||
)
|
||||
|
||||
func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||
i := 0
|
||||
for i < len(b) {
|
||||
@@ -232,7 +216,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||
// Handle TLV (Type-Length-Value) options
|
||||
if i+1 >= len(b) {
|
||||
// This should not happen for DF packets.
|
||||
return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i)
|
||||
return options{}, fmt.Errorf("TCP option length is truncated: at offset %d", i)
|
||||
}
|
||||
|
||||
length := int(b[i+1])
|
||||
@@ -240,11 +224,11 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||
maxLength := len(b) - i
|
||||
switch {
|
||||
case length < minLength:
|
||||
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d",
|
||||
ErrTCPOptionLengthInvalid, optionType, i, length, minLength)
|
||||
return options{}, fmt.Errorf("TCP option length is invalid: "+
|
||||
"type %d at offset %d has length %d < %d", optionType, i, length, minLength)
|
||||
case length > maxLength:
|
||||
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d",
|
||||
ErrTCPOptionLengthInvalid, optionType, i, length, maxLength)
|
||||
return options{}, fmt.Errorf("TCP option length is invalid: "+
|
||||
"type %d at offset %d has length %d > %d", optionType, i, length, maxLength)
|
||||
}
|
||||
|
||||
data := b[i+2 : i+length]
|
||||
@@ -259,15 +243,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||
case optionTypeMSS:
|
||||
const expectedLength = 4
|
||||
if length != expectedLength {
|
||||
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
|
||||
ErrTCPOptionMSSInvalid, i, length, expectedLength)
|
||||
return options{}, fmt.Errorf("TCP option MSS value is invalid: "+
|
||||
"MSS option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||
}
|
||||
parsed.mss = uint32(binary.BigEndian.Uint16(data))
|
||||
case optionTypeWindowScale:
|
||||
const expectedLength = 3
|
||||
if length != expectedLength {
|
||||
return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d",
|
||||
ErrTCPOptionWindowScaleInvalid, i, length, expectedLength)
|
||||
return options{}, fmt.Errorf("TCP option Window Scale value is invalid: "+
|
||||
"window scale option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||
}
|
||||
windowScale := data[0]
|
||||
parsed.windowScale = &windowScale
|
||||
@@ -276,15 +260,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
|
||||
case optionTypeTimestamps:
|
||||
const expectedLength = 10
|
||||
if length != expectedLength {
|
||||
return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d",
|
||||
ErrTCPOptionTimestampsInvalid, i, length, expectedLength)
|
||||
return options{}, fmt.Errorf("TCP option Timestamps value is invalid: "+
|
||||
"timestamps option at offset %d has length %d, expected %d", i, length, expectedLength)
|
||||
}
|
||||
parsed.timestamps = &optionTimestamps{
|
||||
value: binary.BigEndian.Uint32(data[:4]),
|
||||
echo: binary.BigEndian.Uint32(data[4:]),
|
||||
}
|
||||
default:
|
||||
return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType)
|
||||
return options{}, fmt.Errorf("TCP option type is unknown: type %d", optionType)
|
||||
}
|
||||
|
||||
i += length
|
||||
|
||||
@@ -177,11 +177,9 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
|
||||
return l.service.GetPortsForwarded()
|
||||
}
|
||||
|
||||
var ErrServiceNotStarted = errors.New("port forwarding service not started")
|
||||
|
||||
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
|
||||
if l.service == nil {
|
||||
return fmt.Errorf("%w", ErrServiceNotStarted)
|
||||
return errors.New("port forwarding service not started")
|
||||
}
|
||||
|
||||
return l.service.SetPortsForwarded(l.runCtx, ports)
|
||||
|
||||
@@ -55,23 +55,10 @@ func (s *Settings) OverrideWith(update Settings) {
|
||||
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPortForwarderNotSet = errors.New("port forwarder not set")
|
||||
ErrServerNameNotSet = errors.New("server name not set")
|
||||
ErrUsernameNotSet = errors.New("username not set")
|
||||
ErrPasswordNotSet = errors.New("password not set")
|
||||
ErrFilepathNotSet = errors.New("file path not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
ErrPortsCountZero = errors.New("ports count cannot be zero")
|
||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
||||
)
|
||||
|
||||
func (s *Settings) Validate(forStartup bool) (err error) {
|
||||
// Minimal validation
|
||||
if s.Filepath == "" {
|
||||
return fmt.Errorf("%w", ErrFilepathNotSet)
|
||||
return errors.New("file path not set")
|
||||
}
|
||||
|
||||
if !forStartup {
|
||||
@@ -83,41 +70,42 @@ func (s *Settings) Validate(forStartup bool) (err error) {
|
||||
// Startup validation requires additional fields set.
|
||||
switch {
|
||||
case s.PortForwarder == nil:
|
||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
||||
return errors.New("port forwarder not set")
|
||||
case s.Interface == "":
|
||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
||||
return errors.New("interface not set")
|
||||
case s.PortsCount == 0:
|
||||
return fmt.Errorf("%w", ErrPortsCountZero)
|
||||
return errors.New("ports count cannot be zero")
|
||||
}
|
||||
|
||||
switch s.PortForwarder.Name() {
|
||||
case providers.PrivateInternetAccess:
|
||||
switch {
|
||||
case s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
return errors.New("server name not set")
|
||||
case s.Username == "":
|
||||
return fmt.Errorf("%w", ErrUsernameNotSet)
|
||||
return errors.New("username not set")
|
||||
case s.Password == "":
|
||||
return fmt.Errorf("%w", ErrPasswordNotSet)
|
||||
return errors.New("password not set")
|
||||
}
|
||||
case providers.Protonvpn:
|
||||
const maxPortsCount = 4
|
||||
if s.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
|
||||
}
|
||||
default:
|
||||
const maxPortsCount = 1
|
||||
if s.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(s.ListeningPorts, []uint16{0}) {
|
||||
switch {
|
||||
case len(s.ListeningPorts) != int(s.PortsCount):
|
||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount)
|
||||
return fmt.Errorf("listening ports length must be equal to ports count: %d != %d",
|
||||
len(s.ListeningPorts), s.PortsCount)
|
||||
case slices.Contains(s.ListeningPorts, 0):
|
||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts)
|
||||
return fmt.Errorf("listening port cannot be 0: in %v", s.ListeningPorts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ func Test_Server_BadSettings(t *testing.T) {
|
||||
|
||||
server, err := New(settings)
|
||||
assert.Nil(t, server)
|
||||
assert.ErrorIs(t, err, ErrBlockProfileRateNegative)
|
||||
assert.ErrorContains(t, err, "block profile rate cannot be negative")
|
||||
const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative"
|
||||
assert.EqualError(t, err, expectedErrMessage)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package pprof
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/httpserver"
|
||||
@@ -51,18 +50,13 @@ func (s *Settings) OverrideWith(other Settings) {
|
||||
s.HTTPServer.OverrideWith(other.HTTPServer)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrBlockProfileRateNegative = errors.New("block profile rate cannot be negative")
|
||||
ErrMutexProfileRateNegative = errors.New("mutex profile rate cannot be negative")
|
||||
)
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
if *s.BlockProfileRate < 0 {
|
||||
return fmt.Errorf("%w", ErrBlockProfileRateNegative)
|
||||
return errors.New("block profile rate cannot be negative")
|
||||
}
|
||||
|
||||
if *s.MutexProfileRate < 0 {
|
||||
return fmt.Errorf("%w", ErrMutexProfileRateNegative)
|
||||
return errors.New("mutex profile rate cannot be negative")
|
||||
}
|
||||
|
||||
return s.HTTPServer.Validate()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/httpserver"
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -195,7 +194,6 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"negative block profile rate": {
|
||||
@@ -203,16 +201,14 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
BlockProfileRate: intPtr(-1),
|
||||
MutexProfileRate: intPtr(0),
|
||||
},
|
||||
errWrapped: ErrBlockProfileRateNegative,
|
||||
errMessage: ErrBlockProfileRateNegative.Error(),
|
||||
errMessage: "block profile rate cannot be negative",
|
||||
},
|
||||
"negative mutex profile rate": {
|
||||
settings: Settings{
|
||||
BlockProfileRate: intPtr(0),
|
||||
MutexProfileRate: intPtr(-1),
|
||||
},
|
||||
errWrapped: ErrMutexProfileRateNegative,
|
||||
errMessage: ErrMutexProfileRateNegative.Error(),
|
||||
errMessage: "mutex profile rate cannot be negative",
|
||||
},
|
||||
"http server validation error": {
|
||||
settings: Settings{
|
||||
@@ -222,7 +218,6 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
Address: ":x",
|
||||
},
|
||||
},
|
||||
errWrapped: validate.ErrPortNotAnInteger,
|
||||
errMessage: "port value is not an integer: x",
|
||||
},
|
||||
"valid settings": {
|
||||
@@ -247,9 +242,10 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
err := testCase.settings.Validate()
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
)
|
||||
|
||||
type apiData struct {
|
||||
@@ -48,8 +46,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
_ = response.Body.Close()
|
||||
return data, fmt.Errorf("%w: %d %s",
|
||||
common.ErrHTTPStatusCodeNotOK, response.StatusCode, response.Status)
|
||||
return data, fmt.Errorf("HTTP status code not OK: %d %s",
|
||||
response.StatusCode, response.Status)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package common
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrPortForwardNotSupported = errors.New("port forwarding not supported")
|
||||
@@ -10,10 +10,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||
ErrIPFetcherUnsupported = errors.New("IP fetcher not supported")
|
||||
ErrCredentialsMissing = errors.New("credentials missing")
|
||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||
ErrCredentialsMissing = errors.New("credentials are missing")
|
||||
)
|
||||
|
||||
type Fetcher interface {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package custom
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
@@ -10,8 +9,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
|
||||
|
||||
// GetConnection gets the connection from the OpenVPN configuration file.
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
|
||||
connection models.Connection, err error,
|
||||
@@ -22,7 +19,7 @@ func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
|
||||
case vpn.Wireguard, vpn.AmneziaWg:
|
||||
return getWireguardConnection(selection), nil
|
||||
default:
|
||||
return connection, fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, selection.VPN)
|
||||
return connection, fmt.Errorf("VPN type not supported for custom provider: %s", selection.VPN)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package custom
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -12,8 +11,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
var ErrExtractData = errors.New("failed extracting information from custom configuration file")
|
||||
|
||||
func (p *Provider) OpenVPNConfig(connection models.Connection,
|
||||
settings settings.OpenVPN, ipv6Supported bool,
|
||||
) (lines []string) {
|
||||
|
||||
@@ -3,13 +3,10 @@ package updater
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||
|
||||
type apiData struct {
|
||||
Servers []apiServer `json:"servers"`
|
||||
}
|
||||
@@ -42,8 +39,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
_ = response.Body.Close()
|
||||
return data, fmt.Errorf("%w: %d %s",
|
||||
errHTTPStatusCodeNotOK, response.StatusCode, response.Status)
|
||||
return data, fmt.Errorf("HTTP status code not OK: %d %s",
|
||||
response.StatusCode, response.Status)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
|
||||
@@ -21,21 +21,17 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
|
||||
const provider = providers.Expressvpn
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
ipv6Supported bool
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
panicMessage string
|
||||
}{
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
storageErr: errors.New("test error"),
|
||||
errMessage: "filtering servers: test error",
|
||||
},
|
||||
"default OpenVPN TCP port": {
|
||||
@@ -100,9 +96,10 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
|
||||
connection, err := provider.GetConnection(testCase.selection, testCase.ipv6Supported)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.connection, connection)
|
||||
|
||||
@@ -3,14 +3,11 @@ package updater
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
)
|
||||
|
||||
type apiServer struct {
|
||||
@@ -19,8 +16,6 @@ type apiServer struct {
|
||||
hostname string
|
||||
}
|
||||
|
||||
var ErrDataMalformed = errors.New("data is malformed")
|
||||
|
||||
const apiURL = "https://support.fastestvpn.com/wp-admin/admin-ajax.php"
|
||||
|
||||
// The API URL and requests are shamelessly taken from network operations
|
||||
@@ -49,7 +44,7 @@ func fetchAPIServers(ctx context.Context, client *http.Client, protocol string)
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
_ = response.Body.Close()
|
||||
return nil, fmt.Errorf("%w: %d", common.ErrHTTPStatusCodeNotOK, response.StatusCode)
|
||||
return nil, fmt.Errorf("HTTP status code not OK: %d", response.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(response.Body)
|
||||
@@ -79,8 +74,8 @@ func fetchAPIServers(ctx context.Context, client *http.Client, protocol string)
|
||||
for i := range numberOfTDBlocks {
|
||||
tdBlock := getNextTDBlock(trBlock)
|
||||
if tdBlock == nil {
|
||||
return nil, fmt.Errorf("%w: expected 3 <td> blocks in <tr> block %q",
|
||||
ErrDataMalformed, string(trBlock))
|
||||
return nil, fmt.Errorf("data is malformed: expected 3 <td> blocks in <tr> block %q",
|
||||
string(trBlock))
|
||||
}
|
||||
trBlock = trBlock[len(tdBlock):]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user