chore: do not use sentinel errors when unneeded

- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly
- only use sentinel errors when it has to be checked using `errors.Is`
- replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error
- exclude the sentinel error definition requirement from .golangci.yml
- update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
Quentin McGaw
2026-05-02 00:50:16 +00:00
parent 9b6f048fe8
commit 4a78989d9d
172 changed files with 666 additions and 1433 deletions
+2 -5
View File
@@ -1,7 +1,6 @@
package alpine
import (
"errors"
"fmt"
"io/fs"
"os"
@@ -9,8 +8,6 @@ import (
"strconv"
)
var ErrUserAlreadyExists = errors.New("user already exists")
// CreateUser creates a user in Alpine with the given UID.
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
UIDStr := strconv.Itoa(uid)
@@ -34,8 +31,8 @@ func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, e
}
if u != nil {
return "", fmt.Errorf("%w: with name %s for ID %s instead of %d",
ErrUserAlreadyExists, username, u.Uid, uid)
return "", fmt.Errorf("user already exists: with name %s for ID %s instead of %d",
username, u.Uid, uid)
}
const permission = fs.FileMode(0o644)
+2 -1
View File
@@ -1,6 +1,7 @@
package amneziawg
import (
"errors"
"net/netip"
"testing"
@@ -28,7 +29,7 @@ func Test_New(t *testing.T) {
PrivateKey: "",
},
},
err: wireguard.ErrPrivateKeyMissing,
err: errors.New("private key is missing"),
},
"minimal valid settings": {
settings: Settings{
+2 -8
View File
@@ -13,11 +13,6 @@ import (
"github.com/qdm12/gluetun/internal/wireguard"
)
var (
errTunNameMismatch = errors.New("TUN device name is mismatching")
errDeviceWaited = errors.New("device waited for")
)
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
// interface and returns any error that occurred during setup or waiting. It sends an error to
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
@@ -52,8 +47,7 @@ func setupUserspace(ctx context.Context,
if err != nil {
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
} else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
errTunNameMismatch, interfaceName, tunName)
return 0, nil, fmt.Errorf("TUN device name is mismatching: expected %q and got %q", interfaceName, tunName)
}
link, err := netLinker.LinkByName(interfaceName)
@@ -106,7 +100,7 @@ func setupUserspace(ctx context.Context,
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = errDeviceWaited
err = errors.New("device waited for")
}
cleanups.Cleanup(logger)
+2 -8
View File
@@ -16,11 +16,6 @@ import (
"golang.org/x/text/language"
)
var (
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
)
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
provider string, titleCaser cases.Caser,
) {
@@ -65,11 +60,10 @@ func (c *CLI) FormatServers(args []string) error {
}
switch len(providers) {
case 0:
return fmt.Errorf("%w", ErrProviderUnspecified)
return errors.New("VPN provider to format was not specified")
case 1:
default:
return fmt.Errorf("%w: %d specified: %s",
ErrMultipleProvidersToFormat, len(providers),
return fmt.Errorf("more than one VPN provider to format were specified: %d specified: %s", len(providers),
strings.Join(providers, ", "))
}
+2 -9
View File
@@ -24,13 +24,6 @@ import (
"github.com/qdm12/gluetun/internal/updater/unzip"
)
var (
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
ErrNoProviderSpecified = errors.New("no provider was specified")
ErrUsernameMissing = errors.New("username is required for this provider")
ErrPasswordMissing = errors.New("password is required for this provider")
)
type UpdaterLogger interface {
Info(s string)
Warn(s string)
@@ -65,14 +58,14 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
}
if !endUserMode && !maintainerMode {
return fmt.Errorf("%w", ErrModeUnspecified)
return errors.New("at least one of -enduser or -maintainer must be specified")
}
if updateAll {
options.Providers = providers.All()
} else {
if csvProviders == "" {
return fmt.Errorf("%w", ErrNoProviderSpecified)
return errors.New("no provider was specified")
}
options.Providers = strings.Split(csvProviders, ",")
}
+5 -12
View File
@@ -8,13 +8,6 @@ import (
"unicode/utf8"
)
var (
errCommandEmpty = errors.New("command is empty")
errSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
errEscapeUnterminated = errors.New("unterminated backslash-escape")
)
// split splits a command string into a slice of arguments.
// This is especially important for commands such as:
// /bin/sh -c "echo hello"
@@ -25,7 +18,7 @@ var (
// - expansion (brace, shell or pathname).
func split(command string) (words []string, err error) {
if command == "" {
return nil, fmt.Errorf("%w", errCommandEmpty)
return nil, errors.New("command is empty")
}
const bufferSize = 1024
@@ -42,7 +35,7 @@ func split(command string) (words []string, err error) {
case character == '\\':
// Look ahead to eventually skip an escaped newline
if command[startIndex+runeSize:] == "" {
return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command)
return nil, fmt.Errorf("unterminated backslash-escape: %q", command)
}
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
if character == '\n' {
@@ -119,7 +112,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
startIndex = cursor
}
}
return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated)
return "", 0, errors.New("unterminated double-quoted string")
}
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
@@ -127,7 +120,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
) {
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
if closingQuoteIndex == -1 {
return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated)
return "", 0, errors.New("unterminated single-quoted string")
}
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
const singleQuoteRuneLength = 1
@@ -139,7 +132,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
word string, newStartIndex int, err error,
) {
if input[startIndex:] == "" {
return "", 0, fmt.Errorf("%w", errEscapeUnterminated)
return "", 0, errors.New("unterminated backslash-escape")
}
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
if character != '\n' { // backslash-escaped newline is ignored
+4 -9
View File
@@ -12,12 +12,10 @@ func Test_split(t *testing.T) {
testCases := map[string]struct {
command string
words []string
errWrapped error
errMessage string
}{
"empty": {
command: "",
errWrapped: errCommandEmpty,
errMessage: "command is empty",
},
"concrete_sh_command": {
@@ -74,22 +72,18 @@ func Test_split(t *testing.T) {
},
"unterminated_single_quote": {
command: "'abc'\\''def",
errWrapped: errSingleQuoteUnterminated,
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
},
"unterminated_double_quote": {
command: "\"abc'def",
errWrapped: errDoubleQuoteUnterminated,
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
},
"unterminated_escape": {
command: "abc\\",
errWrapped: errEscapeUnterminated,
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
},
"unterminated_escape_only": {
command: " \\",
errWrapped: errEscapeUnterminated,
errMessage: `unterminated backslash-escape: " \\"`,
},
}
@@ -101,9 +95,10 @@ func Test_split(t *testing.T) {
words, err := split(testCase.command)
assert.Equal(t, testCase.words, words)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
if testCase.errMessage != "" {
assert.ErrorContains(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+12 -21
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -177,14 +176,6 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
return node
}
var (
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
const amneziaWG = true
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
@@ -194,16 +185,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
if *a.JunkPacketCount == 0 {
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
}
} else {
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} else if *a.JunkPacketMin > *a.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d",
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
"jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
}
}
@@ -222,20 +213,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
case 1:
_, err := strconv.Atoi(fields[0])
if err != nil {
return fmt.Errorf("%w: %s value %s is not a number",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s is not a number", name, headerRange)
}
case 2: //nolint:mnd
for _, field := range fields {
_, err := strconv.Atoi(field)
if err != nil {
return fmt.Errorf("%w: %s value %s is not a valid range",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s is not a valid range", name, headerRange)
}
}
default:
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s must be in the form n or n-m", name, headerRange)
}
}
+7 -13
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"time"
@@ -48,22 +47,15 @@ type DNS struct {
UpstreamPlainAddresses []netip.AddrPort
}
var (
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
)
func (d DNS) validate() (err error) {
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
}
const minUpdatePeriod = 30 * time.Second
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
return fmt.Errorf("%w: %s must be bigger than %s",
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
return fmt.Errorf("update period is too short: %s must be bigger than %s",
*d.UpdatePeriod, minUpdatePeriod)
}
if d.UpstreamType == DNSUpstreamTypePlain {
@@ -81,9 +73,11 @@ func (d DNS) validate() (err error) {
}
switch {
case *d.IPv6 && !selectedHasPlainIPv6:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
case !*d.IPv6 && !selectedHasPlainIPv4:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
}
}
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/http"
"net/netip"
@@ -37,22 +36,16 @@ func (b *DNSBlacklist) setDefaults() {
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
var (
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
)
func (b DNSBlacklist) validate() (err error) {
for _, host := range b.AllowedHosts {
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host)
return fmt.Errorf("allowed host is not valid: %s", host)
}
}
for _, host := range b.AddBlockedHosts {
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host)
return fmt.Errorf("blocked host is not valid: %s", host)
}
}
@@ -61,7 +54,7 @@ func (b DNSBlacklist) validate() (err error) {
host = host[2:]
}
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
}
}
@@ -209,8 +202,6 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
return ips, ipPrefixes, nil
}
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
ipPrefixes []netip.Prefix, err error,
) {
@@ -236,8 +227,9 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
}
return nil, nil, fmt.Errorf(
"environment variable DOT_PRIVATE_ADDRESS: %w: %s",
ErrPrivateAddressNotValid, privateAddress)
"environment variable DOT_PRIVATE_ADDRESS: "+
"private address is not a valid IP or CIDR range: %s",
privateAddress)
}
return ips, ipPrefixes, nil
-58
View File
@@ -1,58 +0,0 @@
package settings
import "errors"
var (
ErrValueUnknown = errors.New("value is unknown")
ErrCityNotValid = errors.New("the city specified is not valid")
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
ErrCategoryNotValid = errors.New("the category specified is not valid")
ErrCountryNotValid = errors.New("the country specified is not valid")
ErrFilepathMissing = errors.New("filepath is missing")
ErrFirewallZeroPort = errors.New("cannot have a zero port")
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
ErrISPNotValid = errors.New("the ISP specified is not valid")
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
ErrMissingValue = errors.New("missing value")
ErrNameNotValid = errors.New("the server name specified is not valid")
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
ErrRegionNotValid = errors.New("the region specified is not valid")
ErrServerAddressNotValid = errors.New("server listening address is not valid")
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
)
+4 -3
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -20,16 +21,16 @@ type Firewall struct {
func (f Firewall) validate() (err error) {
if hasZeroPort(f.VPNInputPorts) {
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort)
return errors.New("VPN input ports: cannot have a zero port")
}
if hasZeroPort(f.InputPorts) {
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
return errors.New("input ports: cannot have a zero port")
}
for _, subnet := range f.OutboundSubnets {
if subnet.Addr().IsUnspecified() {
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
}
}
@@ -13,25 +13,21 @@ func Test_Firewall_validate(t *testing.T) {
testCases := map[string]struct {
firewall Firewall
errWrapped error
errMessage string
}{
"empty": {
errWrapped: log.ErrLevelNotRecognized,
errMessage: "iptables settings: log level: level is not recognized: ",
},
"zero_vpn_input_port": {
firewall: Firewall{
VPNInputPorts: []uint16{0},
},
errWrapped: ErrFirewallZeroPort,
errMessage: "VPN input ports: cannot have a zero port",
},
"zero_input_port": {
firewall: Firewall{
InputPorts: []uint16{0},
},
errWrapped: ErrFirewallZeroPort,
errMessage: "input ports: cannot have a zero port",
},
"unspecified_outbound_subnet": {
@@ -40,7 +36,6 @@ func Test_Firewall_validate(t *testing.T) {
netip.MustParsePrefix("0.0.0.0/0"),
},
},
errWrapped: ErrFirewallPublicOutboundSubnet,
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
},
"public_outbound_subnet": {
@@ -70,9 +65,10 @@ func Test_Firewall_validate(t *testing.T) {
err := testCase.firewall.validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+4 -10
View File
@@ -38,12 +38,6 @@ type Health struct {
RestartVPN *bool
}
var (
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
)
func (h Health) Validate() (err error) {
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
if err != nil {
@@ -53,16 +47,16 @@ func (h Health) Validate() (err error) {
for _, ip := range h.ICMPTargetIPs {
switch {
case !ip.IsValid():
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
ErrICMPTargetIPsNotCompatible)
return errors.New("ICMP target IP addresses are not compatible: " +
"only a single IP address must be set if it is to be unspecified")
}
}
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
if err != nil {
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
return fmt.Errorf("small check type is not valid: %w", err)
}
return nil
+2 -3
View File
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
// Do not validate user and password
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
if err != nil {
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
return fmt.Errorf("server listening address is not valid: %w", err)
}
return nil
@@ -176,7 +176,6 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
case "disabled", "no", "off":
return ptrTo(false), nil
default:
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s",
ErrValueUnknown, value)
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
}
}
+12 -14
View File
@@ -2,6 +2,7 @@ package settings
import (
"encoding/base64"
"errors"
"fmt"
"regexp"
"strings"
@@ -92,7 +93,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
// Validate version
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err)
return fmt.Errorf("version is not valid: %w", err)
}
isCustom := vpnProvider == providers.Custom
@@ -101,14 +102,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
vpnProvider != providers.VPNSecure
if isUserRequired && *o.User == "" {
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty)
return errors.New("user is empty")
}
passwordRequired := isUserRequired &&
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
if passwordRequired && *o.Password == "" {
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty)
return errors.New("password is empty")
}
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
@@ -132,23 +133,20 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
}
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty)
return errors.New("key passphrase is empty")
}
const maxMSSFix = 10000
if *o.MSSFix > maxMSSFix {
return fmt.Errorf("%w: %d is over the maximum value of %d",
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
}
if !regexpInterfaceName.MatchString(o.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'",
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
}
if *o.Verbosity < 0 || *o.Verbosity > 6 {
return fmt.Errorf("%w: %d can only be between 0 and 5",
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
}
return nil
@@ -162,7 +160,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
}
if confFile == "" {
return fmt.Errorf("%w", ErrFilepathMissing)
return errors.New("filepath is missing")
}
err = validate.FileExists(confFile)
@@ -189,7 +187,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
providers.VPNSecure,
providers.VPNUnlimited:
if clientCert == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
}
@@ -211,7 +209,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
providers.Cyberghost,
providers.VPNUnlimited:
if clientKey == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
}
@@ -230,7 +228,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
encryptedPrivateKey string,
) (err error) {
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
if encryptedPrivateKey == "" {
@@ -62,8 +62,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Perfectprivacy,
providers.Vyprvpn,
) {
return fmt.Errorf("%w: for VPN service provider %s",
ErrOpenVPNTCPNotSupported, vpnProvider)
return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
}
// Validate CustomPort
@@ -78,8 +77,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Nordvpn, providers.Purevpn,
providers.Surfshark, providers.VPNSecure,
providers.VPNUnlimited, providers.Vyprvpn:
return fmt.Errorf("%w: for VPN service provider %s",
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
default:
var allowedTCP, allowedUDP []uint16
switch vpnProvider {
@@ -123,8 +121,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
}
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
if err != nil {
return fmt.Errorf("%w: for VPN service provider %s: %w",
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
}
}
}
@@ -136,7 +133,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
presets.Strong,
}
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err)
return fmt.Errorf("PIA encryption preset is not valid: %w", err)
}
}
+2 -8
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -24,21 +23,16 @@ type PMTUD struct {
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
}
}
return nil
+9 -14
View File
@@ -55,12 +55,6 @@ type PortForwarding struct {
Password string `json:"password"`
}
var (
ErrPortsCountTooHigh = errors.New("ports count too high")
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
ErrListeningPortZero = errors.New("listening port cannot be 0")
)
func (p PortForwarding) Validate(vpnProvider string) (err error) {
if !*p.Enabled {
return nil
@@ -78,7 +72,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
providers.Protonvpn,
}
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
return fmt.Errorf("port forwarding cannot be enabled: %w", err)
}
// Validate Filepath
@@ -94,30 +88,31 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
const maxPortsCount = 1
switch {
case p.PortsCount > maxPortsCount:
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
case p.Username == "":
return fmt.Errorf("%w", ErrPortForwardingUserEmpty)
return errors.New("port forwarding username is empty")
case p.Password == "":
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty)
return errors.New("port forwarding password is empty")
}
case providers.Protonvpn:
const maxPortsCount = 4
if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
}
default:
const maxPortsCount = 1
if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
}
}
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
switch {
case len(p.ListeningPorts) != int(p.PortsCount):
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount)
return fmt.Errorf("listening ports length must be equal to ports count: "+
"%d != %d", len(p.ListeningPorts), p.PortsCount)
case slices.Contains(p.ListeningPorts, 0):
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts)
return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
}
}
+1 -1
View File
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
}
}
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
@@ -15,7 +15,6 @@ func Test_PublicIP_read(t *testing.T) {
makeReader func(ctrl *gomock.Controller) *reader.Reader
makeWarner func(ctrl *gomock.Controller) Warner
settings PublicIP
errWrapped error
errMessage string
}{
"nothing_read": {
@@ -152,9 +151,10 @@ func Test_PublicIP_read(t *testing.T) {
err := settings.read(reader, warner)
assert.Equal(t, testCase.settings, settings)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+1 -2
View File
@@ -46,8 +46,7 @@ func (c ControlServer) validate() (err error) {
uid := os.Getuid()
const maxPrivilegedPort = 1023
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
return fmt.Errorf("%w: %d when running with user ID %d",
ErrControlServerPrivilegedPort, port, uid)
return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
}
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
@@ -71,25 +71,13 @@ type ServerSelection struct {
Wireguard WireguardSelection `json:"wireguard"`
}
var (
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
)
func (ss *ServerSelection) validate(vpnServiceProvider string,
filterChoicesGetter FilterChoicesGetter, warner Warner,
) (err error) {
switch ss.VPN {
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
}
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
@@ -150,7 +138,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
// Only return error comparing with newer regions, we don't want to confuse the user
// with the retro regions in the error message.
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
}
}
@@ -164,27 +152,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
) (err error) {
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
return fmt.Errorf("the country specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
return fmt.Errorf("the region specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
return fmt.Errorf("the city specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
return fmt.Errorf("the ISP specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
return fmt.Errorf("the hostname specified is not valid: %w", err)
}
if vpnServiceProvider == providers.Custom {
@@ -196,19 +184,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
return fmt.Errorf("name is not valid: "+
"%d names specified instead of 0 or 1 for the custom provider",
len(settings.Names))
}
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
return fmt.Errorf("the server name specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
return fmt.Errorf("the category specified is not valid: %w", err)
}
return nil
@@ -255,12 +243,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
switch {
case *settings.FreeOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrFreeOnlyNotSupported)
return errors.New("free only filter is not supported")
case *settings.PremiumOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported)
return errors.New("premium only filter is not supported")
case *settings.FreeOnly && *settings.PremiumOnly:
return fmt.Errorf("%w", ErrFreePremiumBothSet)
return errors.New("free only and premium only filters are both set")
default:
return nil
}
@@ -269,21 +257,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
switch {
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
return errors.New("owned only filter is not supported")
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
return errors.New("port forwarding only filter is not supported: together with free only filter")
case *settings.StreamOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
return errors.New("stream only filter is not supported")
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported)
return errors.New("multi hop only filter is not supported")
case *settings.PortForwardOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported)
return errors.New("port forwarding only filter is not supported")
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported)
return errors.New("secure core only filter is not supported")
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrTorOnlyNotSupported)
return errors.New("tor only filter is not supported")
default:
return nil
}
+8 -7
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"slices"
"strings"
@@ -37,20 +38,20 @@ type Updater struct {
func (u Updater) Validate() (err error) {
const minPeriod = time.Minute
if *u.Period > 0 && *u.Period < minPeriod {
return fmt.Errorf("%w: %d must be larger than %s",
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
return fmt.Errorf("VPN server data updater period is too small: "+
"%d must be larger than %s", *u.Period, minPeriod)
}
if u.MinRatio <= 0 || u.MinRatio > 1 {
return fmt.Errorf("%w: %.2f must be between 0+ and 1",
ErrMinRatioNotValid, u.MinRatio)
return fmt.Errorf("minimum ratio is not valid: "+
"%.2f must be between 0+ and 1", u.MinRatio)
}
validProviders := providers.All()
for _, provider := range u.Providers {
err = validate.IsOneOf(provider, validProviders...)
if err != nil {
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
return fmt.Errorf("VPN provider name is not valid: %w", err)
}
if provider == providers.Protonvpn {
@@ -58,9 +59,9 @@ func (u Updater) Validate() (err error) {
if authenticatedAPI {
switch {
case *u.ProtonEmail == "":
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
return errors.New("proton email is missing")
case *u.ProtonPassword == "":
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
return errors.New("proton password is missing")
}
}
}
+1 -1
View File
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
// Validate Type
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
return fmt.Errorf("VPN type is not valid: %w", err)
}
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
+11 -15
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"regexp"
@@ -54,7 +55,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
// Validate PrivateKey
if *w.PrivateKey == "" {
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
return errors.New("private key is not set")
}
_, err = wgtypes.ParseKey(*w.PrivateKey)
if err != nil {
@@ -68,7 +69,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
if vpnProvider == providers.Airvpn {
if *w.PreSharedKey == "" {
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet)
return errors.New("pre-shared key is not set")
}
}
@@ -82,17 +83,15 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// Validate Addresses
if len(w.Addresses) == 0 {
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
return errors.New("interface address is not set")
}
for i, ipNet := range w.Addresses {
if !ipNet.IsValid() {
return fmt.Errorf("%w: for address at index %d",
ErrWireguardInterfaceAddressNotSet, i)
return fmt.Errorf("interface address is not set: for address at index %d", i)
}
if !ipv6Supported && ipNet.Addr().Is6() {
return fmt.Errorf("%w: address %s",
ErrWireguardInterfaceAddressIPv6, ipNet.String())
return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
}
}
@@ -100,30 +99,27 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// WARNING: do not check for IPv6 networks in the allowed IPs,
// the wireguard code will take care to ignore it.
if len(w.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
return errors.New("allowed IPs is not set")
}
for i, allowedIP := range w.AllowedIPs {
if !allowedIP.IsValid() {
return fmt.Errorf("%w: for allowed ip %d of %d",
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
}
}
if *w.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
*w.PersistentKeepaliveInterval)
return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
}
// Validate interface
if !regexpInterfaceName.MatchString(w.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'",
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
}
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
validImplementations := []string{"auto", "userspace", "kernelspace"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
return fmt.Errorf("implementation is not valid: %w", err)
}
}
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -44,7 +45,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// endpoint IP addresses are baked in
case providers.Custom:
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet)
return errors.New("endpoint IP is not set")
}
default: // Providers not supporting Wireguard
}
@@ -54,13 +55,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// EndpointPort is required
case providers.Custom:
if *w.EndpointPort == 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
return errors.New("endpoint port is not set")
}
// EndpointPort cannot be set
case providers.Fastestvpn, providers.Nordvpn,
providers.Protonvpn, providers.Surfshark:
if *w.EndpointPort != 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
return errors.New("endpoint port is set")
}
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
// EndpointPort is optional and can be 0
@@ -84,8 +85,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
if err == nil {
break
}
return fmt.Errorf("%w: for VPN service provider %s: %w",
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
default: // Providers not supporting Wireguard
}
@@ -96,15 +96,14 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// public keys are baked in
case providers.Custom:
if w.PublicKey == "" {
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet)
return errors.New("public key is not set")
}
default: // Providers not supporting Wireguard
}
if w.PublicKey != "" {
_, err := wgtypes.ParseKey(w.PublicKey)
if err != nil {
return fmt.Errorf("%w: %s: %s",
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
}
}
@@ -74,8 +74,6 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
return privateKey, addresses
}
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
func parseWireguardPeerSection(peerSection *ini.Section) (
preSharedKey, publicKey, endpointIP, endpointPort *string,
) {
+1 -4
View File
@@ -3,7 +3,6 @@ package dns
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"math/rand/v2"
@@ -63,8 +62,6 @@ func generateRandomString(length uint) string {
return string(b)
}
var errIPLeakSessionMismatch = errors.New("ipleak.net session mismatch")
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
dnsToCount map[string]uint, err error,
) {
@@ -93,7 +90,7 @@ func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
if err != nil {
return nil, fmt.Errorf("decoding response: %w", err)
} else if data.Session != session {
return nil, fmt.Errorf("%w: expected %s, got %s", errIPLeakSessionMismatch, session, data.Session)
return nil, fmt.Errorf("ipleak.net session mismatch: expected %s, got %s", session, data.Session)
}
return data.IP, nil
+5 -9
View File
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
},
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
Return("", errors.New("test error"))
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+1 -3
View File
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
return fmt.Errorf("policy is not valid: %s", policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
+9 -12
View File
@@ -2,7 +2,6 @@ package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
@@ -13,10 +12,8 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
const (
needIP6Tables = "ip6tables is required, please upgrade your kernel"
)
func appendOrDelete(remove bool) string {
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
words := strings.Fields(output)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
return "", fmt.Errorf("iptables version string is too short: %s", output)
}
return "iptables " + words[1], nil
}
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
return fmt.Errorf("unknown policy: %s", policy)
}
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
if ip.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
if doIPv4 {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
case ipv4:
err = c.runIptablesInstructionNoSave(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
default: // ipv6
err = c.runIP6tablesInstructionNoSave(ctx, rule)
}
+27 -47
View File
@@ -40,8 +40,6 @@ type mark struct {
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
expectedNumberOfFields, line)
}
// Sanity checks
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
return base, nil
}
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
return chainRule{}, errors.New("chain rule is malformed: empty line")
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
}
for fieldIndex, field := range fields[:minFields] {
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
}
const (
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
optionalFields[1])
}
case "ctstate":
i++
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.mark = mark
i += consumed
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
optionalFields[i])
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
}
}
return consumed, nil
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
return 0, errors.New("line number is zero")
}
return uint16(lineNumber), nil
}
var ErrTargetUnknown = errors.New("unknown target")
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
return fmt.Errorf("unknown target: %s", target)
}
var ErrProtocolUnknown = errors.New("unknown protocol")
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0", "all":
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
case "17", "udp":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
return "", fmt.Errorf("unknown protocol: %s", s)
}
return protocol, nil
}
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
return n, errors.New("metric size is malformed: empty string")
}
//nolint:mnd
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
return n, fmt.Errorf("metric size is malformed: %w", err)
}
n *= multiplier
return n, nil
+3 -7
View File
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+10 -12
View File
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
}
fields := strings.Fields(s)
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
}
return consumed, nil
}
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
flag)
}
return expected, nil
}
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
fields[consumed])
}
return consumed, nil
}
+3 -6
View File
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+4 -9
View File
@@ -10,12 +10,7 @@ import (
"strings"
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
)
var ErrNotSupported = errors.New("no iptables supported found")
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
iptablesPathsToTry ...string,
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
if allArePermissionDenied {
// If the error is related to a denied permission for all iptables path,
// return an error describing what to do from an end-user perspective.
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
}
return "", fmt.Errorf("%w: errors encountered are: %s",
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
output, err = runner.Run(cmd)
if err != nil {
// this is a critical error, we want to make sure our test rule gets removed.
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
return false, "", criticalErr
}
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
}
if inputPolicy == "" {
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
return false, "", criticalErr
}
+6 -12
View File
@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newAppendTestRuleMatcher(path string) *cmdMatcher {
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner
iptablesPathsToTry []string
iptablesPath string
errSentinel error
errMessage string
}{
"critical error when checking": {
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrTestRuleCleanup,
errMessage: "for path1: failed cleaning up test rule: " +
"output (exit code 4)",
},
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNetAdminMissing,
errMessage: "NET_ADMIN capability is missing: " +
"path1: Permission denied (you must be root) more context (exit code 4); " +
"path2: context: Permission denied (you must be root) (exit code 4)",
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
require.ErrorIs(t, err, testCase.errSentinel)
if testCase.errSentinel != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.iptablesPath, iptablesPath)
})
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner
ok bool
unsupportedMessage string
criticalErrWrapped error
criticalErrMessage string
}{
"append test rule permission denied": {
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some output", errDummy)
return runner
},
criticalErrWrapped: ErrTestRuleCleanup,
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
},
"list input rules permission denied": {
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some\noutput", nil)
return runner
},
criticalErrWrapped: ErrInputPolicyNotFound,
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
},
"set policy permission denied": {
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
assert.Equal(t, testCase.ok, ok)
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
if testCase.criticalErrWrapped != nil {
if testCase.criticalErrMessage != "" {
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
} else {
assert.NoError(t, criticalErr)
}
})
}
+2 -4
View File
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
panic(fmt.Sprintf("unknown TCP flag: %d", f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
return 0, fmt.Errorf("unknown TCP flag: %s", s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
+2 -4
View File
@@ -266,8 +266,6 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
return address, nil
}
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
logger Logger, checkName string, check func(ctx context.Context, try int) error,
) error {
@@ -297,7 +295,7 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
for i, err := range errs {
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
}
return fmt.Errorf("%w:\n\t%s", ErrAllCheckTriesFailed, strings.Join(errStrings, "\n\t"))
return fmt.Errorf("all check tries failed:\n\t%s", strings.Join(errStrings, "\n\t"))
}
func (c *Checker) startupCheck(ctx context.Context) error {
@@ -342,7 +340,7 @@ func (c *Checker) startupCheck(ctx context.Context) error {
for i, err := range errs {
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
}
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
return fmt.Errorf("all check tries failed: %s", strings.Join(errStrings, ", "))
}
const (
+5 -6
View File
@@ -2,7 +2,6 @@ package healthcheck
import (
"context"
"fmt"
"net"
"testing"
"time"
@@ -68,7 +67,7 @@ func Test_makeAddressToDial(t *testing.T) {
testCases := map[string]struct {
address string
addressToDial string
err error
errMessage string
}{
"host without port": {
address: "test.com",
@@ -79,8 +78,8 @@ func Test_makeAddressToDial(t *testing.T) {
addressToDial: "test.com:80",
},
"bad address": {
address: "test.com::",
err: fmt.Errorf("splitting host and port from address: address test.com::: too many colons in address"), //nolint:lll
address: "test.com::",
errMessage: "splitting host and port from address: address test.com::: too many colons in address",
},
}
@@ -91,8 +90,8 @@ func Test_makeAddressToDial(t *testing.T) {
addressToDial, err := makeAddressToDial(testCase.address)
assert.Equal(t, testCase.addressToDial, addressToDial)
if testCase.err != nil {
assert.EqualError(t, err, testCase.err.Error())
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
+1 -4
View File
@@ -2,15 +2,12 @@ package healthcheck
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"time"
)
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
type Client struct {
httpClient *http.Client
}
@@ -41,6 +38,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
if err != nil {
return err
}
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
return fmt.Errorf("HTTP response status is not OK: %d %s: %s",
response.StatusCode, response.Status, string(b))
}
+1 -4
View File
@@ -2,7 +2,6 @@ package dns
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@@ -41,8 +40,6 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
return result
}
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
func (c *Client) Check(ctx context.Context) error {
dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
resolver := &net.Resolver{
@@ -59,7 +56,7 @@ func (c *Client) Check(ctx context.Context) error {
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
case len(ips) == 0:
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs)
return fmt.Errorf("with DNS server %s: no IPs found from DNS lookup", dnsAddr)
default:
return nil
}
+1 -3
View File
@@ -12,11 +12,9 @@ type handler struct {
logger Logger
}
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
func newHandler(logger Logger) *handler {
return &handler{
healthErr: errHealthcheckNotRunYet,
healthErr: errors.New("healthcheck did not run yet"),
logger: logger,
}
}
+6 -13
View File
@@ -19,11 +19,6 @@ import (
"golang.org/x/net/ipv6"
)
var (
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
)
type Echoer struct {
buffer []byte
randomSource io.Reader
@@ -60,10 +55,7 @@ func (e *Echoer) Reset() {
e.seqStart = time.Now()
}
var (
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
ErrNotPermitted = errors.New("not permitted")
)
var ErrNotPermitted = errors.New("not permitted")
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
var ipVersion string
@@ -114,14 +106,14 @@ func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
if err != nil {
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
return fmt.Errorf("%w from %s", ErrTimedOut, ip)
return fmt.Errorf("timed out waiting for ICMP echo reply from %s", ip)
}
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
}
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
if !bytes.Equal(receivedData, sentData) {
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData)
return fmt.Errorf("ICMP data mismatch: sent %x to %s and received %x", sentData, ip, receivedData)
}
return nil
@@ -216,8 +208,9 @@ func receiveEchoReply(conn net.PacketConn, id, seq int, buffer []byte, ipVersion
message.Code, returnAddr, id, seq)
continue
default:
return nil, fmt.Errorf("%w: %T (type %d, code %d, return address %s, expected id %d and seq %d)",
ErrICMPBodyUnsupported, body, message.Type, message.Code, returnAddr, id, seq)
return nil, fmt.Errorf("ICMP body type is not supported: "+
"%T (type %d, code %d, return address %s, expected id %d and seq %d)",
body, message.Type, message.Code, returnAddr, id, seq)
}
}
}
+4 -6
View File
@@ -6,7 +6,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
@@ -20,11 +19,9 @@ func Test_New(t *testing.T) {
testCases := map[string]struct {
settings Settings
expected *Server
errWrapped error
errMessage string
}{
"empty settings": {
errWrapped: ErrHandlerIsNotSet,
errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
},
"filled settings": {
@@ -52,9 +49,10 @@ func Test_New(t *testing.T) {
t.Parallel()
server, err := New(testCase.settings)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
require.EqualError(t, err, testCase.errMessage)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
if server != nil {
+5 -19
View File
@@ -64,14 +64,6 @@ func (s *Settings) OverrideWith(other Settings) {
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
}
var (
ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset")
ErrLoggerIsNotSet = errors.New("logger cannot be left unset")
ErrReadHeaderTimeoutTooSmall = errors.New("read header timeout is too small")
ErrReadTimeoutTooSmall = errors.New("read timeout is too small")
ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small")
)
func (s Settings) Validate() (err error) {
err = validate.ListeningAddress(s.Address, os.Getuid())
if err != nil {
@@ -79,31 +71,25 @@ func (s Settings) Validate() (err error) {
}
if s.Handler == nil {
return fmt.Errorf("%w", ErrHandlerIsNotSet)
return errors.New("HTTP handler cannot be left unset")
}
if s.Logger == nil {
return fmt.Errorf("%w", ErrLoggerIsNotSet)
return errors.New("logger cannot be left unset")
}
const minReadTimeout = time.Millisecond
if s.ReadHeaderTimeout < minReadTimeout {
return fmt.Errorf("%w: %s must be at least %s",
ErrReadHeaderTimeoutTooSmall,
s.ReadHeaderTimeout, minReadTimeout)
return fmt.Errorf("read header timeout is too small: %s must be at least %s", s.ReadHeaderTimeout, minReadTimeout)
}
if s.ReadTimeout < minReadTimeout {
return fmt.Errorf("%w: %s must be at least %s",
ErrReadTimeoutTooSmall,
s.ReadTimeout, minReadTimeout)
return fmt.Errorf("read timeout is too small: %s must be at least %s", s.ReadTimeout, minReadTimeout)
}
const minShutdownTimeout = 5 * time.Millisecond
if s.ShutdownTimeout < minShutdownTimeout {
return fmt.Errorf("%w: %s must be at least %s",
ErrShutdownTimeoutTooSmall,
s.ShutdownTimeout, minShutdownTimeout)
return fmt.Errorf("shutdown timeout is too small: %s must be at least %s", s.ShutdownTimeout, minShutdownTimeout)
}
return nil
+4 -11
View File
@@ -5,7 +5,6 @@ import (
"testing"
"time"
"github.com/qdm12/gosettings/validate"
"github.com/stretchr/testify/assert"
)
@@ -189,30 +188,26 @@ func Test_Settings_Validate(t *testing.T) {
testCases := map[string]struct {
settings Settings
errWrapped error
errMessage string
}{
"bad_address": {
settings: Settings{
Address: "address:notanint",
},
errWrapped: validate.ErrPortNotAnInteger,
errMessage: "port value is not an integer: notanint",
},
"nil handler": {
settings: Settings{
Address: ":8000",
},
errWrapped: ErrHandlerIsNotSet,
errMessage: ErrHandlerIsNotSet.Error(),
errMessage: "HTTP handler cannot be left unset",
},
"nil logger": {
settings: Settings{
Address: ":8000",
Handler: someHandler,
},
errWrapped: ErrLoggerIsNotSet,
errMessage: ErrLoggerIsNotSet.Error(),
errMessage: "logger cannot be left unset",
},
"read header timeout too small": {
settings: Settings{
@@ -221,7 +216,6 @@ func Test_Settings_Validate(t *testing.T) {
Logger: someLogger,
ReadHeaderTimeout: time.Nanosecond,
},
errWrapped: ErrReadHeaderTimeoutTooSmall,
errMessage: "read header timeout is too small: 1ns must be at least 1ms",
},
"read timeout too small": {
@@ -232,7 +226,6 @@ func Test_Settings_Validate(t *testing.T) {
ReadHeaderTimeout: time.Millisecond,
ReadTimeout: time.Nanosecond,
},
errWrapped: ErrReadTimeoutTooSmall,
errMessage: "read timeout is too small: 1ns must be at least 1ms",
},
"shutdown timeout too small": {
@@ -244,7 +237,6 @@ func Test_Settings_Validate(t *testing.T) {
ReadTimeout: time.Millisecond,
ShutdownTimeout: time.Millisecond,
},
errWrapped: ErrShutdownTimeoutTooSmall,
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
},
"valid settings": {
@@ -265,9 +257,10 @@ func Test_Settings_Validate(t *testing.T) {
err := testCase.settings.Validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+2 -5
View File
@@ -2,15 +2,12 @@ package loopstate
import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
var ErrInvalidStatus = errors.New("invalid status")
// ApplyStatus sends signals to the running loop depending on the
// current status and status requested, such that its next status
// matches the requested one. It is thread safe and a synchronous call
@@ -73,7 +70,7 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
return newStatus.String(), nil
default:
s.statusMu.Unlock()
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
ErrInvalidStatus, status, constants.Running, constants.Stopped)
return "", fmt.Errorf("invalid status: %s: it can only be one of: %s, %s",
status, constants.Running, constants.Stopped)
}
}
+4 -12
View File
@@ -3,19 +3,11 @@ package mod
import (
"bufio"
"compress/gzip"
"errors"
"fmt"
"os"
"strings"
)
var (
errModuleNameUnknown = errors.New("unknown module name")
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
errKernelFeatureNotSet = errors.New("kernel feature not set")
errKernelFeatureNotFound = errors.New("kernel feature not found")
)
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
@@ -39,7 +31,7 @@ func checkProcConfig(moduleName string) error {
// If any group of kernel features is satisfied, then the module is considered supported.
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
if !ok {
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
return fmt.Errorf("unknown module name: %s", moduleName)
}
groups := make([]map[string]bool, len(kernelFeatureGroups))
for i, group := range kernelFeatureGroups {
@@ -58,20 +50,20 @@ func checkProcConfig(moduleName string) error {
switch {
case ok:
case strings.HasPrefix(line, name+"=m"):
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
return fmt.Errorf("kernel feature is a module, not built-in: %s", name)
case strings.HasPrefix(line, name+"=y"):
featureToOK[name] = true
if allFeaturesOK(featureToOK) {
return nil
}
case strings.HasPrefix(line, "# "+name+" is not set"):
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
return fmt.Errorf("kernel feature not set: %s", name)
}
}
}
}
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
return fmt.Errorf("kernel feature not found: for module name %s", moduleName)
}
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
+1 -3
View File
@@ -181,8 +181,6 @@ func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) {
return nil
}
var ErrModulePathNotFound = errors.New("module path not found")
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
// Kernel module names can have underscores or hyphens in their names,
// but only one or the other in one particular name.
@@ -205,5 +203,5 @@ func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modul
}
}
return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName)
return "", fmt.Errorf("module path not found: for %q", moduleName)
}
+2 -9
View File
@@ -1,7 +1,6 @@
package mod
import (
"errors"
"fmt"
"io"
"os"
@@ -14,15 +13,10 @@ import (
"golang.org/x/sys/unix"
)
var (
ErrModuleInfoNotFound = errors.New("module info not found")
ErrCircularDependency = errors.New("circular dependency")
)
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
info, ok := modulesInfo[path]
if !ok {
return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path)
return fmt.Errorf("module info not found: %s", path)
}
switch info.state {
@@ -30,8 +24,7 @@ func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error
case loaded, builtin:
return nil
case loading:
return fmt.Errorf("%w: %s is already in the loading state",
ErrCircularDependency, path)
return fmt.Errorf("circular dependency: %s is already in the loading state", path)
}
info.state = loading
+1 -4
View File
@@ -1,7 +1,6 @@
package models
import (
"errors"
"fmt"
"strings"
@@ -109,8 +108,6 @@ func (s *Servers) toMarkdown(vpnProvider string) (formatted string, err error) {
return formatted, nil
}
var ErrMarkdownHeadersNotDefined = errors.New("markdown headers not defined")
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
switch vpnProvider {
case providers.Airvpn:
@@ -169,6 +166,6 @@ func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
case providers.Windscribe:
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
default:
return nil, fmt.Errorf("%w: for %s", ErrMarkdownHeadersNotDefined, vpnProvider)
return nil, fmt.Errorf("markdown headers not defined: for %s", vpnProvider)
}
}
+3 -4
View File
@@ -15,12 +15,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
provider string
servers Servers
formatted string
errWrapped error
errMessage string
}{
"unsupported_provider": {
provider: "unsupported",
errWrapped: ErrMarkdownHeadersNotDefined,
errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
},
providers.Cyberghost: {
@@ -58,9 +56,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
markdown, err := testCase.servers.toMarkdown(testCase.provider)
assert.Equal(t, testCase.formatted, markdown)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+5 -14
View File
@@ -38,27 +38,18 @@ type Server struct {
IPs []netip.Addr `json:"ips,omitempty"`
}
var (
ErrVPNFieldEmpty = errors.New("vpn field is empty")
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
ErrIPsFieldEmpty = errors.New("ips field is empty")
ErrNoNetworkProtocol = errors.New("both TCP and UDP fields are false for OpenVPN")
ErrNetworkProtocolSet = errors.New("no network protocol should be set")
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
)
func (s *Server) HasMinimumInformation() (err error) {
switch {
case s.VPN == "":
return fmt.Errorf("%w", ErrVPNFieldEmpty)
return errors.New("vpn field is empty")
case len(s.IPs) == 0:
return fmt.Errorf("%w", ErrIPsFieldEmpty)
return errors.New("ips field is empty")
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
return fmt.Errorf("%w", ErrNetworkProtocolSet)
return errors.New("no network protocol should be set")
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
return fmt.Errorf("%w", ErrNoNetworkProtocol)
return errors.New("both TCP and UDP fields are false for OpenVPN")
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
return errors.New("wireguard public key field is empty")
default:
return nil
}
+1 -4
View File
@@ -3,7 +3,6 @@ package models
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
@@ -158,8 +157,6 @@ type Servers struct {
Servers []Server `json:"servers,omitempty"`
}
var ErrServersFormatNotSupported = errors.New("servers format not supported")
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
switch format {
case "markdown":
@@ -167,7 +164,7 @@ func (s *Servers) Format(vpnProvider, format string) (formatted string, err erro
case "json":
return s.toJSON()
default:
return "", fmt.Errorf("%w: %s", ErrServersFormatNotSupported, format)
return "", fmt.Errorf("servers format not supported: %s", format)
}
}
+9 -8
View File
@@ -16,7 +16,6 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
testCases := map[string]struct {
allServers *AllServers
dataString string
errWrapped error
errMessage string
}{
"no provider": {
@@ -58,16 +57,18 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
t.Parallel()
data, err := testCase.allServers.MarshalJSON()
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
require.Equal(t, testCase.dataString, string(data))
data, err = json.Marshal(testCase.allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
require.Equal(t, testCase.dataString, string(data))
@@ -87,7 +88,6 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
testCases := map[string]struct {
dataString string
allServers AllServers
errWrapped error
errMessage string
}{
"empty": {
@@ -131,9 +131,10 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
err := json.Unmarshal(data, &allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.allServers, allServers)
})
+16 -33
View File
@@ -6,48 +6,40 @@ import (
"fmt"
)
var ErrRequestSizeTooSmall = errors.New("message size is too small")
func checkRequest(request []byte) (err error) {
const minMessageSize = 2 // version number + operation code
if len(request) < minMessageSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
ErrRequestSizeTooSmall, minMessageSize, len(request))
return fmt.Errorf("message size is too small: need at least %d bytes and got %d byte(s)",
minMessageSize, len(request))
}
return nil
}
var (
ErrResponseSizeTooSmall = errors.New("response size is too small")
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
)
func checkResponse(response []byte, expectedOperationCode byte,
expectedResponseSize uint,
) (err error) {
const minResponseSize = 4
if len(response) < minResponseSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
ErrResponseSizeTooSmall, minResponseSize, len(response))
return fmt.Errorf("response size is too small: "+
"need at least %d bytes and got %d byte(s)",
minResponseSize, len(response))
}
if uint(len(response)) != expectedResponseSize {
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
return fmt.Errorf("response size is unexpected: "+
"expected %d bytes and got %d byte(s)",
expectedResponseSize, len(response))
}
protocolVersion := response[0]
if protocolVersion != 0 {
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
return fmt.Errorf("protocol version is unknown: %d", protocolVersion)
}
operationCode := response[1]
if operationCode != expectedOperationCode {
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
return fmt.Errorf("operation code is unexpected: expected 0x%x and got 0x%x", expectedOperationCode, operationCode)
}
resultCode := binary.BigEndian.Uint16(response[2:4])
@@ -59,15 +51,6 @@ func checkResponse(response []byte, expectedOperationCode byte,
return nil
}
var (
ErrVersionNotSupported = errors.New("version is not supported")
ErrNotAuthorized = errors.New("not authorized")
ErrNetworkFailure = errors.New("network failure")
ErrOutOfResources = errors.New("out of resources")
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
ErrResultCodeUnknown = errors.New("result code is unknown")
)
// checkResultCode checks the result code and returns an error
// if the result code is not a success (0).
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
@@ -78,16 +61,16 @@ func checkResultCode(resultCode uint16) (err error) {
case 0:
return nil
case 1:
return fmt.Errorf("%w", ErrVersionNotSupported)
return errors.New("version is not supported")
case 2:
return fmt.Errorf("%w", ErrNotAuthorized)
return errors.New("not authorized")
case 3:
return fmt.Errorf("%w", ErrNetworkFailure)
return errors.New("network failure")
case 4:
return fmt.Errorf("%w", ErrOutOfResources)
return errors.New("out of resources")
case 5:
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
return errors.New("operation code is not supported")
default:
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
return fmt.Errorf("result code is unknown: %d", resultCode)
}
}
+21 -17
View File
@@ -1,6 +1,7 @@
package natpmp
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -11,12 +12,10 @@ func Test_checkRequest(t *testing.T) {
testCases := map[string]struct {
request []byte
err error
errMessage string
}{
"too_short": {
request: []byte{1},
err: ErrRequestSizeTooSmall,
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
},
"success": {
@@ -30,9 +29,10 @@ func Test_checkRequest(t *testing.T) {
err := checkRequest(testCase.request)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
@@ -50,33 +50,33 @@ func Test_checkResponse(t *testing.T) {
}{
"too_short": {
response: []byte{1},
err: ErrResponseSizeTooSmall,
err: errors.New("response size is too small"),
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
},
"size_mismatch": {
response: []byte{0, 0, 0, 0},
expectedResponseSize: 5,
err: ErrResponseSizeUnexpected,
err: errors.New("response size is unexpected"),
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
},
"protocol_unknown": {
response: []byte{1, 0, 0, 0},
expectedResponseSize: 4,
err: ErrProtocolVersionUnknown,
err: errors.New("protocol version is unknown"),
errMessage: "protocol version is unknown: 1",
},
"operation_code_unexpected": {
response: []byte{0, 2, 0, 0},
expectedOperationCode: 1,
expectedResponseSize: 4,
err: ErrOperationCodeUnexpected,
err: errors.New("operation code is unexpected"),
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
},
"result_code_failure": {
response: []byte{0, 1, 0, 1},
expectedOperationCode: 1,
expectedResponseSize: 4,
err: ErrVersionNotSupported,
err: errors.New("version is not supported"),
errMessage: "result code: version is not supported",
},
"success": {
@@ -94,9 +94,11 @@ func Test_checkResponse(t *testing.T) {
testCase.expectedOperationCode,
testCase.expectedResponseSize)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.ErrorContains(t, err, testCase.err.Error())
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
@@ -113,32 +115,32 @@ func Test_checkResultCode(t *testing.T) {
"success": {},
"version_unsupported": {
resultCode: 1,
err: ErrVersionNotSupported,
err: errors.New("version is not supported"),
errMessage: "version is not supported",
},
"not_authorized": {
resultCode: 2,
err: ErrNotAuthorized,
err: errors.New("not authorized"),
errMessage: "not authorized",
},
"network_failure": {
resultCode: 3,
err: ErrNetworkFailure,
err: errors.New("network failure"),
errMessage: "network failure",
},
"out_of_resources": {
resultCode: 4,
err: ErrOutOfResources,
err: errors.New("out of resources"),
errMessage: "out of resources",
},
"unsupported_operation_code": {
resultCode: 5,
err: ErrOperationCodeNotSupported,
err: errors.New("operation code is not supported"),
errMessage: "operation code is not supported",
},
"unknown": {
resultCode: 6,
err: ErrResultCodeUnknown,
err: errors.New("result code is unknown"),
errMessage: "result code is unknown: 6",
},
}
@@ -149,9 +151,11 @@ func Test_checkResultCode(t *testing.T) {
err := checkResultCode(testCase.resultCode)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.ErrorContains(t, err, testCase.err.Error())
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+4 -9
View File
@@ -3,17 +3,11 @@ package natpmp
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net/netip"
"time"
)
var (
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
ErrLifetimeTooLong = errors.New("lifetime is too long")
)
// Add or delete a port mapping. To delete a mapping, set both the
// requestedExternalPort and lifetime to 0.
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
@@ -26,8 +20,9 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
lifetimeSecondsFloat := lifetime.Seconds()
const maxLifetimeSeconds = uint64(^uint32(0))
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds",
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
return 0, 0, 0, 0, fmt.Errorf("lifetime is too long: "+
"%d seconds must at most %d seconds",
uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
}
const messageSize = 12
message := make([]byte, messageSize)
@@ -38,7 +33,7 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
case "tcp":
message[1] = 2 // operationCode 2
default:
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol)
return 0, 0, 0, 0, fmt.Errorf("network protocol is unknown: %s", protocol)
}
// [2:3] are reserved.
binary.BigEndian.PutUint16(message[4:6], internalPort)
-7
View File
@@ -25,18 +25,15 @@ func Test_Client_AddPortMapping(t *testing.T) {
assignedInternalPort uint16
assignedExternalPort uint16
assignedLifetime time.Duration
err error
errMessage string
}{
"lifetime_too_long": {
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
err: ErrLifetimeTooLong,
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
},
"protocol_unknown": {
lifetime: time.Second,
protocol: "xyz",
err: ErrNetworkProtocolUnknown,
errMessage: "network protocol is unknown: xyz",
},
"rpc_error": {
@@ -48,7 +45,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
lifetime: 1200 * time.Second,
initialConnectionDuration: time.Millisecond,
exchanges: []udpExchange{{close: true}},
err: ErrConnectionTimeout,
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
},
@@ -136,9 +132,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
if testCase.errMessage != "" {
if testCase.err != nil {
assert.ErrorIs(t, err, testCase.err)
}
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
} else {
assert.NoError(t, err)
+2 -8
View File
@@ -11,17 +11,12 @@ import (
"time"
)
var (
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
ErrConnectionTimeout = errors.New("connection timeout")
)
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
request []byte, responseSize uint) (
response []byte, err error,
) {
if gateway.IsUnspecified() || !gateway.IsValid() {
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified)
return nil, errors.New("gateway IP is unspecified")
}
err = checkRequest(request)
@@ -114,8 +109,7 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
}
if retryCount == c.maxRetries {
return nil, fmt.Errorf("%w: failed attempts: %s",
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
return nil, fmt.Errorf("connection timeout: failed attempts: %s", dedupFailedAttempts(failedAttempts))
}
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
-12
View File
@@ -20,20 +20,17 @@ func Test_Client_rpc(t *testing.T) {
initialConnectionDuration time.Duration
exchanges []udpExchange
expectedResponse []byte
err error
errMessage string
}{
"gateway_ip_unspecified": {
gateway: netip.IPv6Unspecified(),
request: []byte{0, 0},
err: ErrGatewayIPUnspecified,
errMessage: "gateway IP is unspecified",
},
"request_too_small": {
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
request: []byte{0},
initialConnectionDuration: time.Nanosecond, // doesn't matter
err: ErrRequestSizeTooSmall,
errMessage: `checking request: message size is too small: ` +
`need at least 2 bytes and got 1 byte\(s\)`,
},
@@ -53,7 +50,6 @@ func Test_Client_rpc(t *testing.T) {
exchanges: []udpExchange{
{request: []byte{0, 1}, close: true},
},
err: ErrConnectionTimeout,
errMessage: "connection timeout: failed attempts: " +
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
},
@@ -66,7 +62,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0, 0},
response: []byte{1},
}},
err: ErrResponseSizeTooSmall,
errMessage: `checking response: response size is too small: ` +
`need at least 4 bytes and got 1 byte\(s\)`,
},
@@ -80,7 +75,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0, 1, 2, 3}, // size 4
}},
err: ErrResponseSizeUnexpected,
errMessage: `checking response: response size is unexpected: ` +
`expected 5 bytes and got 4 byte\(s\)`,
},
@@ -94,7 +88,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}},
err: ErrProtocolVersionUnknown,
errMessage: "checking response: protocol version is unknown: 1",
},
"unexpected_operation_code": {
@@ -107,7 +100,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}},
err: ErrOperationCodeUnexpected,
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
},
"failure_result_code": {
@@ -120,7 +112,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}},
err: ErrResultCodeUnknown,
errMessage: "checking response: result code: result code is unknown: 17",
},
"success": {
@@ -153,9 +144,6 @@ func Test_Client_rpc(t *testing.T) {
testCase.request, testCase.responseSize)
if testCase.errMessage != "" {
if testCase.err != nil {
assert.ErrorIs(t, err, testCase.err)
}
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
} else {
assert.NoError(t, err)
+2 -4
View File
@@ -41,8 +41,6 @@ func findAvailableTCPPort(t *testing.T) (port uint16) {
func Test_dialAddrThroughFirewall(t *testing.T) {
t.Parallel()
errTest := errors.New("test error")
const ipv6InternetWorks = false
testCases := map[string]struct {
@@ -102,7 +100,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
},
},
"firewall_add_error": {
firewallAddErr: errTest,
firewallAddErr: errors.New("test error"),
errMessageRegex: func() string {
return "accepting output traffic: test error"
},
@@ -122,7 +120,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
firewallRemoveErr: errTest,
firewallRemoveErr: errors.New("test error"),
errMessageRegex: func() string {
return "removing output traffic rule: test error"
},
+3 -6
View File
@@ -1,7 +1,6 @@
package netlink
import (
"errors"
"fmt"
"github.com/jsimonetti/rtnetlink"
@@ -47,8 +46,6 @@ func (n *NetLink) LinkList() (links []Link, err error) {
return links, nil
}
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) {
links, err := n.LinkList()
if err != nil {
@@ -61,7 +58,7 @@ func (n *NetLink) LinkByName(name string) (link Link, err error) {
}
}
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
return Link{}, fmt.Errorf("link not found: for name %s", name)
}
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
@@ -76,7 +73,7 @@ func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
}
}
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
return Link{}, fmt.Errorf("link not found: for index %d", index)
}
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
@@ -114,7 +111,7 @@ func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
return 0, fmt.Errorf("link not found: matching name %s", link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
-6
View File
@@ -1,17 +1,11 @@
package extract
import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrRead = errors.New("cannot read file")
ErrExtractConnection = errors.New("cannot extract connection from file")
)
// Data extracts the lines and connection from the OpenVPN configuration file.
func (e *Extractor) Data(filepath string) (lines []string,
connection models.Connection, err error,
+10 -24
View File
@@ -11,8 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
var errRemoteLineNotFound = errors.New("remote line not found")
func extractDataFromLines(lines []string) (
connection models.Connection, err error,
) {
@@ -35,7 +33,7 @@ func extractDataFromLines(lines []string) (
}
if !connection.IP.IsValid() {
return connection, errRemoteLineNotFound
return connection, errors.New("remote line not found")
}
if connection.Protocol == "" {
@@ -81,19 +79,15 @@ func extractDataFromLine(line string) (
return ip, 0, "", nil
}
var errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected")
func extractProto(line string) (protocol string, err error) {
fields := strings.Fields(line)
if len(fields) != 2 { //nolint:mnd
return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line)
return "", fmt.Errorf("proto line has not 2 fields as expected: %s", line)
}
return parseProto(fields[1])
}
var errProtocolNotSupported = errors.New("network protocol not supported")
func parseProto(field string) (protocol string, err error) {
switch field {
case "tcp", "tcp4", "tcp6", "tcp-client":
@@ -106,16 +100,10 @@ func parseProto(field string) (protocol string, err error) {
// determined by the remote IP address version.
return constants.UDP, nil
default:
return "", fmt.Errorf("%w: %s", errProtocolNotSupported, field)
return "", fmt.Errorf("network protocol not supported: %s", field)
}
}
var (
errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected")
errHostNotIP = errors.New("host is not an IP address")
errPortNotValid = errors.New("port is not valid")
)
func extractRemote(line string) (ip netip.Addr, port uint16,
protocol string, err error,
) {
@@ -123,13 +111,13 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
n := len(fields)
if n < 2 || n > 4 {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line)
return netip.Addr{}, 0, "", fmt.Errorf("remote line has not 2 fields as expected: %s", line)
}
host := fields[1]
ip, err = netip.ParseAddr(host)
if err != nil {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host)
return netip.Addr{}, 0, "", fmt.Errorf("host is not an IP address: %s", host)
// TODO resolve hostname once there is an option to allow it through
// the firewall before the VPN is up.
}
@@ -137,9 +125,9 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
if n > 2 { //nolint:mnd
portInt, err := strconv.Atoi(fields[2])
if err != nil {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %s", line)
} else if portInt < 1 || portInt > 65535 {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
}
port = uint16(portInt)
}
@@ -154,20 +142,18 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
return ip, port, protocol, nil
}
var errPostLineFieldsCount = errors.New("post line has not 2 fields as expected")
func extractPort(line string) (port uint16, err error) {
fields := strings.Fields(line)
const expectedFieldsCount = 2
if len(fields) != expectedFieldsCount {
return 0, fmt.Errorf("%w: %s", errPostLineFieldsCount, line)
return 0, fmt.Errorf("post line has not 2 fields as expected: %s", line)
}
portInt, err := strconv.Atoi(fields[1])
if err != nil {
return 0, fmt.Errorf("%w: %s", errPortNotValid, line)
return 0, fmt.Errorf("port is not valid: %s", line)
} else if portInt < 1 || portInt > 65535 {
return 0, fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
return 0, fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
}
port = uint16(portInt)
+19 -20
View File
@@ -17,7 +17,7 @@ func Test_extractDataFromLines(t *testing.T) {
testCases := map[string]struct {
lines []string
connection models.Connection
err error
errMessage string
}{
"success": {
lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
@@ -28,8 +28,8 @@ func Test_extractDataFromLines(t *testing.T) {
},
},
"extraction error": {
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
err: errors.New("on line 2: extracting protocol from proto line: network protocol not supported: bad"),
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
errMessage: "on line 2: extracting protocol from proto line: network protocol not supported: bad",
},
"only use first values found": {
lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"},
@@ -44,7 +44,7 @@ func Test_extractDataFromLines(t *testing.T) {
connection: models.Connection{
Protocol: constants.TCP,
},
err: errRemoteLineNotFound,
errMessage: "remote line not found",
},
"default TCP port": {
lines: []string{"remote 1.2.3.4", "proto tcp"},
@@ -70,9 +70,8 @@ func Test_extractDataFromLines(t *testing.T) {
connection, err := extractDataFromLines(testCase.lines)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
@@ -86,18 +85,18 @@ func Test_extractDataFromLine(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
line string
ip netip.Addr
port uint16
protocol string
isErr error
line string
ip netip.Addr
port uint16
protocol string
errMessage string
}{
"irrelevant line": {
line: "bla",
},
"extract proto error": {
line: "proto bad",
isErr: errProtocolNotSupported,
line: "proto bad",
errMessage: "network protocol not supported",
},
"extract proto success": {
line: "proto tcp",
@@ -108,8 +107,8 @@ func Test_extractDataFromLine(t *testing.T) {
protocol: constants.TCP,
},
"extract remote error": {
line: "remote bad",
isErr: errHostNotIP,
line: "remote bad",
errMessage: "host is not an IP address",
},
"extract remote success": {
line: "remote 1.2.3.4 1194 udp",
@@ -118,8 +117,8 @@ func Test_extractDataFromLine(t *testing.T) {
protocol: constants.UDP,
},
"extract_port_fail": {
line: "port a",
isErr: errPortNotValid,
line: "port a",
errMessage: "port is not valid",
},
"extract_port_success": {
line: "port 1194",
@@ -133,8 +132,8 @@ func Test_extractDataFromLine(t *testing.T) {
ip, port, protocol, err := extractDataFromLine(testCase.line)
if testCase.isErr != nil {
assert.ErrorIs(t, err, testCase.isErr)
if testCase.errMessage != "" {
assert.ErrorContains(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
+1 -4
View File
@@ -4,15 +4,12 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
)
var errPEMDecode = errors.New("cannot decode PEM encoded block")
func PEM(b []byte) (encodedData string, err error) {
pemBlock, _ := pem.Decode(b)
if pemBlock == nil {
return "", fmt.Errorf("%w", errPEMDecode)
return "", errors.New("cannot decode PEM encoded block")
}
der := pemBlock.Bytes
+3 -5
View File
@@ -13,16 +13,13 @@ func Test_PEM(t *testing.T) {
testCases := map[string]struct {
b []byte
encodedData string
errWrapped error
errMessage string
}{
"no input": {
errWrapped: errPEMDecode,
errMessage: "cannot decode PEM encoded block",
},
"bad input": {
b: []byte{1, 2, 3},
errWrapped: errPEMDecode,
errMessage: "cannot decode PEM encoded block",
},
"valid data with extras": {
@@ -46,9 +43,10 @@ func Test_PEM(t *testing.T) {
encodedData, err := PEM(testCase.b)
assert.Equal(t, testCase.encodedData, encodedData)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+2 -5
View File
@@ -3,7 +3,6 @@ package pkcs8
import (
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
)
@@ -11,8 +10,6 @@ import (
// https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers
var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals
var ErrEncryptionAlgorithmNotPBES2 = errors.New("encryption algorithm is not PBES2")
type encryptedPrivateKey struct {
EncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedData []byte
@@ -35,8 +32,8 @@ func getEncryptionAlgorithmOid(der []byte) (
oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm
if !oidAlgorithm.Equal(oidPBES2) {
return nil, fmt.Errorf("%w: %s instead of PBES2 %s",
ErrEncryptionAlgorithmNotPBES2, oidAlgorithm, oidPBES2)
return nil, fmt.Errorf("encryption algorithm is not PBES2: %s instead of PBES2 %s",
oidAlgorithm, oidPBES2)
}
var encryptionAlgorithmParams encryptedAlgorithmParams
-3
View File
@@ -2,14 +2,11 @@ package pkcs8
import (
"encoding/base64"
"errors"
"fmt"
pkcs8lib "github.com/youmark/pkcs8"
)
var ErrUnsupportedKeyType = errors.New("unsupported key type")
// UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption
// if its encryption is too weak for Openvpn/Openssl.
// If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC.
+1 -4
View File
@@ -2,15 +2,12 @@ package openvpn
import (
"context"
"errors"
"fmt"
"os/exec"
"github.com/qdm12/gluetun/internal/constants/openvpn"
)
var ErrVersionUnknown = errors.New("OpenVPN version is unknown")
const (
binOpenvpn25 = "openvpn2.5"
binOpenvpn26 = "openvpn2.6"
@@ -26,7 +23,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
case openvpn.Openvpn26:
bin = binOpenvpn26
default:
return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version)
return nil, nil, nil, fmt.Errorf("OpenVPN version is unknown: %s", version)
}
args := []string{"--config", configPath}
+1 -4
View File
@@ -2,7 +2,6 @@ package openvpn
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
@@ -16,8 +15,6 @@ func (c *Configurator) Version26(ctx context.Context) (version string, err error
return c.version(ctx, binOpenvpn26)
}
var ErrVersionTooShort = errors.New("version output is too short")
func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) {
cmd := exec.CommandContext(ctx, binName, "--version")
output, err := c.cmder.Run(cmd)
@@ -28,7 +25,7 @@ func (c *Configurator) version(ctx context.Context, binName string) (version str
words := strings.Fields(firstLine)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrVersionTooShort, firstLine)
return "", fmt.Errorf("version output is too short: %s", firstLine)
}
return words[1], nil
}
+10 -21
View File
@@ -2,24 +2,17 @@ package icmp
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
return fmt.Errorf("ICMP Next Hop MTU is too low: %d", mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
return fmt.Errorf("ICMP Next Hop MTU is too high: %d is larger than physical link MTU %d", mtu, physicalLinkMTU)
default:
return nil
}
@@ -34,14 +27,12 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
return false, fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
@@ -51,12 +42,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
return fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
return fmt.Errorf("ICMP id mismatch: sent id %d and received id %d",
outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
@@ -65,19 +56,17 @@ func checkEchoReply(icmpProtocol int, received []byte,
return nil
}
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received))
return fmt.Errorf("ICMP data mismatch: sent %d bytes and received %d bytes",
len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrEchoDataMismatch, sent, received)
return fmt.Errorf("ICMP data mismatch: sent %x and received %x",
sent, received)
}
return nil
}
+1 -3
View File
@@ -10,9 +10,7 @@ import (
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
errCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrMTUNotFound = errors.New("MTU not found")
errTimeout = errors.New("operation timed out")
)
+1 -1
View File
@@ -25,7 +25,7 @@ func PathMTUDiscover(ctx context.Context, ip netip.Addr,
switch {
case err == nil:
return mtu, nil
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
case errors.Is(err, errTimeout) || errors.Is(err, errCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
}
+3 -6
View File
@@ -117,13 +117,10 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
return 0, fmt.Errorf("ICMP destination unreachable: %w (code %d)", errCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrDestinationUnreachable, inboundMessage.Code)
return 0, fmt.Errorf("ICMP destination unreachable: code %d", inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
@@ -158,7 +155,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
return 0, fmt.Errorf("ICMP body type is not supported: %T", typedBody)
}
}
}
+3 -2
View File
@@ -2,6 +2,7 @@ package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@@ -115,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
return 0, errors.New("ICMP destination unreachable")
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
@@ -128,7 +129,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
return 0, fmt.Errorf("ICMP body type %T is not supported", typedBody)
}
}
}
+4 -4
View File
@@ -71,7 +71,7 @@ func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted)
err = ErrNotPermitted
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
@@ -157,7 +157,7 @@ func collectReplies(conn net.PacketConn, ipVersion string,
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
continue
default:
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
return fmt.Errorf("ICMP body type is not supported: %T", message.Body)
}
echoBody, _ := message.Body.(*icmp.Echo)
@@ -183,8 +183,8 @@ func collectReplies(conn net.PacketConn, ipVersion string,
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength)
return fmt.Errorf("ICMP data mismatch: sent %dB and received %dB",
sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
+1 -1
View File
@@ -29,7 +29,7 @@ func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(),
}
var (
errNoRoute = fmt.Errorf("no route to destination")
errNoRoute = errors.New("no route to destination")
ErrNetworkUnreachable = errors.New("network unreachable")
)
+3 -8
View File
@@ -13,11 +13,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/tcp"
)
var (
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
)
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
@@ -81,10 +76,10 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
}
}
if icmpSuccess {
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d",
ErrICMPOkTCPFail, maxPossibleMTU)
return 0, fmt.Errorf("PMTUD succeeded with ICMP but failed with TCP "+
"- discarding ICMP obtained MTU %d", maxPossibleMTU)
}
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail)
return 0, errors.New("PMTUD failed with both ICMP and TCP")
}
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
return mtu, nil
+2 -4
View File
@@ -57,8 +57,6 @@ func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
@@ -76,7 +74,7 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
return 0, errors.New("route not found: no loopback route found")
}
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
@@ -100,7 +98,7 @@ func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
}
}
if mtu == 0 {
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
return 0, errors.New("route not found: no default route found")
}
return mtu, nil
}
+5 -8
View File
@@ -12,8 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
var errTCPServersUnreachable = errors.New("all TCP servers are unreachable")
// findHighestMSSDestination finds the destination with the highest
// MSS amongst the provided destinations.
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
@@ -68,7 +66,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
}
if mss == 0 { // no MSS found for any destination
return netip.AddrPort{}, 0, fmt.Errorf("%w (%d servers)", errTCPServersUnreachable, len(dsts))
return netip.AddrPort{}, 0, fmt.Errorf("all %d TCP servers are unreachable", len(dsts))
}
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
@@ -77,8 +75,6 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
return dst, mss, nil
}
var errMSSNotFound = errors.New("MSS option not found in reply")
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
mss uint32, err error,
@@ -132,11 +128,12 @@ func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
case err != nil:
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
case replyHeader.typ != packetTypeSYNACK:
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ)
return 0, fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", replyHeader.typ)
case replyHeader.ack != synSeq+1:
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack)
return 0, fmt.Errorf("TCP SYN-ACK ACK number %d does not match expected value %d",
replyHeader.ack, synSeq+1)
case replyHeader.options.mss == 0:
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound)
return 0, errors.New("MSS option not found in reply")
}
err = sendRST(fd, src, dst, replyHeader.ack)
+1 -6
View File
@@ -12,11 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/test"
)
var (
ErrMTUNotFound = errors.New("MTU not found")
ErrMSSTooSmall = errors.New("TCP MSS is too small to find the MTU")
)
type testUnit struct {
mtu uint32
ok bool
@@ -178,5 +173,5 @@ func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
}
}
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
return 0, errors.New("MTU not found: your connection might not be working at all")
}
+6 -14
View File
@@ -75,13 +75,6 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
return fileDescriptor(fdPlatform), stop, nil
}
var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
errTCPPacketLost = errors.New("TCP packet was lost")
)
// Craft and send a raw TCP packet to test the MTU.
// It expects either an RST reply (if no server is listening)
// or a SYN-ACK/ACK reply (if a server is listening).
@@ -142,9 +135,10 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
// server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case firstReplyHeader.typ != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ)
return fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", firstReplyHeader.typ)
case firstReplyHeader.ack != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack)
return fmt.Errorf("TCP SYN-ACK ACK number does not match expected value: "+
"expected %d, got %d", synSeq+1, firstReplyHeader.ack)
}
if firstReplyHeader.options.mss != 0 {
@@ -191,15 +185,13 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
}
return nil
case packetTypeSYNACK: // server never received our MTU-test ACK packet
return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost)
return errors.New("TCP packet was lost: server responded with second SYN-ACK packet")
default:
_ = sendRST(fd, src, dst, finalPacketHeader.ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ)
return fmt.Errorf("final TCP packet type is unexpected: %s", finalPacketHeader.typ)
}
}
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
src, dst netip.AddrPort, mtu uint32,
) error {
@@ -223,7 +215,7 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
return fmt.Errorf("parsing reply TCP header: %w", err)
} else if replyPacketHeader.typ != packetTypeRST &&
replyPacketHeader.typ != packetTypeRSTACK {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ)
return fmt.Errorf("TCP packet is not an RST: %s", replyPacketHeader.typ)
}
return nil
}
+18 -34
View File
@@ -2,7 +2,6 @@ package tcp
import (
"encoding/binary"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/pmtud/constants"
@@ -120,17 +119,11 @@ type tcpHeader struct {
options options
}
var (
errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
)
// parseTCPHeader parses the TCP header from b.
// b should be the entire TCP packet bytes.
func parseTCPHeader(b []byte) (header tcpHeader, err error) {
if len(b) < int(constants.BaseTCPHeaderLength) {
return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b))
return tcpHeader{}, fmt.Errorf("TCP header is too short: %d bytes", len(b))
}
header.srcPort = binary.BigEndian.Uint16(b[0:2])
@@ -146,11 +139,11 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
switch {
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes",
errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength)
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
"data offset is %d bytes, expected at least %d bytes", header.dataOffset, constants.BaseTCPHeaderLength)
case int(header.dataOffset) > len(b):
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes",
errTCPHeaderDataOffset, header.dataOffset, len(b))
return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
"data offset is %d bytes, but packet is only %d bytes", header.dataOffset, len(b))
}
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
@@ -186,7 +179,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
case flags&ackFlag != 0:
header.typ = packetTypeACK
default:
return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
return tcpHeader{}, fmt.Errorf("TCP packet type is unknown: flags are 0x%02x", flags)
}
header.seq = binary.BigEndian.Uint32(b[4:8])
@@ -206,15 +199,6 @@ type optionTimestamps struct {
echo uint32
}
var (
errTCPOptionLengthTruncated = errors.New("TCP option length is truncated")
ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid")
ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid")
ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid")
ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid")
errTCPOptionTypeUnknown = errors.New("TCP option type is unknown")
)
func parseTCPOptions(b []byte) (parsed options, err error) {
i := 0
for i < len(b) {
@@ -232,7 +216,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
// Handle TLV (Type-Length-Value) options
if i+1 >= len(b) {
// This should not happen for DF packets.
return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i)
return options{}, fmt.Errorf("TCP option length is truncated: at offset %d", i)
}
length := int(b[i+1])
@@ -240,11 +224,11 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
maxLength := len(b) - i
switch {
case length < minLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d",
ErrTCPOptionLengthInvalid, optionType, i, length, minLength)
return options{}, fmt.Errorf("TCP option length is invalid: "+
"type %d at offset %d has length %d < %d", optionType, i, length, minLength)
case length > maxLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d",
ErrTCPOptionLengthInvalid, optionType, i, length, maxLength)
return options{}, fmt.Errorf("TCP option length is invalid: "+
"type %d at offset %d has length %d > %d", optionType, i, length, maxLength)
}
data := b[i+2 : i+length]
@@ -259,15 +243,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
case optionTypeMSS:
const expectedLength = 4
if length != expectedLength {
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
ErrTCPOptionMSSInvalid, i, length, expectedLength)
return options{}, fmt.Errorf("TCP option MSS value is invalid: "+
"MSS option at offset %d has length %d, expected %d", i, length, expectedLength)
}
parsed.mss = uint32(binary.BigEndian.Uint16(data))
case optionTypeWindowScale:
const expectedLength = 3
if length != expectedLength {
return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d",
ErrTCPOptionWindowScaleInvalid, i, length, expectedLength)
return options{}, fmt.Errorf("TCP option Window Scale value is invalid: "+
"window scale option at offset %d has length %d, expected %d", i, length, expectedLength)
}
windowScale := data[0]
parsed.windowScale = &windowScale
@@ -276,15 +260,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
case optionTypeTimestamps:
const expectedLength = 10
if length != expectedLength {
return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d",
ErrTCPOptionTimestampsInvalid, i, length, expectedLength)
return options{}, fmt.Errorf("TCP option Timestamps value is invalid: "+
"timestamps option at offset %d has length %d, expected %d", i, length, expectedLength)
}
parsed.timestamps = &optionTimestamps{
value: binary.BigEndian.Uint32(data[:4]),
echo: binary.BigEndian.Uint32(data[4:]),
}
default:
return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType)
return options{}, fmt.Errorf("TCP option type is unknown: type %d", optionType)
}
i += length
+1 -3
View File
@@ -177,11 +177,9 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
return l.service.GetPortsForwarded()
}
var ErrServiceNotStarted = errors.New("port forwarding service not started")
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
if l.service == nil {
return fmt.Errorf("%w", ErrServiceNotStarted)
return errors.New("port forwarding service not started")
}
return l.service.SetPortsForwarded(l.runCtx, ports)
+12 -24
View File
@@ -55,23 +55,10 @@ func (s *Settings) OverrideWith(update Settings) {
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
}
var (
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrUsernameNotSet = errors.New("username not set")
ErrPasswordNotSet = errors.New("password not set")
ErrFilepathNotSet = errors.New("file path not set")
ErrInterfaceNotSet = errors.New("interface not set")
ErrPortsCountZero = errors.New("ports count cannot be zero")
ErrPortsCountTooHigh = errors.New("ports count too high")
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
ErrListeningPortZero = errors.New("listening port cannot be 0")
)
func (s *Settings) Validate(forStartup bool) (err error) {
// Minimal validation
if s.Filepath == "" {
return fmt.Errorf("%w", ErrFilepathNotSet)
return errors.New("file path not set")
}
if !forStartup {
@@ -83,41 +70,42 @@ func (s *Settings) Validate(forStartup bool) (err error) {
// Startup validation requires additional fields set.
switch {
case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet)
return errors.New("port forwarder not set")
case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet)
return errors.New("interface not set")
case s.PortsCount == 0:
return fmt.Errorf("%w", ErrPortsCountZero)
return errors.New("ports count cannot be zero")
}
switch s.PortForwarder.Name() {
case providers.PrivateInternetAccess:
switch {
case s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
return errors.New("server name not set")
case s.Username == "":
return fmt.Errorf("%w", ErrUsernameNotSet)
return errors.New("username not set")
case s.Password == "":
return fmt.Errorf("%w", ErrPasswordNotSet)
return errors.New("password not set")
}
case providers.Protonvpn:
const maxPortsCount = 4
if s.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
}
default:
const maxPortsCount = 1
if s.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
}
}
if !slices.Equal(s.ListeningPorts, []uint16{0}) {
switch {
case len(s.ListeningPorts) != int(s.PortsCount):
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount)
return fmt.Errorf("listening ports length must be equal to ports count: %d != %d",
len(s.ListeningPorts), s.PortsCount)
case slices.Contains(s.ListeningPorts, 0):
return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts)
return fmt.Errorf("listening port cannot be 0: in %v", s.ListeningPorts)
}
}
+1 -1
View File
@@ -120,7 +120,7 @@ func Test_Server_BadSettings(t *testing.T) {
server, err := New(settings)
assert.Nil(t, server)
assert.ErrorIs(t, err, ErrBlockProfileRateNegative)
assert.ErrorContains(t, err, "block profile rate cannot be negative")
const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative"
assert.EqualError(t, err, expectedErrMessage)
}
+2 -8
View File
@@ -2,7 +2,6 @@ package pprof
import (
"errors"
"fmt"
"time"
"github.com/qdm12/gluetun/internal/httpserver"
@@ -51,18 +50,13 @@ func (s *Settings) OverrideWith(other Settings) {
s.HTTPServer.OverrideWith(other.HTTPServer)
}
var (
ErrBlockProfileRateNegative = errors.New("block profile rate cannot be negative")
ErrMutexProfileRateNegative = errors.New("mutex profile rate cannot be negative")
)
func (s Settings) Validate() (err error) {
if *s.BlockProfileRate < 0 {
return fmt.Errorf("%w", ErrBlockProfileRateNegative)
return errors.New("block profile rate cannot be negative")
}
if *s.MutexProfileRate < 0 {
return fmt.Errorf("%w", ErrMutexProfileRateNegative)
return errors.New("mutex profile rate cannot be negative")
}
return s.HTTPServer.Validate()
+4 -8
View File
@@ -6,7 +6,6 @@ import (
"time"
"github.com/qdm12/gluetun/internal/httpserver"
"github.com/qdm12/gosettings/validate"
"github.com/stretchr/testify/assert"
)
@@ -195,7 +194,6 @@ func Test_Settings_Validate(t *testing.T) {
testCases := map[string]struct {
settings Settings
errWrapped error
errMessage string
}{
"negative block profile rate": {
@@ -203,16 +201,14 @@ func Test_Settings_Validate(t *testing.T) {
BlockProfileRate: intPtr(-1),
MutexProfileRate: intPtr(0),
},
errWrapped: ErrBlockProfileRateNegative,
errMessage: ErrBlockProfileRateNegative.Error(),
errMessage: "block profile rate cannot be negative",
},
"negative mutex profile rate": {
settings: Settings{
BlockProfileRate: intPtr(0),
MutexProfileRate: intPtr(-1),
},
errWrapped: ErrMutexProfileRateNegative,
errMessage: ErrMutexProfileRateNegative.Error(),
errMessage: "mutex profile rate cannot be negative",
},
"http server validation error": {
settings: Settings{
@@ -222,7 +218,6 @@ func Test_Settings_Validate(t *testing.T) {
Address: ":x",
},
},
errWrapped: validate.ErrPortNotAnInteger,
errMessage: "port value is not an integer: x",
},
"valid settings": {
@@ -247,9 +242,10 @@ func Test_Settings_Validate(t *testing.T) {
err := testCase.settings.Validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+2 -4
View File
@@ -6,8 +6,6 @@ import (
"fmt"
"net/http"
"net/netip"
"github.com/qdm12/gluetun/internal/provider/common"
)
type apiData struct {
@@ -48,8 +46,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
if response.StatusCode != http.StatusOK {
_ = response.Body.Close()
return data, fmt.Errorf("%w: %d %s",
common.ErrHTTPStatusCodeNotOK, response.StatusCode, response.Status)
return data, fmt.Errorf("HTTP status code not OK: %d %s",
response.StatusCode, response.Status)
}
decoder := json.NewDecoder(response.Body)
-5
View File
@@ -1,5 +0,0 @@
package common
import "errors"
var ErrPortForwardNotSupported = errors.New("port forwarding not supported")
+2 -4
View File
@@ -10,10 +10,8 @@ import (
)
var (
ErrNotEnoughServers = errors.New("not enough servers found")
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrIPFetcherUnsupported = errors.New("IP fetcher not supported")
ErrCredentialsMissing = errors.New("credentials missing")
ErrNotEnoughServers = errors.New("not enough servers found")
ErrCredentialsMissing = errors.New("credentials are missing")
)
type Fetcher interface {
+1 -4
View File
@@ -1,7 +1,6 @@
package custom
import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
@@ -10,8 +9,6 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
var ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
// GetConnection gets the connection from the OpenVPN configuration file.
func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
connection models.Connection, err error,
@@ -22,7 +19,7 @@ func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
case vpn.Wireguard, vpn.AmneziaWg:
return getWireguardConnection(selection), nil
default:
return connection, fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, selection.VPN)
return connection, fmt.Errorf("VPN type not supported for custom provider: %s", selection.VPN)
}
}
-3
View File
@@ -1,7 +1,6 @@
package custom
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -12,8 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/provider/utils"
)
var ErrExtractData = errors.New("failed extracting information from custom configuration file")
func (p *Provider) OpenVPNConfig(connection models.Connection,
settings settings.OpenVPN, ipv6Supported bool,
) (lines []string) {
+2 -5
View File
@@ -3,13 +3,10 @@ package updater
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)
var errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
type apiData struct {
Servers []apiServer `json:"servers"`
}
@@ -42,8 +39,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
if response.StatusCode != http.StatusOK {
_ = response.Body.Close()
return data, fmt.Errorf("%w: %d %s",
errHTTPStatusCodeNotOK, response.StatusCode, response.Status)
return data, fmt.Errorf("HTTP status code not OK: %d %s",
response.StatusCode, response.Status)
}
decoder := json.NewDecoder(response.Body)
@@ -21,21 +21,17 @@ func Test_Provider_GetConnection(t *testing.T) {
const provider = providers.Expressvpn
errTest := errors.New("test error")
testCases := map[string]struct {
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
ipv6Supported bool
connection models.Connection
errWrapped error
errMessage string
panicMessage string
}{
"error": {
storageErr: errTest,
errWrapped: errTest,
storageErr: errors.New("test error"),
errMessage: "filtering servers: test error",
},
"default OpenVPN TCP port": {
@@ -100,9 +96,10 @@ func Test_Provider_GetConnection(t *testing.T) {
connection, err := provider.GetConnection(testCase.selection, testCase.ipv6Supported)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
+3 -8
View File
@@ -3,14 +3,11 @@ package updater
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/qdm12/gluetun/internal/provider/common"
)
type apiServer struct {
@@ -19,8 +16,6 @@ type apiServer struct {
hostname string
}
var ErrDataMalformed = errors.New("data is malformed")
const apiURL = "https://support.fastestvpn.com/wp-admin/admin-ajax.php"
// The API URL and requests are shamelessly taken from network operations
@@ -49,7 +44,7 @@ func fetchAPIServers(ctx context.Context, client *http.Client, protocol string)
if response.StatusCode != http.StatusOK {
_ = response.Body.Close()
return nil, fmt.Errorf("%w: %d", common.ErrHTTPStatusCodeNotOK, response.StatusCode)
return nil, fmt.Errorf("HTTP status code not OK: %d", response.StatusCode)
}
data, err := io.ReadAll(response.Body)
@@ -79,8 +74,8 @@ func fetchAPIServers(ctx context.Context, client *http.Client, protocol string)
for i := range numberOfTDBlocks {
tdBlock := getNextTDBlock(trBlock)
if tdBlock == nil {
return nil, fmt.Errorf("%w: expected 3 <td> blocks in <tr> block %q",
ErrDataMalformed, string(trBlock))
return nil, fmt.Errorf("data is malformed: expected 3 <td> blocks in <tr> block %q",
string(trBlock))
}
trBlock = trBlock[len(tdBlock):]

Some files were not shown because too many files have changed in this diff Show More