chore: do not use sentinel errors when unneeded

- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly
- only use sentinel errors when it has to be checked using `errors.Is`
- replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error
- exclude the sentinel error definition requirement from .golangci.yml
- update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
Quentin McGaw
2026-05-02 00:50:16 +00:00
parent 9b6f048fe8
commit 4a78989d9d
172 changed files with 666 additions and 1433 deletions
+3
View File
@@ -68,6 +68,9 @@ linters:
- err113 - err113
- mnd - mnd
path: ci\/.+\.go path: ci\/.+\.go
- linters:
- err113
text: "do not define dynamic errors, use wrapped static errors instead"
paths: paths:
- third_party$ - third_party$
+1 -3
View File
@@ -142,8 +142,6 @@ func main() {
} }
} }
var errCommandUnknown = errors.New("command is unknown")
//nolint:gocognit,gocyclo,maintidx //nolint:gocognit,gocyclo,maintidx
func _main(ctx context.Context, buildInfo models.BuildInformation, func _main(ctx context.Context, buildInfo models.BuildInformation,
args []string, logger log.LoggerInterface, reader *reader.Reader, args []string, logger log.LoggerInterface, reader *reader.Reader,
@@ -165,7 +163,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
case "genkey": case "genkey":
return cli.GenKey(args[2:]) return cli.GenKey(args[2:])
default: default:
return fmt.Errorf("%w: %s", errCommandUnknown, args[1]) return fmt.Errorf("command is unknown: %s", args[1])
} }
} }
+2 -5
View File
@@ -1,7 +1,6 @@
package alpine package alpine
import ( import (
"errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
@@ -9,8 +8,6 @@ import (
"strconv" "strconv"
) )
var ErrUserAlreadyExists = errors.New("user already exists")
// CreateUser creates a user in Alpine with the given UID. // CreateUser creates a user in Alpine with the given UID.
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) { func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
UIDStr := strconv.Itoa(uid) UIDStr := strconv.Itoa(uid)
@@ -34,8 +31,8 @@ func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, e
} }
if u != nil { if u != nil {
return "", fmt.Errorf("%w: with name %s for ID %s instead of %d", return "", fmt.Errorf("user already exists: with name %s for ID %s instead of %d",
ErrUserAlreadyExists, username, u.Uid, uid) username, u.Uid, uid)
} }
const permission = fs.FileMode(0o644) const permission = fs.FileMode(0o644)
+2 -1
View File
@@ -1,6 +1,7 @@
package amneziawg package amneziawg
import ( import (
"errors"
"net/netip" "net/netip"
"testing" "testing"
@@ -28,7 +29,7 @@ func Test_New(t *testing.T) {
PrivateKey: "", PrivateKey: "",
}, },
}, },
err: wireguard.ErrPrivateKeyMissing, err: errors.New("private key is missing"),
}, },
"minimal valid settings": { "minimal valid settings": {
settings: Settings{ settings: Settings{
+2 -8
View File
@@ -13,11 +13,6 @@ import (
"github.com/qdm12/gluetun/internal/wireguard" "github.com/qdm12/gluetun/internal/wireguard"
) )
var (
errTunNameMismatch = errors.New("TUN device name is mismatching")
errDeviceWaited = errors.New("device waited for")
)
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the // Run runs the amneziawg interface and waits until the context is done, then it cleans up the
// interface and returns any error that occurred during setup or waiting. It sends an error to // interface and returns any error that occurred during setup or waiting. It sends an error to
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context // waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
@@ -52,8 +47,7 @@ func setupUserspace(ctx context.Context,
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err) return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
} else if tunName != interfaceName { } else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: expected %q and got %q", return 0, nil, fmt.Errorf("TUN device name is mismatching: expected %q and got %q", interfaceName, tunName)
errTunNameMismatch, interfaceName, tunName)
} }
link, err := netLinker.LinkByName(interfaceName) link, err := netLinker.LinkByName(interfaceName)
@@ -106,7 +100,7 @@ func setupUserspace(ctx context.Context,
case err = <-uapiAcceptErrorCh: case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh) close(uapiAcceptErrorCh)
case <-device.Wait(): case <-device.Wait():
err = errDeviceWaited err = errors.New("device waited for")
} }
cleanups.Cleanup(logger) cleanups.Cleanup(logger)
+2 -8
View File
@@ -16,11 +16,6 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
) )
var (
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
)
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool, func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
provider string, titleCaser cases.Caser, provider string, titleCaser cases.Caser,
) { ) {
@@ -65,11 +60,10 @@ func (c *CLI) FormatServers(args []string) error {
} }
switch len(providers) { switch len(providers) {
case 0: case 0:
return fmt.Errorf("%w", ErrProviderUnspecified) return errors.New("VPN provider to format was not specified")
case 1: case 1:
default: default:
return fmt.Errorf("%w: %d specified: %s", return fmt.Errorf("more than one VPN provider to format were specified: %d specified: %s", len(providers),
ErrMultipleProvidersToFormat, len(providers),
strings.Join(providers, ", ")) strings.Join(providers, ", "))
} }
+2 -9
View File
@@ -24,13 +24,6 @@ import (
"github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/updater/unzip"
) )
var (
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
ErrNoProviderSpecified = errors.New("no provider was specified")
ErrUsernameMissing = errors.New("username is required for this provider")
ErrPasswordMissing = errors.New("password is required for this provider")
)
type UpdaterLogger interface { type UpdaterLogger interface {
Info(s string) Info(s string)
Warn(s string) Warn(s string)
@@ -65,14 +58,14 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
} }
if !endUserMode && !maintainerMode { if !endUserMode && !maintainerMode {
return fmt.Errorf("%w", ErrModeUnspecified) return errors.New("at least one of -enduser or -maintainer must be specified")
} }
if updateAll { if updateAll {
options.Providers = providers.All() options.Providers = providers.All()
} else { } else {
if csvProviders == "" { if csvProviders == "" {
return fmt.Errorf("%w", ErrNoProviderSpecified) return errors.New("no provider was specified")
} }
options.Providers = strings.Split(csvProviders, ",") options.Providers = strings.Split(csvProviders, ",")
} }
+5 -12
View File
@@ -8,13 +8,6 @@ import (
"unicode/utf8" "unicode/utf8"
) )
var (
errCommandEmpty = errors.New("command is empty")
errSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
errEscapeUnterminated = errors.New("unterminated backslash-escape")
)
// split splits a command string into a slice of arguments. // split splits a command string into a slice of arguments.
// This is especially important for commands such as: // This is especially important for commands such as:
// /bin/sh -c "echo hello" // /bin/sh -c "echo hello"
@@ -25,7 +18,7 @@ var (
// - expansion (brace, shell or pathname). // - expansion (brace, shell or pathname).
func split(command string) (words []string, err error) { func split(command string) (words []string, err error) {
if command == "" { if command == "" {
return nil, fmt.Errorf("%w", errCommandEmpty) return nil, errors.New("command is empty")
} }
const bufferSize = 1024 const bufferSize = 1024
@@ -42,7 +35,7 @@ func split(command string) (words []string, err error) {
case character == '\\': case character == '\\':
// Look ahead to eventually skip an escaped newline // Look ahead to eventually skip an escaped newline
if command[startIndex+runeSize:] == "" { if command[startIndex+runeSize:] == "" {
return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command) return nil, fmt.Errorf("unterminated backslash-escape: %q", command)
} }
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:]) character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
if character == '\n' { if character == '\n' {
@@ -119,7 +112,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
startIndex = cursor startIndex = cursor
} }
} }
return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated) return "", 0, errors.New("unterminated double-quoted string")
} }
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
@@ -127,7 +120,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
) { ) {
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'') closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
if closingQuoteIndex == -1 { if closingQuoteIndex == -1 {
return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated) return "", 0, errors.New("unterminated single-quoted string")
} }
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex]) buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
const singleQuoteRuneLength = 1 const singleQuoteRuneLength = 1
@@ -139,7 +132,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
word string, newStartIndex int, err error, word string, newStartIndex int, err error,
) { ) {
if input[startIndex:] == "" { if input[startIndex:] == "" {
return "", 0, fmt.Errorf("%w", errEscapeUnterminated) return "", 0, errors.New("unterminated backslash-escape")
} }
character, runeLength := utf8.DecodeRuneInString(input[startIndex:]) character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
if character != '\n' { // backslash-escaped newline is ignored if character != '\n' { // backslash-escaped newline is ignored
+4 -9
View File
@@ -12,12 +12,10 @@ func Test_split(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
command string command string
words []string words []string
errWrapped error
errMessage string errMessage string
}{ }{
"empty": { "empty": {
command: "", command: "",
errWrapped: errCommandEmpty,
errMessage: "command is empty", errMessage: "command is empty",
}, },
"concrete_sh_command": { "concrete_sh_command": {
@@ -74,22 +72,18 @@ func Test_split(t *testing.T) {
}, },
"unterminated_single_quote": { "unterminated_single_quote": {
command: "'abc'\\''def", command: "'abc'\\''def",
errWrapped: errSingleQuoteUnterminated,
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`, errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
}, },
"unterminated_double_quote": { "unterminated_double_quote": {
command: "\"abc'def", command: "\"abc'def",
errWrapped: errDoubleQuoteUnterminated,
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`, errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
}, },
"unterminated_escape": { "unterminated_escape": {
command: "abc\\", command: "abc\\",
errWrapped: errEscapeUnterminated,
errMessage: `splitting word in "abc\\": unterminated backslash-escape`, errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
}, },
"unterminated_escape_only": { "unterminated_escape_only": {
command: " \\", command: " \\",
errWrapped: errEscapeUnterminated,
errMessage: `unterminated backslash-escape: " \\"`, errMessage: `unterminated backslash-escape: " \\"`,
}, },
} }
@@ -101,9 +95,10 @@ func Test_split(t *testing.T) {
words, err := split(testCase.command) words, err := split(testCase.command)
assert.Equal(t, testCase.words, words) assert.Equal(t, testCase.words, words)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil { assert.ErrorContains(t, err, testCase.errMessage)
assert.EqualError(t, err, testCase.errMessage) } else {
assert.NoError(t, err)
} }
}) })
} }
+12 -21
View File
@@ -1,7 +1,6 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -177,14 +176,6 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
return node return node
} }
var (
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error { func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
const amneziaWG = true const amneziaWG = true
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG) err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
@@ -194,16 +185,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
if *a.JunkPacketCount == 0 { if *a.JunkPacketCount == 0 {
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 { if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d", return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax) "jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} }
} else { } else {
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 { if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d", return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax) "jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} else if *a.JunkPacketMin > *a.JunkPacketMax { } else if *a.JunkPacketMin > *a.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d", return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax) "jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
} }
} }
@@ -222,20 +213,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
case 1: case 1:
_, err := strconv.Atoi(fields[0]) _, err := strconv.Atoi(fields[0])
if err != nil { if err != nil {
return fmt.Errorf("%w: %s value %s is not a number", return fmt.Errorf("header range is malformed: "+
ErrHeaderRangeMalformed, name, headerRange) "%s value %s is not a number", name, headerRange)
} }
case 2: //nolint:mnd case 2: //nolint:mnd
for _, field := range fields { for _, field := range fields {
_, err := strconv.Atoi(field) _, err := strconv.Atoi(field)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s value %s is not a valid range", return fmt.Errorf("header range is malformed: "+
ErrHeaderRangeMalformed, name, headerRange) "%s value %s is not a valid range", name, headerRange)
} }
} }
default: default:
return fmt.Errorf("%w: %s value %s must be in the form n or n-m", return fmt.Errorf("header range is malformed: "+
ErrHeaderRangeMalformed, name, headerRange) "%s value %s must be in the form n or n-m", name, headerRange)
} }
} }
+7 -13
View File
@@ -1,7 +1,6 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"time" "time"
@@ -48,22 +47,15 @@ type DNS struct {
UpstreamPlainAddresses []netip.AddrPort UpstreamPlainAddresses []netip.AddrPort
} }
var (
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
)
func (d DNS) validate() (err error) { func (d DNS) validate() (err error) {
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) { if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType) return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
} }
const minUpdatePeriod = 30 * time.Second const minUpdatePeriod = 30 * time.Second
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod { if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
return fmt.Errorf("%w: %s must be bigger than %s", return fmt.Errorf("update period is too short: %s must be bigger than %s",
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod) *d.UpdatePeriod, minUpdatePeriod)
} }
if d.UpstreamType == DNSUpstreamTypePlain { if d.UpstreamType == DNSUpstreamTypePlain {
@@ -81,9 +73,11 @@ func (d DNS) validate() (err error) {
} }
switch { switch {
case *d.IPv6 && !selectedHasPlainIPv6: case *d.IPv6 && !selectedHasPlainIPv6:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses)) return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
case !*d.IPv6 && !selectedHasPlainIPv4: case !*d.IPv6 && !selectedHasPlainIPv4:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses)) return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
} }
} }
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes // Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
@@ -1,7 +1,6 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
@@ -37,22 +36,16 @@ func (b *DNSBlacklist) setDefaults() {
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
var (
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
)
func (b DNSBlacklist) validate() (err error) { func (b DNSBlacklist) validate() (err error) {
for _, host := range b.AllowedHosts { for _, host := range b.AllowedHosts {
if !hostRegex.MatchString(host) { if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host) return fmt.Errorf("allowed host is not valid: %s", host)
} }
} }
for _, host := range b.AddBlockedHosts { for _, host := range b.AddBlockedHosts {
if !hostRegex.MatchString(host) { if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host) return fmt.Errorf("blocked host is not valid: %s", host)
} }
} }
@@ -61,7 +54,7 @@ func (b DNSBlacklist) validate() (err error) {
host = host[2:] host = host[2:]
} }
if !hostRegex.MatchString(host) { if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host) return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
} }
} }
@@ -209,8 +202,6 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
return ips, ipPrefixes, nil return ips, ipPrefixes, nil
} }
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr, func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
ipPrefixes []netip.Prefix, err error, ipPrefixes []netip.Prefix, err error,
) { ) {
@@ -236,8 +227,9 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
} }
return nil, nil, fmt.Errorf( return nil, nil, fmt.Errorf(
"environment variable DOT_PRIVATE_ADDRESS: %w: %s", "environment variable DOT_PRIVATE_ADDRESS: "+
ErrPrivateAddressNotValid, privateAddress) "private address is not a valid IP or CIDR range: %s",
privateAddress)
} }
return ips, ipPrefixes, nil return ips, ipPrefixes, nil
-58
View File
@@ -1,58 +0,0 @@
package settings
import "errors"
var (
ErrValueUnknown = errors.New("value is unknown")
ErrCityNotValid = errors.New("the city specified is not valid")
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
ErrCategoryNotValid = errors.New("the category specified is not valid")
ErrCountryNotValid = errors.New("the country specified is not valid")
ErrFilepathMissing = errors.New("filepath is missing")
ErrFirewallZeroPort = errors.New("cannot have a zero port")
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
ErrISPNotValid = errors.New("the ISP specified is not valid")
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
ErrMissingValue = errors.New("missing value")
ErrNameNotValid = errors.New("the server name specified is not valid")
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
ErrRegionNotValid = errors.New("the region specified is not valid")
ErrServerAddressNotValid = errors.New("server listening address is not valid")
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
)
+4 -3
View File
@@ -1,6 +1,7 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
@@ -20,16 +21,16 @@ type Firewall struct {
func (f Firewall) validate() (err error) { func (f Firewall) validate() (err error) {
if hasZeroPort(f.VPNInputPorts) { if hasZeroPort(f.VPNInputPorts) {
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort) return errors.New("VPN input ports: cannot have a zero port")
} }
if hasZeroPort(f.InputPorts) { if hasZeroPort(f.InputPorts) {
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort) return errors.New("input ports: cannot have a zero port")
} }
for _, subnet := range f.OutboundSubnets { for _, subnet := range f.OutboundSubnets {
if subnet.Addr().IsUnspecified() { if subnet.Addr().IsUnspecified() {
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet) return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
} }
} }
@@ -13,25 +13,21 @@ func Test_Firewall_validate(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
firewall Firewall firewall Firewall
errWrapped error
errMessage string errMessage string
}{ }{
"empty": { "empty": {
errWrapped: log.ErrLevelNotRecognized,
errMessage: "iptables settings: log level: level is not recognized: ", errMessage: "iptables settings: log level: level is not recognized: ",
}, },
"zero_vpn_input_port": { "zero_vpn_input_port": {
firewall: Firewall{ firewall: Firewall{
VPNInputPorts: []uint16{0}, VPNInputPorts: []uint16{0},
}, },
errWrapped: ErrFirewallZeroPort,
errMessage: "VPN input ports: cannot have a zero port", errMessage: "VPN input ports: cannot have a zero port",
}, },
"zero_input_port": { "zero_input_port": {
firewall: Firewall{ firewall: Firewall{
InputPorts: []uint16{0}, InputPorts: []uint16{0},
}, },
errWrapped: ErrFirewallZeroPort,
errMessage: "input ports: cannot have a zero port", errMessage: "input ports: cannot have a zero port",
}, },
"unspecified_outbound_subnet": { "unspecified_outbound_subnet": {
@@ -40,7 +36,6 @@ func Test_Firewall_validate(t *testing.T) {
netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("0.0.0.0/0"),
}, },
}, },
errWrapped: ErrFirewallPublicOutboundSubnet,
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0", errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
}, },
"public_outbound_subnet": { "public_outbound_subnet": {
@@ -70,9 +65,10 @@ func Test_Firewall_validate(t *testing.T) {
err := testCase.firewall.validate() err := testCase.firewall.validate()
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+4 -10
View File
@@ -38,12 +38,6 @@ type Health struct {
RestartVPN *bool RestartVPN *bool
} }
var (
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
)
func (h Health) Validate() (err error) { func (h Health) Validate() (err error) {
err = validate.ListeningAddress(h.ServerAddress, os.Getuid()) err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
if err != nil { if err != nil {
@@ -53,16 +47,16 @@ func (h Health) Validate() (err error) {
for _, ip := range h.ICMPTargetIPs { for _, ip := range h.ICMPTargetIPs {
switch { switch {
case !ip.IsValid(): case !ip.IsValid():
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip) return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1: case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified", return errors.New("ICMP target IP addresses are not compatible: " +
ErrICMPTargetIPsNotCompatible) "only a single IP address must be set if it is to be unspecified")
} }
} }
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns") err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err) return fmt.Errorf("small check type is not valid: %w", err)
} }
return nil return nil
+2 -3
View File
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
// Do not validate user and password // Do not validate user and password
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid()) err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress) return fmt.Errorf("server listening address is not valid: %w", err)
} }
return nil return nil
@@ -176,7 +176,6 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
case "disabled", "no", "off": case "disabled", "no", "off":
return ptrTo(false), nil return ptrTo(false), nil
default: default:
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s", return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
ErrValueUnknown, value)
} }
} }
+12 -14
View File
@@ -2,6 +2,7 @@ package settings
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
@@ -92,7 +93,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
// Validate version // Validate version
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26} validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
if err = validate.IsOneOf(o.Version, validVersions...); err != nil { if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err) return fmt.Errorf("version is not valid: %w", err)
} }
isCustom := vpnProvider == providers.Custom isCustom := vpnProvider == providers.Custom
@@ -101,14 +102,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
vpnProvider != providers.VPNSecure vpnProvider != providers.VPNSecure
if isUserRequired && *o.User == "" { if isUserRequired && *o.User == "" {
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty) return errors.New("user is empty")
} }
passwordRequired := isUserRequired && passwordRequired := isUserRequired &&
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User)) (vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
if passwordRequired && *o.Password == "" { if passwordRequired && *o.Password == "" {
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty) return errors.New("password is empty")
} }
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile) err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
@@ -132,23 +133,20 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
} }
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" { if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty) return errors.New("key passphrase is empty")
} }
const maxMSSFix = 10000 const maxMSSFix = 10000
if *o.MSSFix > maxMSSFix { if *o.MSSFix > maxMSSFix {
return fmt.Errorf("%w: %d is over the maximum value of %d", return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
} }
if !regexpInterfaceName.MatchString(o.Interface) { if !regexpInterfaceName.MatchString(o.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'", return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
} }
if *o.Verbosity < 0 || *o.Verbosity > 6 { if *o.Verbosity < 0 || *o.Verbosity > 6 {
return fmt.Errorf("%w: %d can only be between 0 and 5", return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
} }
return nil return nil
@@ -162,7 +160,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
} }
if confFile == "" { if confFile == "" {
return fmt.Errorf("%w", ErrFilepathMissing) return errors.New("filepath is missing")
} }
err = validate.FileExists(confFile) err = validate.FileExists(confFile)
@@ -189,7 +187,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
providers.VPNSecure, providers.VPNSecure,
providers.VPNUnlimited: providers.VPNUnlimited:
if clientCert == "" { if clientCert == "" {
return fmt.Errorf("%w", ErrMissingValue) return errors.New("missing value")
} }
} }
@@ -211,7 +209,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
providers.Cyberghost, providers.Cyberghost,
providers.VPNUnlimited: providers.VPNUnlimited:
if clientKey == "" { if clientKey == "" {
return fmt.Errorf("%w", ErrMissingValue) return errors.New("missing value")
} }
} }
@@ -230,7 +228,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
encryptedPrivateKey string, encryptedPrivateKey string,
) (err error) { ) (err error) {
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" { if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
return fmt.Errorf("%w", ErrMissingValue) return errors.New("missing value")
} }
if encryptedPrivateKey == "" { if encryptedPrivateKey == "" {
@@ -62,8 +62,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Perfectprivacy, providers.Perfectprivacy,
providers.Vyprvpn, providers.Vyprvpn,
) { ) {
return fmt.Errorf("%w: for VPN service provider %s", return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
ErrOpenVPNTCPNotSupported, vpnProvider)
} }
// Validate CustomPort // Validate CustomPort
@@ -78,8 +77,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Nordvpn, providers.Purevpn, providers.Nordvpn, providers.Purevpn,
providers.Surfshark, providers.VPNSecure, providers.Surfshark, providers.VPNSecure,
providers.VPNUnlimited, providers.Vyprvpn: providers.VPNUnlimited, providers.Vyprvpn:
return fmt.Errorf("%w: for VPN service provider %s", return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
default: default:
var allowedTCP, allowedUDP []uint16 var allowedTCP, allowedUDP []uint16
switch vpnProvider { switch vpnProvider {
@@ -123,8 +121,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
} }
err = validate.IsOneOf(*o.CustomPort, allowedPorts...) err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
if err != nil { if err != nil {
return fmt.Errorf("%w: for VPN service provider %s: %w", return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
} }
} }
} }
@@ -136,7 +133,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
presets.Strong, presets.Strong,
} }
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil { if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err) return fmt.Errorf("PIA encryption preset is not valid: %w", err)
} }
} }
+2 -8
View File
@@ -1,7 +1,6 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
@@ -24,21 +23,16 @@ type PMTUD struct {
TCPAddresses []netip.AddrPort `json:"tcp_addresses"` TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
} }
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings. // Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) { func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses { for i, addr := range p.ICMPAddresses {
if !addr.IsValid() { if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i) return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
} }
} }
for i, addr := range p.TCPAddresses { for i, addr := range p.TCPAddresses {
if !addr.IsValid() { if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i) return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
} }
} }
return nil return nil
+9 -14
View File
@@ -55,12 +55,6 @@ type PortForwarding struct {
Password string `json:"password"` Password string `json:"password"`
} }
var (
ErrPortsCountTooHigh = errors.New("ports count too high")
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
ErrListeningPortZero = errors.New("listening port cannot be 0")
)
func (p PortForwarding) Validate(vpnProvider string) (err error) { func (p PortForwarding) Validate(vpnProvider string) (err error) {
if !*p.Enabled { if !*p.Enabled {
return nil return nil
@@ -78,7 +72,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
providers.Protonvpn, providers.Protonvpn,
} }
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil { if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err) return fmt.Errorf("port forwarding cannot be enabled: %w", err)
} }
// Validate Filepath // Validate Filepath
@@ -94,30 +88,31 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
const maxPortsCount = 1 const maxPortsCount = 1
switch { switch {
case p.PortsCount > maxPortsCount: case p.PortsCount > maxPortsCount:
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
case p.Username == "": case p.Username == "":
return fmt.Errorf("%w", ErrPortForwardingUserEmpty) return errors.New("port forwarding username is empty")
case p.Password == "": case p.Password == "":
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty) return errors.New("port forwarding password is empty")
} }
case providers.Protonvpn: case providers.Protonvpn:
const maxPortsCount = 4 const maxPortsCount = 4
if p.PortsCount > maxPortsCount { if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
} }
default: default:
const maxPortsCount = 1 const maxPortsCount = 1
if p.PortsCount > maxPortsCount { if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
} }
} }
if !slices.Equal(p.ListeningPorts, []uint16{0}) { if !slices.Equal(p.ListeningPorts, []uint16{0}) {
switch { switch {
case len(p.ListeningPorts) != int(p.PortsCount): case len(p.ListeningPorts) != int(p.PortsCount):
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount) return fmt.Errorf("listening ports length must be equal to ports count: "+
"%d != %d", len(p.ListeningPorts), p.PortsCount)
case slices.Contains(p.ListeningPorts, 0): case slices.Contains(p.ListeningPorts, 0):
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts) return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
} }
} }
+1 -1
View File
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
} }
} }
if err = validate.IsOneOf(p.Name, validNames...); err != nil { if err = validate.IsOneOf(p.Name, validNames...); err != nil {
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err) return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
} }
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner) err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
@@ -15,7 +15,6 @@ func Test_PublicIP_read(t *testing.T) {
makeReader func(ctrl *gomock.Controller) *reader.Reader makeReader func(ctrl *gomock.Controller) *reader.Reader
makeWarner func(ctrl *gomock.Controller) Warner makeWarner func(ctrl *gomock.Controller) Warner
settings PublicIP settings PublicIP
errWrapped error
errMessage string errMessage string
}{ }{
"nothing_read": { "nothing_read": {
@@ -152,9 +151,10 @@ func Test_PublicIP_read(t *testing.T) {
err := settings.read(reader, warner) err := settings.read(reader, warner)
assert.Equal(t, testCase.settings, settings) assert.Equal(t, testCase.settings, settings)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+1 -2
View File
@@ -46,8 +46,7 @@ func (c ControlServer) validate() (err error) {
uid := os.Getuid() uid := os.Getuid()
const maxPrivilegedPort = 1023 const maxPrivilegedPort = 1023
if uid != 0 && port != 0 && port <= maxPrivilegedPort { if uid != 0 && port != 0 && port <= maxPrivilegedPort {
return fmt.Errorf("%w: %d when running with user ID %d", return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
ErrControlServerPrivilegedPort, port, uid)
} }
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole)) jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
@@ -71,25 +71,13 @@ type ServerSelection struct {
Wireguard WireguardSelection `json:"wireguard"` Wireguard WireguardSelection `json:"wireguard"`
} }
var (
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
)
func (ss *ServerSelection) validate(vpnServiceProvider string, func (ss *ServerSelection) validate(vpnServiceProvider string,
filterChoicesGetter FilterChoicesGetter, warner Warner, filterChoicesGetter FilterChoicesGetter, warner Warner,
) (err error) { ) (err error) {
switch ss.VPN { switch ss.VPN {
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard: case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
default: default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
} }
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner) filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
@@ -150,7 +138,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
// Only return error comparing with newer regions, we don't want to confuse the user // Only return error comparing with newer regions, we don't want to confuse the user
// with the retro regions in the error message. // with the retro regions in the error message.
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner) err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err) return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
} }
} }
@@ -164,27 +152,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
) (err error) { ) (err error) {
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrCountryNotValid, err) return fmt.Errorf("the country specified is not valid: %w", err)
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrRegionNotValid, err) return fmt.Errorf("the region specified is not valid: %w", err)
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrCityNotValid, err) return fmt.Errorf("the city specified is not valid: %w", err)
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrISPNotValid, err) return fmt.Errorf("the ISP specified is not valid: %w", err)
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err) return fmt.Errorf("the hostname specified is not valid: %w", err)
} }
if vpnServiceProvider == providers.Custom { if vpnServiceProvider == providers.Custom {
@@ -196,19 +184,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
// which requires a server name for TLS verification. // which requires a server name for TLS verification.
filterChoices.Names = settings.Names filterChoices.Names = settings.Names
default: default:
return fmt.Errorf("%w: %d names specified instead of "+ return fmt.Errorf("name is not valid: "+
"0 or 1 for the custom provider", "%d names specified instead of 0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names)) len(settings.Names))
} }
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrNameNotValid, err) return fmt.Errorf("the server name specified is not valid: %w", err)
} }
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner) err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err) return fmt.Errorf("the category specified is not valid: %w", err)
} }
return nil return nil
@@ -255,12 +243,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
switch { switch {
case *settings.FreeOnly && case *settings.FreeOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited): !helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrFreeOnlyNotSupported) return errors.New("free only filter is not supported")
case *settings.PremiumOnly && case *settings.PremiumOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure): !helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported) return errors.New("premium only filter is not supported")
case *settings.FreeOnly && *settings.PremiumOnly: case *settings.FreeOnly && *settings.PremiumOnly:
return fmt.Errorf("%w", ErrFreePremiumBothSet) return errors.New("free only and premium only filters are both set")
default: default:
return nil return nil
} }
@@ -269,21 +257,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error { func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
switch { switch {
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad: case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported) return errors.New("owned only filter is not supported")
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly: case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported) return errors.New("port forwarding only filter is not supported: together with free only filter")
case *settings.StreamOnly && case *settings.StreamOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited): !helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrStreamOnlyNotSupported) return errors.New("stream only filter is not supported")
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark: case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported) return errors.New("multi hop only filter is not supported")
case *settings.PortForwardOnly && case *settings.PortForwardOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn): !helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported) return errors.New("port forwarding only filter is not supported")
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn: case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported) return errors.New("secure core only filter is not supported")
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn: case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrTorOnlyNotSupported) return errors.New("tor only filter is not supported")
default: default:
return nil return nil
} }
+8 -7
View File
@@ -1,6 +1,7 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"slices" "slices"
"strings" "strings"
@@ -37,20 +38,20 @@ type Updater struct {
func (u Updater) Validate() (err error) { func (u Updater) Validate() (err error) {
const minPeriod = time.Minute const minPeriod = time.Minute
if *u.Period > 0 && *u.Period < minPeriod { if *u.Period > 0 && *u.Period < minPeriod {
return fmt.Errorf("%w: %d must be larger than %s", return fmt.Errorf("VPN server data updater period is too small: "+
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod) "%d must be larger than %s", *u.Period, minPeriod)
} }
if u.MinRatio <= 0 || u.MinRatio > 1 { if u.MinRatio <= 0 || u.MinRatio > 1 {
return fmt.Errorf("%w: %.2f must be between 0+ and 1", return fmt.Errorf("minimum ratio is not valid: "+
ErrMinRatioNotValid, u.MinRatio) "%.2f must be between 0+ and 1", u.MinRatio)
} }
validProviders := providers.All() validProviders := providers.All()
for _, provider := range u.Providers { for _, provider := range u.Providers {
err = validate.IsOneOf(provider, validProviders...) err = validate.IsOneOf(provider, validProviders...)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err) return fmt.Errorf("VPN provider name is not valid: %w", err)
} }
if provider == providers.Protonvpn { if provider == providers.Protonvpn {
@@ -58,9 +59,9 @@ func (u Updater) Validate() (err error) {
if authenticatedAPI { if authenticatedAPI {
switch { switch {
case *u.ProtonEmail == "": case *u.ProtonEmail == "":
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing) return errors.New("proton email is missing")
case *u.ProtonPassword == "": case *u.ProtonPassword == "":
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing) return errors.New("proton password is missing")
} }
} }
} }
+1 -1
View File
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
// Validate Type // Validate Type
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard} validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil { if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err) return fmt.Errorf("VPN type is not valid: %w", err)
} }
err = v.Provider.validate(v.Type, filterChoicesGetter, warner) err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
+11 -15
View File
@@ -1,6 +1,7 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"regexp" "regexp"
@@ -54,7 +55,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) { func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
// Validate PrivateKey // Validate PrivateKey
if *w.PrivateKey == "" { if *w.PrivateKey == "" {
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet) return errors.New("private key is not set")
} }
_, err = wgtypes.ParseKey(*w.PrivateKey) _, err = wgtypes.ParseKey(*w.PrivateKey)
if err != nil { if err != nil {
@@ -68,7 +69,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
if vpnProvider == providers.Airvpn { if vpnProvider == providers.Airvpn {
if *w.PreSharedKey == "" { if *w.PreSharedKey == "" {
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet) return errors.New("pre-shared key is not set")
} }
} }
@@ -82,17 +83,15 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// Validate Addresses // Validate Addresses
if len(w.Addresses) == 0 { if len(w.Addresses) == 0 {
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet) return errors.New("interface address is not set")
} }
for i, ipNet := range w.Addresses { for i, ipNet := range w.Addresses {
if !ipNet.IsValid() { if !ipNet.IsValid() {
return fmt.Errorf("%w: for address at index %d", return fmt.Errorf("interface address is not set: for address at index %d", i)
ErrWireguardInterfaceAddressNotSet, i)
} }
if !ipv6Supported && ipNet.Addr().Is6() { if !ipv6Supported && ipNet.Addr().Is6() {
return fmt.Errorf("%w: address %s", return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
ErrWireguardInterfaceAddressIPv6, ipNet.String())
} }
} }
@@ -100,30 +99,27 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// WARNING: do not check for IPv6 networks in the allowed IPs, // WARNING: do not check for IPv6 networks in the allowed IPs,
// the wireguard code will take care to ignore it. // the wireguard code will take care to ignore it.
if len(w.AllowedIPs) == 0 { if len(w.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet) return errors.New("allowed IPs is not set")
} }
for i, allowedIP := range w.AllowedIPs { for i, allowedIP := range w.AllowedIPs {
if !allowedIP.IsValid() { if !allowedIP.IsValid() {
return fmt.Errorf("%w: for allowed ip %d of %d", return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
} }
} }
if *w.PersistentKeepaliveInterval < 0 { if *w.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative, return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
*w.PersistentKeepaliveInterval)
} }
// Validate interface // Validate interface
if !regexpInterfaceName.MatchString(w.Interface) { if !regexpInterfaceName.MatchString(w.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'", return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
} }
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
validImplementations := []string{"auto", "userspace", "kernelspace"} validImplementations := []string{"auto", "userspace", "kernelspace"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil { if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err) return fmt.Errorf("implementation is not valid: %w", err)
} }
} }
@@ -1,6 +1,7 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
@@ -44,7 +45,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// endpoint IP addresses are baked in // endpoint IP addresses are baked in
case providers.Custom: case providers.Custom:
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() { if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet) return errors.New("endpoint IP is not set")
} }
default: // Providers not supporting Wireguard default: // Providers not supporting Wireguard
} }
@@ -54,13 +55,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// EndpointPort is required // EndpointPort is required
case providers.Custom: case providers.Custom:
if *w.EndpointPort == 0 { if *w.EndpointPort == 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet) return errors.New("endpoint port is not set")
} }
// EndpointPort cannot be set // EndpointPort cannot be set
case providers.Fastestvpn, providers.Nordvpn, case providers.Fastestvpn, providers.Nordvpn,
providers.Protonvpn, providers.Surfshark: providers.Protonvpn, providers.Surfshark:
if *w.EndpointPort != 0 { if *w.EndpointPort != 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortSet) return errors.New("endpoint port is set")
} }
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe: case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
// EndpointPort is optional and can be 0 // EndpointPort is optional and can be 0
@@ -84,8 +85,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
if err == nil { if err == nil {
break break
} }
return fmt.Errorf("%w: for VPN service provider %s: %w", return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
default: // Providers not supporting Wireguard default: // Providers not supporting Wireguard
} }
@@ -96,15 +96,14 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// public keys are baked in // public keys are baked in
case providers.Custom: case providers.Custom:
if w.PublicKey == "" { if w.PublicKey == "" {
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet) return errors.New("public key is not set")
} }
default: // Providers not supporting Wireguard default: // Providers not supporting Wireguard
} }
if w.PublicKey != "" { if w.PublicKey != "" {
_, err := wgtypes.ParseKey(w.PublicKey) _, err := wgtypes.ParseKey(w.PublicKey)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s: %s", return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
} }
} }
@@ -74,8 +74,6 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
return privateKey, addresses return privateKey, addresses
} }
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
func parseWireguardPeerSection(peerSection *ini.Section) ( func parseWireguardPeerSection(peerSection *ini.Section) (
preSharedKey, publicKey, endpointIP, endpointPort *string, preSharedKey, publicKey, endpointIP, endpointPort *string,
) { ) {
+1 -4
View File
@@ -3,7 +3,6 @@ package dns
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math" "math"
"math/rand/v2" "math/rand/v2"
@@ -63,8 +62,6 @@ func generateRandomString(length uint) string {
return string(b) return string(b)
} }
var errIPLeakSessionMismatch = errors.New("ipleak.net session mismatch")
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) ( func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
dnsToCount map[string]uint, err error, dnsToCount map[string]uint, err error,
) { ) {
@@ -93,7 +90,7 @@ func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
if err != nil { if err != nil {
return nil, fmt.Errorf("decoding response: %w", err) return nil, fmt.Errorf("decoding response: %w", err)
} else if data.Session != session { } else if data.Session != session {
return nil, fmt.Errorf("%w: expected %s, got %s", errIPLeakSessionMismatch, session, data.Session) return nil, fmt.Errorf("ipleak.net session mismatch: expected %s, got %s", session, data.Session)
} }
return data.IP, nil return data.IP, nil
+5 -9
View File
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel() t.Parallel()
const iptablesBinary = "/sbin/iptables" const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct { testCases := map[string]struct {
instruction string instruction string
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string errMessage string
}{ }{
"invalid_instruction": { "invalid_instruction": {
instruction: "invalid", instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " + errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none", "iptables command is malformed: flag \"invalid\" requires a value, but got none",
}, },
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
runner := NewMockCmdRunner(ctrl) runner := NewMockCmdRunner(ctrl)
runner.EXPECT(). runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest) Return("", errors.New("test error"))
return runner return runner
}, },
makeLogger: func(ctrl *gomock.Controller) *MockLogger { makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v") logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger return logger
}, },
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` + errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`, `"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
}, },
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll "2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil) nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$", runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest) "^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
return runner return runner
}, },
makeLogger: func(ctrl *gomock.Controller) *MockLogger { makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2") logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger return logger
}, },
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details", errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
}, },
"rule_found_delete_success": { "rule_found_delete_success": {
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger) err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+1 -3
View File
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
return nil return nil
} }
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error { func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy) return fmt.Errorf("policy is not valid: %s", policy)
} }
return c.runIP6tablesInstructions(ctx, []string{ return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy, "--policy INPUT " + policy,
+9 -12
View File
@@ -2,7 +2,6 @@ package iptables
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
@@ -13,10 +12,8 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var ( const (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short") needIP6Tables = "ip6tables is required, please upgrade your kernel"
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
) )
func appendOrDelete(remove bool) string { func appendOrDelete(remove bool) string {
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
words := strings.Fields(output) words := strings.Fields(output)
const minWords = 2 const minWords = 2
if len(words) < minWords { if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output) return "", fmt.Errorf("iptables version string is too short: %s", output)
} }
return "iptables " + words[1], nil return "iptables " + words[1], nil
} }
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy) return fmt.Errorf("unknown policy: %s", policy)
} }
return c.runIptablesInstructions(ctx, []string{ return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy, "--policy INPUT " + policy,
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
return c.runIptablesInstruction(ctx, instruction) return c.runIptablesInstruction(ctx, instruction)
} }
if c.ip6Tables == "" { if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables) return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
} }
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
if connection.IP.Is4() { if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction) return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" { } else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables) return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
} }
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
if ip.Is4() { if ip.Is4() {
return c.runIptablesInstruction(ctx, instruction) return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" { } else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables) return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
} }
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
if doIPv4 { if doIPv4 {
return c.runIptablesInstruction(ctx, instruction) return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" { } else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables) return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
} }
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
case ipv4: case ipv4:
err = c.runIptablesInstructionNoSave(ctx, rule) err = c.runIptablesInstructionNoSave(ctx, rule)
case c.ip6Tables == "": case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables) err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
default: // ipv6 default: // ipv6
err = c.runIP6tablesInstructionNoSave(ctx, rule) err = c.runIP6tablesInstructionNoSave(ctx, rule)
} }
+27 -47
View File
@@ -40,8 +40,6 @@ type mark struct {
value uint value uint
} }
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
func parseChain(iptablesOutput string) (c chain, err error) { func parseChain(iptablesOutput string) (c chain, err error) {
// Text example: // Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes) // Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
const minLines = 2 // chain general information line + legend line const minLines = 2 // chain general information line + legend line
if len(lines) < minLines { if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s", return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput) iptablesOutput)
} }
c, err = parseChainGeneralDataLine(lines[0]) c, err = parseChainGeneralDataLine(lines[0])
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
legendLine := strings.TrimSpace(lines[1]) legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine) legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) { if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q", return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " ")) legendLine, strings.Join(expectedLegendFields, " "))
} }
lines = lines[2:] // remove chain general information line and legend line lines = lines[2:] // remove chain general information line and legend line
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
fields := strings.Fields(line) fields := strings.Fields(line)
const expectedNumberOfFields = 8 const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields { if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q", return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line) expectedNumberOfFields, line)
} }
// Sanity checks // Sanity checks
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
if fields[index] == expectedValue { if fields[index] == expectedValue {
continue continue
} }
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q", return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line) expectedValue, index, line)
} }
base.name = fields[1] // chain name could be custom base.name = fields[1] // chain name could be custom
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
return base, nil return base, nil
} }
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
func parseChainRuleLine(line string) (rule chainRule, err error) { func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" { if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed) return chainRule{}, errors.New("chain rule is malformed: empty line")
} }
fields := strings.Fields(line) fields := strings.Fields(line)
const minFields = 10 const minFields = 10
if len(fields) < minFields { if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed) return chainRule{}, errors.New("chain rule is malformed: not enough fields")
} }
for fieldIndex, field := range fields[:minFields] { for fieldIndex, field := range fields[:minFields] {
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) { func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" { if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex) return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
} }
const ( const (
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.redirPorts = ports rule.redirPorts = ports
i++ i++
default: default:
return fmt.Errorf("%w: unexpected %q after redir", return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1]) optionalFields[1])
} }
case "ctstate": case "ctstate":
i++ i++
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.mark = mark rule.mark = mark
i += consumed i += consumed
default: default:
return fmt.Errorf("%w: unexpected optional field: %s", return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i]) optionalFields[i])
} }
} }
return nil return nil
} }
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) { func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields { for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') { if !strings.ContainsRune(value, ':') {
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
} }
consumed++ consumed++
default: default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value) return 0, fmt.Errorf("unknown UDP optional field: %s", value)
} }
} }
return consumed, nil return consumed, nil
} }
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) { func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields { for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') { if !strings.ContainsRune(value, ':') {
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
} }
consumed++ consumed++
default: default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value) return 0, fmt.Errorf("unknown TCP optional field: %s", value)
} }
} }
return consumed, nil return consumed, nil
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
return parsePort(value) return parsePort(value)
} }
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) { func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:") value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/") fields := strings.Split(value, "/")
const expectedFields = 2 const expectedFields = 2
if len(fields) != expectedFields { if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q", return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value) value)
} }
maskFlags := strings.Split(fields[0], ",") maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags)) mask := make([]tcpFlag, len(maskFlags))
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil return ports, nil
} }
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) { func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] { switch optionalFields[consumed] {
case "match": case "match":
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
const bits = 32 const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits) value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil { if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed]) return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
} }
m.value = uint(value) m.value = uint(value)
consumed++ consumed++
default: default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s", return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed]) optionalFields[consumed])
} }
return m, consumed, nil return m, consumed, nil
} }
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) { func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16 const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength) lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil { if err != nil {
return 0, err return 0, err
} else if lineNumber == 0 { } else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero) return 0, errors.New("line number is zero")
} }
return uint16(lineNumber), nil return uint16(lineNumber), nil
} }
var ErrTargetUnknown = errors.New("unknown target")
func checkTarget(target string) (err error) { func checkTarget(target string) (err error) {
switch target { switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT": case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil return nil
} }
return fmt.Errorf("%w: %s", ErrTargetUnknown, target) return fmt.Errorf("unknown target: %s", target)
} }
var ErrProtocolUnknown = errors.New("unknown protocol")
func parseProtocol(s string) (protocol string, err error) { func parseProtocol(s string) (protocol string, err error) {
switch s { switch s {
case "0", "all": case "0", "all":
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
case "17", "udp": case "17", "udp":
protocol = "udp" protocol = "udp"
default: default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s) return "", fmt.Errorf("unknown protocol: %s", s)
} }
return protocol, nil return protocol, nil
} }
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
// parseMetricSize parses a metric size string like 140K or 226M and // parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it. // returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) { func parseMetricSize(size string) (n uint64, err error) {
if size == "" { if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed) return n, errors.New("metric size is malformed: empty string")
} }
//nolint:mnd //nolint:mnd
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
const base, bitLength = 10, 64 const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength) n, err = strconv.ParseUint(size, base, bitLength)
if err != nil { if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err) return n, fmt.Errorf("metric size is malformed: %w", err)
} }
n *= multiplier n *= multiplier
return n, nil return n, nil
+3 -7
View File
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
iptablesOutput string iptablesOutput string
table chain table chain
errWrapped error
errMessage string errMessage string
}{ }{
"no_output": { "no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ", errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
}, },
"single_line_only": { "single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`, iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " + errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)", "Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
}, },
"malformed_general_data_line": { "malformed_general_data_line": {
iptablesOutput: `Chain INPUT iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`, num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " + errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"", "expected 8 fields in \"Chain INPUT\"",
}, },
"malformed_legend": { "malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes) iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`, num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " + errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " + "\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"", "is not the expected \"num pkts bytes target prot opt in out source destination\"",
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
table, err := parseChain(testCase.iptablesOutput) table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table) assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+10 -12
View File
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified()) (!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
} }
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) { func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" { if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
} }
fields := strings.Fields(s) fields := strings.Fields(s)
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
return 0, fmt.Errorf("parsing TCP flags: %w", err) return 0, fmt.Errorf("parsing TCP flags: %w", err)
} }
default: default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag) return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
} }
return consumed, nil return consumed, nil
} }
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
case "--tcp-flags": // -m can have 1 or 2 values case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3 const expected = 3
if len(fields) < expected { if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) flag, strings.Join(fields, " "))
} }
return expected, nil return expected, nil
default: default:
const expected = 2 const expected = 2
if len(fields) < expected { if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none", return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag) flag)
} }
return expected, nil return expected, nil
} }
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed++ consumed++
instruction.mark.invert = true instruction.mark.invert = true
default: default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s", return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2]) fields[2])
} }
default: default:
return 0, fmt.Errorf("%w: unknown match value: %s", return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed]) fields[consumed])
} }
return consumed, nil return consumed, nil
} }
+3 -6
View File
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
s string s string
instruction iptablesInstruction instruction iptablesInstruction
errWrapped error
errMessage string errMessage string
}{ }{
"no_instruction": { "no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction", errMessage: "iptables command is malformed: empty instruction",
}, },
"uneven_fields": { "uneven_fields": {
s: "-A", s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none", errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
}, },
"unknown_key": { "unknown_key": {
s: "-x something", s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"", errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
}, },
"one_pair": { "one_pair": {
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
rule, err := parseIptablesInstruction(testCase.s) rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule) assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+4 -9
View File
@@ -10,12 +10,7 @@ import (
"strings" "strings"
) )
var ( var ErrNotSupported = errors.New("no iptables supported found")
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
)
func checkIptablesSupport(ctx context.Context, runner CmdRunner, func checkIptablesSupport(ctx context.Context, runner CmdRunner,
iptablesPathsToTry ...string, iptablesPathsToTry ...string,
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
if allArePermissionDenied { if allArePermissionDenied {
// If the error is related to a denied permission for all iptables path, // If the error is related to a denied permission for all iptables path,
// return an error describing what to do from an end-user perspective. // return an error describing what to do from an end-user perspective.
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; ")) return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
} }
return "", fmt.Errorf("%w: errors encountered are: %s", return "", fmt.Errorf("%w: errors encountered are: %s",
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
output, err = runner.Run(cmd) output, err = runner.Run(cmd)
if err != nil { if err != nil {
// this is a critical error, we want to make sure our test rule gets removed. // this is a critical error, we want to make sure our test rule gets removed.
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err) criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
return false, "", criticalErr return false, "", criticalErr
} }
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
} }
if inputPolicy == "" { if inputPolicy == "" {
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output) criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
return false, "", criticalErr return false, "", criticalErr
} }
+6 -12
View File
@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func newAppendTestRuleMatcher(path string) *cmdMatcher { func newAppendTestRuleMatcher(path string) *cmdMatcher {
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner buildRunner func(ctrl *gomock.Controller) CmdRunner
iptablesPathsToTry []string iptablesPathsToTry []string
iptablesPath string iptablesPath string
errSentinel error
errMessage string errMessage string
}{ }{
"critical error when checking": { "critical error when checking": {
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner return runner
}, },
iptablesPathsToTry: []string{"path1", "path2"}, iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrTestRuleCleanup,
errMessage: "for path1: failed cleaning up test rule: " + errMessage: "for path1: failed cleaning up test rule: " +
"output (exit code 4)", "output (exit code 4)",
}, },
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner return runner
}, },
iptablesPathsToTry: []string{"path1", "path2"}, iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNetAdminMissing,
errMessage: "NET_ADMIN capability is missing: " + errMessage: "NET_ADMIN capability is missing: " +
"path1: Permission denied (you must be root) more context (exit code 4); " + "path1: Permission denied (you must be root) more context (exit code 4); " +
"path2: context: Permission denied (you must be root) (exit code 4)", "path2: context: Permission denied (you must be root) (exit code 4)",
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner return runner
}, },
iptablesPathsToTry: []string{"path1", "path2"}, iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " + errMessage: "no iptables supported found: " +
"errors encountered are: " + "errors encountered are: " +
"path1: output 1 (exit code 4); " + "path1: output 1 (exit code 4); " +
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...) iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
require.ErrorIs(t, err, testCase.errSentinel) if testCase.errMessage != "" {
if testCase.errSentinel != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
assert.Equal(t, testCase.iptablesPath, iptablesPath) assert.Equal(t, testCase.iptablesPath, iptablesPath)
}) })
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner buildRunner func(ctrl *gomock.Controller) CmdRunner
ok bool ok bool
unsupportedMessage string unsupportedMessage string
criticalErrWrapped error
criticalErrMessage string criticalErrMessage string
}{ }{
"append test rule permission denied": { "append test rule permission denied": {
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some output", errDummy) Return("some output", errDummy)
return runner return runner
}, },
criticalErrWrapped: ErrTestRuleCleanup,
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)", criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
}, },
"list input rules permission denied": { "list input rules permission denied": {
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some\noutput", nil) Return("some\noutput", nil)
return runner return runner
}, },
criticalErrWrapped: ErrInputPolicyNotFound,
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput", criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
}, },
"set policy permission denied": { "set policy permission denied": {
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
assert.Equal(t, testCase.ok, ok) assert.Equal(t, testCase.ok, ok)
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage) assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped) if testCase.criticalErrMessage != "" {
if testCase.criticalErrWrapped != nil {
assert.EqualError(t, criticalErr, testCase.criticalErrMessage) assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
} else {
assert.NoError(t, criticalErr)
} }
}) })
} }
+2 -4
View File
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
case tcpFlagCWR: case tcpFlagCWR:
return "CWR" return "CWR"
default: default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f)) panic(fmt.Sprintf("unknown TCP flag: %d", f))
} }
} }
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) { func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{ allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH, tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
return flag, nil return flag, nil
} }
} }
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s) return 0, fmt.Errorf("unknown TCP flag: %s", s)
} }
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so") var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
+2 -4
View File
@@ -266,8 +266,6 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
return address, nil return address, nil
} }
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
func withRetries(ctx context.Context, tryTimeouts []time.Duration, func withRetries(ctx context.Context, tryTimeouts []time.Duration,
logger Logger, checkName string, check func(ctx context.Context, try int) error, logger Logger, checkName string, check func(ctx context.Context, try int) error,
) error { ) error {
@@ -297,7 +295,7 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
for i, err := range errs { for i, err := range errs {
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err) errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
} }
return fmt.Errorf("%w:\n\t%s", ErrAllCheckTriesFailed, strings.Join(errStrings, "\n\t")) return fmt.Errorf("all check tries failed:\n\t%s", strings.Join(errStrings, "\n\t"))
} }
func (c *Checker) startupCheck(ctx context.Context) error { func (c *Checker) startupCheck(ctx context.Context) error {
@@ -342,7 +340,7 @@ func (c *Checker) startupCheck(ctx context.Context) error {
for i, err := range errs { for i, err := range errs {
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err) errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
} }
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) return fmt.Errorf("all check tries failed: %s", strings.Join(errStrings, ", "))
} }
const ( const (
+4 -5
View File
@@ -2,7 +2,6 @@ package healthcheck
import ( import (
"context" "context"
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
@@ -68,7 +67,7 @@ func Test_makeAddressToDial(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
address string address string
addressToDial string addressToDial string
err error errMessage string
}{ }{
"host without port": { "host without port": {
address: "test.com", address: "test.com",
@@ -80,7 +79,7 @@ func Test_makeAddressToDial(t *testing.T) {
}, },
"bad address": { "bad address": {
address: "test.com::", address: "test.com::",
err: fmt.Errorf("splitting host and port from address: address test.com::: too many colons in address"), //nolint:lll errMessage: "splitting host and port from address: address test.com::: too many colons in address",
}, },
} }
@@ -91,8 +90,8 @@ func Test_makeAddressToDial(t *testing.T) {
addressToDial, err := makeAddressToDial(testCase.address) addressToDial, err := makeAddressToDial(testCase.address)
assert.Equal(t, testCase.addressToDial, addressToDial) assert.Equal(t, testCase.addressToDial, addressToDial)
if testCase.err != nil { if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.err.Error()) assert.EqualError(t, err, testCase.errMessage)
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
} }
+1 -4
View File
@@ -2,15 +2,12 @@ package healthcheck
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time" "time"
) )
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
} }
@@ -41,6 +38,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
if err != nil { if err != nil {
return err return err
} }
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK, return fmt.Errorf("HTTP response status is not OK: %d %s: %s",
response.StatusCode, response.Status, string(b)) response.StatusCode, response.Status, string(b))
} }
+1 -4
View File
@@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -41,8 +40,6 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
return result return result
} }
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
func (c *Client) Check(ctx context.Context) error { func (c *Client) Check(ctx context.Context) error {
dnsAddr := c.serverAddrs[c.dnsIPIndex].String() dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
resolver := &net.Resolver{ resolver := &net.Resolver{
@@ -59,7 +56,7 @@ func (c *Client) Check(ctx context.Context) error {
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err) return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
case len(ips) == 0: case len(ips) == 0:
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs) c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs) return fmt.Errorf("with DNS server %s: no IPs found from DNS lookup", dnsAddr)
default: default:
return nil return nil
} }
+1 -3
View File
@@ -12,11 +12,9 @@ type handler struct {
logger Logger logger Logger
} }
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
func newHandler(logger Logger) *handler { func newHandler(logger Logger) *handler {
return &handler{ return &handler{
healthErr: errHealthcheckNotRunYet, healthErr: errors.New("healthcheck did not run yet"),
logger: logger, logger: logger,
} }
} }
+6 -13
View File
@@ -19,11 +19,6 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
var (
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
)
type Echoer struct { type Echoer struct {
buffer []byte buffer []byte
randomSource io.Reader randomSource io.Reader
@@ -60,10 +55,7 @@ func (e *Echoer) Reset() {
e.seqStart = time.Now() e.seqStart = time.Now()
} }
var ( var ErrNotPermitted = errors.New("not permitted")
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
ErrNotPermitted = errors.New("not permitted")
)
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) { func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
var ipVersion string var ipVersion string
@@ -114,14 +106,14 @@ func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger) receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil { if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
return fmt.Errorf("%w from %s", ErrTimedOut, ip) return fmt.Errorf("timed out waiting for ICMP echo reply from %s", ip)
} }
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err) return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
} }
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
if !bytes.Equal(receivedData, sentData) { if !bytes.Equal(receivedData, sentData) {
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData) return fmt.Errorf("ICMP data mismatch: sent %x to %s and received %x", sentData, ip, receivedData)
} }
return nil return nil
@@ -216,8 +208,9 @@ func receiveEchoReply(conn net.PacketConn, id, seq int, buffer []byte, ipVersion
message.Code, returnAddr, id, seq) message.Code, returnAddr, id, seq)
continue continue
default: default:
return nil, fmt.Errorf("%w: %T (type %d, code %d, return address %s, expected id %d and seq %d)", return nil, fmt.Errorf("ICMP body type is not supported: "+
ErrICMPBodyUnsupported, body, message.Type, message.Code, returnAddr, id, seq) "%T (type %d, code %d, return address %s, expected id %d and seq %d)",
body, message.Type, message.Code, returnAddr, id, seq)
} }
} }
} }
+4 -6
View File
@@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger //go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
@@ -20,11 +19,9 @@ func Test_New(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
settings Settings settings Settings
expected *Server expected *Server
errWrapped error
errMessage string errMessage string
}{ }{
"empty settings": { "empty settings": {
errWrapped: ErrHandlerIsNotSet,
errMessage: "http server settings validation failed: HTTP handler cannot be left unset", errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
}, },
"filled settings": { "filled settings": {
@@ -52,9 +49,10 @@ func Test_New(t *testing.T) {
t.Parallel() t.Parallel()
server, err := New(testCase.settings) server, err := New(testCase.settings)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil { assert.EqualError(t, err, testCase.errMessage)
require.EqualError(t, err, testCase.errMessage) } else {
assert.NoError(t, err)
} }
if server != nil { if server != nil {
+5 -19
View File
@@ -64,14 +64,6 @@ func (s *Settings) OverrideWith(other Settings) {
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout) s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
} }
var (
ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset")
ErrLoggerIsNotSet = errors.New("logger cannot be left unset")
ErrReadHeaderTimeoutTooSmall = errors.New("read header timeout is too small")
ErrReadTimeoutTooSmall = errors.New("read timeout is too small")
ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small")
)
func (s Settings) Validate() (err error) { func (s Settings) Validate() (err error) {
err = validate.ListeningAddress(s.Address, os.Getuid()) err = validate.ListeningAddress(s.Address, os.Getuid())
if err != nil { if err != nil {
@@ -79,31 +71,25 @@ func (s Settings) Validate() (err error) {
} }
if s.Handler == nil { if s.Handler == nil {
return fmt.Errorf("%w", ErrHandlerIsNotSet) return errors.New("HTTP handler cannot be left unset")
} }
if s.Logger == nil { if s.Logger == nil {
return fmt.Errorf("%w", ErrLoggerIsNotSet) return errors.New("logger cannot be left unset")
} }
const minReadTimeout = time.Millisecond const minReadTimeout = time.Millisecond
if s.ReadHeaderTimeout < minReadTimeout { if s.ReadHeaderTimeout < minReadTimeout {
return fmt.Errorf("%w: %s must be at least %s", return fmt.Errorf("read header timeout is too small: %s must be at least %s", s.ReadHeaderTimeout, minReadTimeout)
ErrReadHeaderTimeoutTooSmall,
s.ReadHeaderTimeout, minReadTimeout)
} }
if s.ReadTimeout < minReadTimeout { if s.ReadTimeout < minReadTimeout {
return fmt.Errorf("%w: %s must be at least %s", return fmt.Errorf("read timeout is too small: %s must be at least %s", s.ReadTimeout, minReadTimeout)
ErrReadTimeoutTooSmall,
s.ReadTimeout, minReadTimeout)
} }
const minShutdownTimeout = 5 * time.Millisecond const minShutdownTimeout = 5 * time.Millisecond
if s.ShutdownTimeout < minShutdownTimeout { if s.ShutdownTimeout < minShutdownTimeout {
return fmt.Errorf("%w: %s must be at least %s", return fmt.Errorf("shutdown timeout is too small: %s must be at least %s", s.ShutdownTimeout, minShutdownTimeout)
ErrShutdownTimeoutTooSmall,
s.ShutdownTimeout, minShutdownTimeout)
} }
return nil return nil
+4 -11
View File
@@ -5,7 +5,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/qdm12/gosettings/validate"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -189,30 +188,26 @@ func Test_Settings_Validate(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
settings Settings settings Settings
errWrapped error
errMessage string errMessage string
}{ }{
"bad_address": { "bad_address": {
settings: Settings{ settings: Settings{
Address: "address:notanint", Address: "address:notanint",
}, },
errWrapped: validate.ErrPortNotAnInteger,
errMessage: "port value is not an integer: notanint", errMessage: "port value is not an integer: notanint",
}, },
"nil handler": { "nil handler": {
settings: Settings{ settings: Settings{
Address: ":8000", Address: ":8000",
}, },
errWrapped: ErrHandlerIsNotSet, errMessage: "HTTP handler cannot be left unset",
errMessage: ErrHandlerIsNotSet.Error(),
}, },
"nil logger": { "nil logger": {
settings: Settings{ settings: Settings{
Address: ":8000", Address: ":8000",
Handler: someHandler, Handler: someHandler,
}, },
errWrapped: ErrLoggerIsNotSet, errMessage: "logger cannot be left unset",
errMessage: ErrLoggerIsNotSet.Error(),
}, },
"read header timeout too small": { "read header timeout too small": {
settings: Settings{ settings: Settings{
@@ -221,7 +216,6 @@ func Test_Settings_Validate(t *testing.T) {
Logger: someLogger, Logger: someLogger,
ReadHeaderTimeout: time.Nanosecond, ReadHeaderTimeout: time.Nanosecond,
}, },
errWrapped: ErrReadHeaderTimeoutTooSmall,
errMessage: "read header timeout is too small: 1ns must be at least 1ms", errMessage: "read header timeout is too small: 1ns must be at least 1ms",
}, },
"read timeout too small": { "read timeout too small": {
@@ -232,7 +226,6 @@ func Test_Settings_Validate(t *testing.T) {
ReadHeaderTimeout: time.Millisecond, ReadHeaderTimeout: time.Millisecond,
ReadTimeout: time.Nanosecond, ReadTimeout: time.Nanosecond,
}, },
errWrapped: ErrReadTimeoutTooSmall,
errMessage: "read timeout is too small: 1ns must be at least 1ms", errMessage: "read timeout is too small: 1ns must be at least 1ms",
}, },
"shutdown timeout too small": { "shutdown timeout too small": {
@@ -244,7 +237,6 @@ func Test_Settings_Validate(t *testing.T) {
ReadTimeout: time.Millisecond, ReadTimeout: time.Millisecond,
ShutdownTimeout: time.Millisecond, ShutdownTimeout: time.Millisecond,
}, },
errWrapped: ErrShutdownTimeoutTooSmall,
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms", errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
}, },
"valid settings": { "valid settings": {
@@ -265,9 +257,10 @@ func Test_Settings_Validate(t *testing.T) {
err := testCase.settings.Validate() err := testCase.settings.Validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errMessage != "" { if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+2 -5
View File
@@ -2,15 +2,12 @@ package loopstate
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var ErrInvalidStatus = errors.New("invalid status")
// ApplyStatus sends signals to the running loop depending on the // ApplyStatus sends signals to the running loop depending on the
// current status and status requested, such that its next status // current status and status requested, such that its next status
// matches the requested one. It is thread safe and a synchronous call // matches the requested one. It is thread safe and a synchronous call
@@ -73,7 +70,7 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
return newStatus.String(), nil return newStatus.String(), nil
default: default:
s.statusMu.Unlock() s.statusMu.Unlock()
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("invalid status: %s: it can only be one of: %s, %s",
ErrInvalidStatus, status, constants.Running, constants.Stopped) status, constants.Running, constants.Stopped)
} }
} }
+4 -12
View File
@@ -3,19 +3,11 @@ package mod
import ( import (
"bufio" "bufio"
"compress/gzip" "compress/gzip"
"errors"
"fmt" "fmt"
"os" "os"
"strings" "strings"
) )
var (
errModuleNameUnknown = errors.New("unknown module name")
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
errKernelFeatureNotSet = errors.New("kernel feature not set")
errKernelFeatureNotFound = errors.New("kernel feature not found")
)
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding // checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
// to the given module name. If the kernel feature is found and set to "y", it returns nil. // to the given module name. If the kernel feature is found and set to "y", it returns nil.
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel // If the kernel feature is found and set to "m", it returns an error indicating that the kernel
@@ -39,7 +31,7 @@ func checkProcConfig(moduleName string) error {
// If any group of kernel features is satisfied, then the module is considered supported. // If any group of kernel features is satisfied, then the module is considered supported.
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName) kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
if !ok { if !ok {
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName) return fmt.Errorf("unknown module name: %s", moduleName)
} }
groups := make([]map[string]bool, len(kernelFeatureGroups)) groups := make([]map[string]bool, len(kernelFeatureGroups))
for i, group := range kernelFeatureGroups { for i, group := range kernelFeatureGroups {
@@ -58,20 +50,20 @@ func checkProcConfig(moduleName string) error {
switch { switch {
case ok: case ok:
case strings.HasPrefix(line, name+"=m"): case strings.HasPrefix(line, name+"=m"):
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name) return fmt.Errorf("kernel feature is a module, not built-in: %s", name)
case strings.HasPrefix(line, name+"=y"): case strings.HasPrefix(line, name+"=y"):
featureToOK[name] = true featureToOK[name] = true
if allFeaturesOK(featureToOK) { if allFeaturesOK(featureToOK) {
return nil return nil
} }
case strings.HasPrefix(line, "# "+name+" is not set"): case strings.HasPrefix(line, "# "+name+" is not set"):
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name) return fmt.Errorf("kernel feature not set: %s", name)
} }
} }
} }
} }
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName) return fmt.Errorf("kernel feature not found: for module name %s", moduleName)
} }
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) { func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
+1 -3
View File
@@ -181,8 +181,6 @@ func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) {
return nil return nil
} }
var ErrModulePathNotFound = errors.New("module path not found")
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) { func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
// Kernel module names can have underscores or hyphens in their names, // Kernel module names can have underscores or hyphens in their names,
// but only one or the other in one particular name. // but only one or the other in one particular name.
@@ -205,5 +203,5 @@ func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modul
} }
} }
return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName) return "", fmt.Errorf("module path not found: for %q", moduleName)
} }
+2 -9
View File
@@ -1,7 +1,6 @@
package mod package mod
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -14,15 +13,10 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var (
ErrModuleInfoNotFound = errors.New("module info not found")
ErrCircularDependency = errors.New("circular dependency")
)
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) { func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
info, ok := modulesInfo[path] info, ok := modulesInfo[path]
if !ok { if !ok {
return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path) return fmt.Errorf("module info not found: %s", path)
} }
switch info.state { switch info.state {
@@ -30,8 +24,7 @@ func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error
case loaded, builtin: case loaded, builtin:
return nil return nil
case loading: case loading:
return fmt.Errorf("%w: %s is already in the loading state", return fmt.Errorf("circular dependency: %s is already in the loading state", path)
ErrCircularDependency, path)
} }
info.state = loading info.state = loading
+1 -4
View File
@@ -1,7 +1,6 @@
package models package models
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
@@ -109,8 +108,6 @@ func (s *Servers) toMarkdown(vpnProvider string) (formatted string, err error) {
return formatted, nil return formatted, nil
} }
var ErrMarkdownHeadersNotDefined = errors.New("markdown headers not defined")
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) { func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
switch vpnProvider { switch vpnProvider {
case providers.Airvpn: case providers.Airvpn:
@@ -169,6 +166,6 @@ func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
case providers.Windscribe: case providers.Windscribe:
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
default: default:
return nil, fmt.Errorf("%w: for %s", ErrMarkdownHeadersNotDefined, vpnProvider) return nil, fmt.Errorf("markdown headers not defined: for %s", vpnProvider)
} }
} }
+3 -4
View File
@@ -15,12 +15,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
provider string provider string
servers Servers servers Servers
formatted string formatted string
errWrapped error
errMessage string errMessage string
}{ }{
"unsupported_provider": { "unsupported_provider": {
provider: "unsupported", provider: "unsupported",
errWrapped: ErrMarkdownHeadersNotDefined,
errMessage: "getting markdown headers: markdown headers not defined: for unsupported", errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
}, },
providers.Cyberghost: { providers.Cyberghost: {
@@ -58,9 +56,10 @@ func Test_Servers_ToMarkdown(t *testing.T) {
markdown, err := testCase.servers.toMarkdown(testCase.provider) markdown, err := testCase.servers.toMarkdown(testCase.provider)
assert.Equal(t, testCase.formatted, markdown) assert.Equal(t, testCase.formatted, markdown)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+5 -14
View File
@@ -38,27 +38,18 @@ type Server struct {
IPs []netip.Addr `json:"ips,omitempty"` IPs []netip.Addr `json:"ips,omitempty"`
} }
var (
ErrVPNFieldEmpty = errors.New("vpn field is empty")
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
ErrIPsFieldEmpty = errors.New("ips field is empty")
ErrNoNetworkProtocol = errors.New("both TCP and UDP fields are false for OpenVPN")
ErrNetworkProtocolSet = errors.New("no network protocol should be set")
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
)
func (s *Server) HasMinimumInformation() (err error) { func (s *Server) HasMinimumInformation() (err error) {
switch { switch {
case s.VPN == "": case s.VPN == "":
return fmt.Errorf("%w", ErrVPNFieldEmpty) return errors.New("vpn field is empty")
case len(s.IPs) == 0: case len(s.IPs) == 0:
return fmt.Errorf("%w", ErrIPsFieldEmpty) return errors.New("ips field is empty")
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP): case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
return fmt.Errorf("%w", ErrNetworkProtocolSet) return errors.New("no network protocol should be set")
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP: case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
return fmt.Errorf("%w", ErrNoNetworkProtocol) return errors.New("both TCP and UDP fields are false for OpenVPN")
case s.VPN == vpn.Wireguard && s.WgPubKey == "": case s.VPN == vpn.Wireguard && s.WgPubKey == "":
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty) return errors.New("wireguard public key field is empty")
default: default:
return nil return nil
} }
+1 -4
View File
@@ -3,7 +3,6 @@ package models
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
@@ -158,8 +157,6 @@ type Servers struct {
Servers []Server `json:"servers,omitempty"` Servers []Server `json:"servers,omitempty"`
} }
var ErrServersFormatNotSupported = errors.New("servers format not supported")
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) { func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
switch format { switch format {
case "markdown": case "markdown":
@@ -167,7 +164,7 @@ func (s *Servers) Format(vpnProvider, format string) (formatted string, err erro
case "json": case "json":
return s.toJSON() return s.toJSON()
default: default:
return "", fmt.Errorf("%w: %s", ErrServersFormatNotSupported, format) return "", fmt.Errorf("servers format not supported: %s", format)
} }
} }
+9 -8
View File
@@ -16,7 +16,6 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
allServers *AllServers allServers *AllServers
dataString string dataString string
errWrapped error
errMessage string errMessage string
}{ }{
"no provider": { "no provider": {
@@ -58,16 +57,18 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
t.Parallel() t.Parallel()
data, err := testCase.allServers.MarshalJSON() data, err := testCase.allServers.MarshalJSON()
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if err != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
require.Equal(t, testCase.dataString, string(data)) require.Equal(t, testCase.dataString, string(data))
data, err = json.Marshal(testCase.allServers) data, err = json.Marshal(testCase.allServers)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if err != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
require.Equal(t, testCase.dataString, string(data)) require.Equal(t, testCase.dataString, string(data))
@@ -87,7 +88,6 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
dataString string dataString string
allServers AllServers allServers AllServers
errWrapped error
errMessage string errMessage string
}{ }{
"empty": { "empty": {
@@ -131,9 +131,10 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
err := json.Unmarshal(data, &allServers) err := json.Unmarshal(data, &allServers)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if err != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
assert.Equal(t, testCase.allServers, allServers) assert.Equal(t, testCase.allServers, allServers)
}) })
+16 -33
View File
@@ -6,48 +6,40 @@ import (
"fmt" "fmt"
) )
var ErrRequestSizeTooSmall = errors.New("message size is too small")
func checkRequest(request []byte) (err error) { func checkRequest(request []byte) (err error) {
const minMessageSize = 2 // version number + operation code const minMessageSize = 2 // version number + operation code
if len(request) < minMessageSize { if len(request) < minMessageSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", return fmt.Errorf("message size is too small: need at least %d bytes and got %d byte(s)",
ErrRequestSizeTooSmall, minMessageSize, len(request)) minMessageSize, len(request))
} }
return nil return nil
} }
var (
ErrResponseSizeTooSmall = errors.New("response size is too small")
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
)
func checkResponse(response []byte, expectedOperationCode byte, func checkResponse(response []byte, expectedOperationCode byte,
expectedResponseSize uint, expectedResponseSize uint,
) (err error) { ) (err error) {
const minResponseSize = 4 const minResponseSize = 4
if len(response) < minResponseSize { if len(response) < minResponseSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", return fmt.Errorf("response size is too small: "+
ErrResponseSizeTooSmall, minResponseSize, len(response)) "need at least %d bytes and got %d byte(s)",
minResponseSize, len(response))
} }
if uint(len(response)) != expectedResponseSize { if uint(len(response)) != expectedResponseSize {
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)", return fmt.Errorf("response size is unexpected: "+
ErrResponseSizeUnexpected, expectedResponseSize, len(response)) "expected %d bytes and got %d byte(s)",
expectedResponseSize, len(response))
} }
protocolVersion := response[0] protocolVersion := response[0]
if protocolVersion != 0 { if protocolVersion != 0 {
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion) return fmt.Errorf("protocol version is unknown: %d", protocolVersion)
} }
operationCode := response[1] operationCode := response[1]
if operationCode != expectedOperationCode { if operationCode != expectedOperationCode {
return fmt.Errorf("%w: expected 0x%x and got 0x%x", return fmt.Errorf("operation code is unexpected: expected 0x%x and got 0x%x", expectedOperationCode, operationCode)
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
} }
resultCode := binary.BigEndian.Uint16(response[2:4]) resultCode := binary.BigEndian.Uint16(response[2:4])
@@ -59,15 +51,6 @@ func checkResponse(response []byte, expectedOperationCode byte,
return nil return nil
} }
var (
ErrVersionNotSupported = errors.New("version is not supported")
ErrNotAuthorized = errors.New("not authorized")
ErrNetworkFailure = errors.New("network failure")
ErrOutOfResources = errors.New("out of resources")
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
ErrResultCodeUnknown = errors.New("result code is unknown")
)
// checkResultCode checks the result code and returns an error // checkResultCode checks the result code and returns an error
// if the result code is not a success (0). // if the result code is not a success (0).
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5 // See https://www.ietf.org/rfc/rfc6886.html#section-3.5
@@ -78,16 +61,16 @@ func checkResultCode(resultCode uint16) (err error) {
case 0: case 0:
return nil return nil
case 1: case 1:
return fmt.Errorf("%w", ErrVersionNotSupported) return errors.New("version is not supported")
case 2: case 2:
return fmt.Errorf("%w", ErrNotAuthorized) return errors.New("not authorized")
case 3: case 3:
return fmt.Errorf("%w", ErrNetworkFailure) return errors.New("network failure")
case 4: case 4:
return fmt.Errorf("%w", ErrOutOfResources) return errors.New("out of resources")
case 5: case 5:
return fmt.Errorf("%w", ErrOperationCodeNotSupported) return errors.New("operation code is not supported")
default: default:
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode) return fmt.Errorf("result code is unknown: %d", resultCode)
} }
} }
+21 -17
View File
@@ -1,6 +1,7 @@
package natpmp package natpmp
import ( import (
"errors"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -11,12 +12,10 @@ func Test_checkRequest(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
request []byte request []byte
err error
errMessage string errMessage string
}{ }{
"too_short": { "too_short": {
request: []byte{1}, request: []byte{1},
err: ErrRequestSizeTooSmall,
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)", errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
}, },
"success": { "success": {
@@ -30,9 +29,10 @@ func Test_checkRequest(t *testing.T) {
err := checkRequest(testCase.request) err := checkRequest(testCase.request)
assert.ErrorIs(t, err, testCase.err) if testCase.errMessage != "" {
if testCase.err != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
@@ -50,33 +50,33 @@ func Test_checkResponse(t *testing.T) {
}{ }{
"too_short": { "too_short": {
response: []byte{1}, response: []byte{1},
err: ErrResponseSizeTooSmall, err: errors.New("response size is too small"),
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)", errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
}, },
"size_mismatch": { "size_mismatch": {
response: []byte{0, 0, 0, 0}, response: []byte{0, 0, 0, 0},
expectedResponseSize: 5, expectedResponseSize: 5,
err: ErrResponseSizeUnexpected, err: errors.New("response size is unexpected"),
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)", errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
}, },
"protocol_unknown": { "protocol_unknown": {
response: []byte{1, 0, 0, 0}, response: []byte{1, 0, 0, 0},
expectedResponseSize: 4, expectedResponseSize: 4,
err: ErrProtocolVersionUnknown, err: errors.New("protocol version is unknown"),
errMessage: "protocol version is unknown: 1", errMessage: "protocol version is unknown: 1",
}, },
"operation_code_unexpected": { "operation_code_unexpected": {
response: []byte{0, 2, 0, 0}, response: []byte{0, 2, 0, 0},
expectedOperationCode: 1, expectedOperationCode: 1,
expectedResponseSize: 4, expectedResponseSize: 4,
err: ErrOperationCodeUnexpected, err: errors.New("operation code is unexpected"),
errMessage: "operation code is unexpected: expected 0x1 and got 0x2", errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
}, },
"result_code_failure": { "result_code_failure": {
response: []byte{0, 1, 0, 1}, response: []byte{0, 1, 0, 1},
expectedOperationCode: 1, expectedOperationCode: 1,
expectedResponseSize: 4, expectedResponseSize: 4,
err: ErrVersionNotSupported, err: errors.New("version is not supported"),
errMessage: "result code: version is not supported", errMessage: "result code: version is not supported",
}, },
"success": { "success": {
@@ -94,9 +94,11 @@ func Test_checkResponse(t *testing.T) {
testCase.expectedOperationCode, testCase.expectedOperationCode,
testCase.expectedResponseSize) testCase.expectedResponseSize)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil { if testCase.err != nil {
assert.ErrorContains(t, err, testCase.err.Error())
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
@@ -113,32 +115,32 @@ func Test_checkResultCode(t *testing.T) {
"success": {}, "success": {},
"version_unsupported": { "version_unsupported": {
resultCode: 1, resultCode: 1,
err: ErrVersionNotSupported, err: errors.New("version is not supported"),
errMessage: "version is not supported", errMessage: "version is not supported",
}, },
"not_authorized": { "not_authorized": {
resultCode: 2, resultCode: 2,
err: ErrNotAuthorized, err: errors.New("not authorized"),
errMessage: "not authorized", errMessage: "not authorized",
}, },
"network_failure": { "network_failure": {
resultCode: 3, resultCode: 3,
err: ErrNetworkFailure, err: errors.New("network failure"),
errMessage: "network failure", errMessage: "network failure",
}, },
"out_of_resources": { "out_of_resources": {
resultCode: 4, resultCode: 4,
err: ErrOutOfResources, err: errors.New("out of resources"),
errMessage: "out of resources", errMessage: "out of resources",
}, },
"unsupported_operation_code": { "unsupported_operation_code": {
resultCode: 5, resultCode: 5,
err: ErrOperationCodeNotSupported, err: errors.New("operation code is not supported"),
errMessage: "operation code is not supported", errMessage: "operation code is not supported",
}, },
"unknown": { "unknown": {
resultCode: 6, resultCode: 6,
err: ErrResultCodeUnknown, err: errors.New("result code is unknown"),
errMessage: "result code is unknown: 6", errMessage: "result code is unknown: 6",
}, },
} }
@@ -149,9 +151,11 @@ func Test_checkResultCode(t *testing.T) {
err := checkResultCode(testCase.resultCode) err := checkResultCode(testCase.resultCode)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil { if testCase.err != nil {
assert.ErrorContains(t, err, testCase.err.Error())
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+4 -9
View File
@@ -3,17 +3,11 @@ package natpmp
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"time" "time"
) )
var (
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
ErrLifetimeTooLong = errors.New("lifetime is too long")
)
// Add or delete a port mapping. To delete a mapping, set both the // Add or delete a port mapping. To delete a mapping, set both the
// requestedExternalPort and lifetime to 0. // requestedExternalPort and lifetime to 0.
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3 // See https://www.ietf.org/rfc/rfc6886.html#section-3.3
@@ -26,8 +20,9 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
lifetimeSecondsFloat := lifetime.Seconds() lifetimeSecondsFloat := lifetime.Seconds()
const maxLifetimeSeconds = uint64(^uint32(0)) const maxLifetimeSeconds = uint64(^uint32(0))
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds { if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds", return 0, 0, 0, 0, fmt.Errorf("lifetime is too long: "+
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds) "%d seconds must at most %d seconds",
uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
} }
const messageSize = 12 const messageSize = 12
message := make([]byte, messageSize) message := make([]byte, messageSize)
@@ -38,7 +33,7 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
case "tcp": case "tcp":
message[1] = 2 // operationCode 2 message[1] = 2 // operationCode 2
default: default:
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol) return 0, 0, 0, 0, fmt.Errorf("network protocol is unknown: %s", protocol)
} }
// [2:3] are reserved. // [2:3] are reserved.
binary.BigEndian.PutUint16(message[4:6], internalPort) binary.BigEndian.PutUint16(message[4:6], internalPort)
-7
View File
@@ -25,18 +25,15 @@ func Test_Client_AddPortMapping(t *testing.T) {
assignedInternalPort uint16 assignedInternalPort uint16
assignedExternalPort uint16 assignedExternalPort uint16
assignedLifetime time.Duration assignedLifetime time.Duration
err error
errMessage string errMessage string
}{ }{
"lifetime_too_long": { "lifetime_too_long": {
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second, lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
err: ErrLifetimeTooLong,
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds", errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
}, },
"protocol_unknown": { "protocol_unknown": {
lifetime: time.Second, lifetime: time.Second,
protocol: "xyz", protocol: "xyz",
err: ErrNetworkProtocolUnknown,
errMessage: "network protocol is unknown: xyz", errMessage: "network protocol is unknown: xyz",
}, },
"rpc_error": { "rpc_error": {
@@ -48,7 +45,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
lifetime: 1200 * time.Second, lifetime: 1200 * time.Second,
initialConnectionDuration: time.Millisecond, initialConnectionDuration: time.Millisecond,
exchanges: []udpExchange{{close: true}}, exchanges: []udpExchange{{close: true}},
err: ErrConnectionTimeout,
errMessage: "executing remote procedure call: connection timeout: failed attempts: " + errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)", "read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
}, },
@@ -136,9 +132,6 @@ func Test_Client_AddPortMapping(t *testing.T) {
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort) assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
assert.Equal(t, testCase.assignedLifetime, assignedLifetime) assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
if testCase.errMessage != "" { if testCase.errMessage != "" {
if testCase.err != nil {
assert.ErrorIs(t, err, testCase.err)
}
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error()) assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
+2 -8
View File
@@ -11,17 +11,12 @@ import (
"time" "time"
) )
var (
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
ErrConnectionTimeout = errors.New("connection timeout")
)
func (c *Client) rpc(ctx context.Context, gateway netip.Addr, func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
request []byte, responseSize uint) ( request []byte, responseSize uint) (
response []byte, err error, response []byte, err error,
) { ) {
if gateway.IsUnspecified() || !gateway.IsValid() { if gateway.IsUnspecified() || !gateway.IsValid() {
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified) return nil, errors.New("gateway IP is unspecified")
} }
err = checkRequest(request) err = checkRequest(request)
@@ -114,8 +109,7 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
} }
if retryCount == c.maxRetries { if retryCount == c.maxRetries {
return nil, fmt.Errorf("%w: failed attempts: %s", return nil, fmt.Errorf("connection timeout: failed attempts: %s", dedupFailedAttempts(failedAttempts))
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
} }
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to // Opcodes between 0 and 127 are client requests. Opcodes from 128 to
-12
View File
@@ -20,20 +20,17 @@ func Test_Client_rpc(t *testing.T) {
initialConnectionDuration time.Duration initialConnectionDuration time.Duration
exchanges []udpExchange exchanges []udpExchange
expectedResponse []byte expectedResponse []byte
err error
errMessage string errMessage string
}{ }{
"gateway_ip_unspecified": { "gateway_ip_unspecified": {
gateway: netip.IPv6Unspecified(), gateway: netip.IPv6Unspecified(),
request: []byte{0, 0}, request: []byte{0, 0},
err: ErrGatewayIPUnspecified,
errMessage: "gateway IP is unspecified", errMessage: "gateway IP is unspecified",
}, },
"request_too_small": { "request_too_small": {
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
request: []byte{0}, request: []byte{0},
initialConnectionDuration: time.Nanosecond, // doesn't matter initialConnectionDuration: time.Nanosecond, // doesn't matter
err: ErrRequestSizeTooSmall,
errMessage: `checking request: message size is too small: ` + errMessage: `checking request: message size is too small: ` +
`need at least 2 bytes and got 1 byte\(s\)`, `need at least 2 bytes and got 1 byte\(s\)`,
}, },
@@ -53,7 +50,6 @@ func Test_Client_rpc(t *testing.T) {
exchanges: []udpExchange{ exchanges: []udpExchange{
{request: []byte{0, 1}, close: true}, {request: []byte{0, 1}, close: true},
}, },
err: ErrConnectionTimeout,
errMessage: "connection timeout: failed attempts: " + errMessage: "connection timeout: failed attempts: " +
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)", "read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
}, },
@@ -66,7 +62,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0, 0}, request: []byte{0, 0},
response: []byte{1}, response: []byte{1},
}}, }},
err: ErrResponseSizeTooSmall,
errMessage: `checking response: response size is too small: ` + errMessage: `checking response: response size is too small: ` +
`need at least 4 bytes and got 1 byte\(s\)`, `need at least 4 bytes and got 1 byte\(s\)`,
}, },
@@ -80,7 +75,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0, 1, 2, 3}, // size 4 response: []byte{0, 1, 2, 3}, // size 4
}}, }},
err: ErrResponseSizeUnexpected,
errMessage: `checking response: response size is unexpected: ` + errMessage: `checking response: response size is unexpected: ` +
`expected 5 bytes and got 4 byte\(s\)`, `expected 5 bytes and got 4 byte\(s\)`,
}, },
@@ -94,7 +88,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}}, }},
err: ErrProtocolVersionUnknown,
errMessage: "checking response: protocol version is unknown: 1", errMessage: "checking response: protocol version is unknown: 1",
}, },
"unexpected_operation_code": { "unexpected_operation_code": {
@@ -107,7 +100,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}}, }},
err: ErrOperationCodeUnexpected,
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88", errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
}, },
"failure_result_code": { "failure_result_code": {
@@ -120,7 +112,6 @@ func Test_Client_rpc(t *testing.T) {
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
}}, }},
err: ErrResultCodeUnknown,
errMessage: "checking response: result code: result code is unknown: 17", errMessage: "checking response: result code: result code is unknown: 17",
}, },
"success": { "success": {
@@ -153,9 +144,6 @@ func Test_Client_rpc(t *testing.T) {
testCase.request, testCase.responseSize) testCase.request, testCase.responseSize)
if testCase.errMessage != "" { if testCase.errMessage != "" {
if testCase.err != nil {
assert.ErrorIs(t, err, testCase.err)
}
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error()) assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
+2 -4
View File
@@ -41,8 +41,6 @@ func findAvailableTCPPort(t *testing.T) (port uint16) {
func Test_dialAddrThroughFirewall(t *testing.T) { func Test_dialAddrThroughFirewall(t *testing.T) {
t.Parallel() t.Parallel()
errTest := errors.New("test error")
const ipv6InternetWorks = false const ipv6InternetWorks = false
testCases := map[string]struct { testCases := map[string]struct {
@@ -102,7 +100,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
}, },
}, },
"firewall_add_error": { "firewall_add_error": {
firewallAddErr: errTest, firewallAddErr: errors.New("test error"),
errMessageRegex: func() string { errMessageRegex: func() string {
return "accepting output traffic: test error" return "accepting output traffic: test error"
}, },
@@ -122,7 +120,7 @@ func Test_dialAddrThroughFirewall(t *testing.T) {
addrPort := netip.MustParseAddrPort(listener.Addr().String()) addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port()) return netip.AddrPortFrom(loopback, addrPort.Port())
}, },
firewallRemoveErr: errTest, firewallRemoveErr: errors.New("test error"),
errMessageRegex: func() string { errMessageRegex: func() string {
return "removing output traffic rule: test error" return "removing output traffic rule: test error"
}, },
+3 -6
View File
@@ -1,7 +1,6 @@
package netlink package netlink
import ( import (
"errors"
"fmt" "fmt"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
@@ -47,8 +46,6 @@ func (n *NetLink) LinkList() (links []Link, err error) {
return links, nil return links, nil
} }
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) { func (n *NetLink) LinkByName(name string) (link Link, err error) {
links, err := n.LinkList() links, err := n.LinkList()
if err != nil { if err != nil {
@@ -61,7 +58,7 @@ func (n *NetLink) LinkByName(name string) (link Link, err error) {
} }
} }
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name) return Link{}, fmt.Errorf("link not found: for name %s", name)
} }
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) { func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
@@ -76,7 +73,7 @@ func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
} }
} }
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index) return Link{}, fmt.Errorf("link not found: for index %d", index)
} }
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) { func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
@@ -114,7 +111,7 @@ func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
} }
} }
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name) return 0, fmt.Errorf("link not found: matching name %s", link.Name)
} }
func (n *NetLink) LinkDel(linkIndex uint32) (err error) { func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
-6
View File
@@ -1,17 +1,11 @@
package extract package extract
import ( import (
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var (
ErrRead = errors.New("cannot read file")
ErrExtractConnection = errors.New("cannot extract connection from file")
)
// Data extracts the lines and connection from the OpenVPN configuration file. // Data extracts the lines and connection from the OpenVPN configuration file.
func (e *Extractor) Data(filepath string) (lines []string, func (e *Extractor) Data(filepath string) (lines []string,
connection models.Connection, err error, connection models.Connection, err error,
+10 -24
View File
@@ -11,8 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var errRemoteLineNotFound = errors.New("remote line not found")
func extractDataFromLines(lines []string) ( func extractDataFromLines(lines []string) (
connection models.Connection, err error, connection models.Connection, err error,
) { ) {
@@ -35,7 +33,7 @@ func extractDataFromLines(lines []string) (
} }
if !connection.IP.IsValid() { if !connection.IP.IsValid() {
return connection, errRemoteLineNotFound return connection, errors.New("remote line not found")
} }
if connection.Protocol == "" { if connection.Protocol == "" {
@@ -81,19 +79,15 @@ func extractDataFromLine(line string) (
return ip, 0, "", nil return ip, 0, "", nil
} }
var errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected")
func extractProto(line string) (protocol string, err error) { func extractProto(line string) (protocol string, err error) {
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) != 2 { //nolint:mnd if len(fields) != 2 { //nolint:mnd
return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line) return "", fmt.Errorf("proto line has not 2 fields as expected: %s", line)
} }
return parseProto(fields[1]) return parseProto(fields[1])
} }
var errProtocolNotSupported = errors.New("network protocol not supported")
func parseProto(field string) (protocol string, err error) { func parseProto(field string) (protocol string, err error) {
switch field { switch field {
case "tcp", "tcp4", "tcp6", "tcp-client": case "tcp", "tcp4", "tcp6", "tcp-client":
@@ -106,16 +100,10 @@ func parseProto(field string) (protocol string, err error) {
// determined by the remote IP address version. // determined by the remote IP address version.
return constants.UDP, nil return constants.UDP, nil
default: default:
return "", fmt.Errorf("%w: %s", errProtocolNotSupported, field) return "", fmt.Errorf("network protocol not supported: %s", field)
} }
} }
var (
errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected")
errHostNotIP = errors.New("host is not an IP address")
errPortNotValid = errors.New("port is not valid")
)
func extractRemote(line string) (ip netip.Addr, port uint16, func extractRemote(line string) (ip netip.Addr, port uint16,
protocol string, err error, protocol string, err error,
) { ) {
@@ -123,13 +111,13 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
n := len(fields) n := len(fields)
if n < 2 || n > 4 { if n < 2 || n > 4 {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line) return netip.Addr{}, 0, "", fmt.Errorf("remote line has not 2 fields as expected: %s", line)
} }
host := fields[1] host := fields[1]
ip, err = netip.ParseAddr(host) ip, err = netip.ParseAddr(host)
if err != nil { if err != nil {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host) return netip.Addr{}, 0, "", fmt.Errorf("host is not an IP address: %s", host)
// TODO resolve hostname once there is an option to allow it through // TODO resolve hostname once there is an option to allow it through
// the firewall before the VPN is up. // the firewall before the VPN is up.
} }
@@ -137,9 +125,9 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
if n > 2 { //nolint:mnd if n > 2 { //nolint:mnd
portInt, err := strconv.Atoi(fields[2]) portInt, err := strconv.Atoi(fields[2])
if err != nil { if err != nil {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line) return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %s", line)
} else if portInt < 1 || portInt > 65535 { } else if portInt < 1 || portInt > 65535 {
return netip.Addr{}, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt) return netip.Addr{}, 0, "", fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
} }
port = uint16(portInt) port = uint16(portInt)
} }
@@ -154,20 +142,18 @@ func extractRemote(line string) (ip netip.Addr, port uint16,
return ip, port, protocol, nil return ip, port, protocol, nil
} }
var errPostLineFieldsCount = errors.New("post line has not 2 fields as expected")
func extractPort(line string) (port uint16, err error) { func extractPort(line string) (port uint16, err error) {
fields := strings.Fields(line) fields := strings.Fields(line)
const expectedFieldsCount = 2 const expectedFieldsCount = 2
if len(fields) != expectedFieldsCount { if len(fields) != expectedFieldsCount {
return 0, fmt.Errorf("%w: %s", errPostLineFieldsCount, line) return 0, fmt.Errorf("post line has not 2 fields as expected: %s", line)
} }
portInt, err := strconv.Atoi(fields[1]) portInt, err := strconv.Atoi(fields[1])
if err != nil { if err != nil {
return 0, fmt.Errorf("%w: %s", errPortNotValid, line) return 0, fmt.Errorf("port is not valid: %s", line)
} else if portInt < 1 || portInt > 65535 { } else if portInt < 1 || portInt > 65535 {
return 0, fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt) return 0, fmt.Errorf("port is not valid: %d must be between 1 and 65535", portInt)
} }
port = uint16(portInt) port = uint16(portInt)
+11 -12
View File
@@ -17,7 +17,7 @@ func Test_extractDataFromLines(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
lines []string lines []string
connection models.Connection connection models.Connection
err error errMessage string
}{ }{
"success": { "success": {
lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"}, lines: []string{"bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
@@ -29,7 +29,7 @@ func Test_extractDataFromLines(t *testing.T) {
}, },
"extraction error": { "extraction error": {
lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"}, lines: []string{"bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
err: errors.New("on line 2: extracting protocol from proto line: network protocol not supported: bad"), errMessage: "on line 2: extracting protocol from proto line: network protocol not supported: bad",
}, },
"only use first values found": { "only use first values found": {
lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"}, lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"},
@@ -44,7 +44,7 @@ func Test_extractDataFromLines(t *testing.T) {
connection: models.Connection{ connection: models.Connection{
Protocol: constants.TCP, Protocol: constants.TCP,
}, },
err: errRemoteLineNotFound, errMessage: "remote line not found",
}, },
"default TCP port": { "default TCP port": {
lines: []string{"remote 1.2.3.4", "proto tcp"}, lines: []string{"remote 1.2.3.4", "proto tcp"},
@@ -70,9 +70,8 @@ func Test_extractDataFromLines(t *testing.T) {
connection, err := extractDataFromLines(testCase.lines) connection, err := extractDataFromLines(testCase.lines)
if testCase.err != nil { if testCase.errMessage != "" {
require.Error(t, err) assert.EqualError(t, err, testCase.errMessage)
assert.Equal(t, testCase.err.Error(), err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -90,14 +89,14 @@ func Test_extractDataFromLine(t *testing.T) {
ip netip.Addr ip netip.Addr
port uint16 port uint16
protocol string protocol string
isErr error errMessage string
}{ }{
"irrelevant line": { "irrelevant line": {
line: "bla", line: "bla",
}, },
"extract proto error": { "extract proto error": {
line: "proto bad", line: "proto bad",
isErr: errProtocolNotSupported, errMessage: "network protocol not supported",
}, },
"extract proto success": { "extract proto success": {
line: "proto tcp", line: "proto tcp",
@@ -109,7 +108,7 @@ func Test_extractDataFromLine(t *testing.T) {
}, },
"extract remote error": { "extract remote error": {
line: "remote bad", line: "remote bad",
isErr: errHostNotIP, errMessage: "host is not an IP address",
}, },
"extract remote success": { "extract remote success": {
line: "remote 1.2.3.4 1194 udp", line: "remote 1.2.3.4 1194 udp",
@@ -119,7 +118,7 @@ func Test_extractDataFromLine(t *testing.T) {
}, },
"extract_port_fail": { "extract_port_fail": {
line: "port a", line: "port a",
isErr: errPortNotValid, errMessage: "port is not valid",
}, },
"extract_port_success": { "extract_port_success": {
line: "port 1194", line: "port 1194",
@@ -133,8 +132,8 @@ func Test_extractDataFromLine(t *testing.T) {
ip, port, protocol, err := extractDataFromLine(testCase.line) ip, port, protocol, err := extractDataFromLine(testCase.line)
if testCase.isErr != nil { if testCase.errMessage != "" {
assert.ErrorIs(t, err, testCase.isErr) assert.ErrorContains(t, err, testCase.errMessage)
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
} }
+1 -4
View File
@@ -4,15 +4,12 @@ import (
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt"
) )
var errPEMDecode = errors.New("cannot decode PEM encoded block")
func PEM(b []byte) (encodedData string, err error) { func PEM(b []byte) (encodedData string, err error) {
pemBlock, _ := pem.Decode(b) pemBlock, _ := pem.Decode(b)
if pemBlock == nil { if pemBlock == nil {
return "", fmt.Errorf("%w", errPEMDecode) return "", errors.New("cannot decode PEM encoded block")
} }
der := pemBlock.Bytes der := pemBlock.Bytes
+3 -5
View File
@@ -13,16 +13,13 @@ func Test_PEM(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
b []byte b []byte
encodedData string encodedData string
errWrapped error
errMessage string errMessage string
}{ }{
"no input": { "no input": {
errWrapped: errPEMDecode,
errMessage: "cannot decode PEM encoded block", errMessage: "cannot decode PEM encoded block",
}, },
"bad input": { "bad input": {
b: []byte{1, 2, 3}, b: []byte{1, 2, 3},
errWrapped: errPEMDecode,
errMessage: "cannot decode PEM encoded block", errMessage: "cannot decode PEM encoded block",
}, },
"valid data with extras": { "valid data with extras": {
@@ -46,9 +43,10 @@ func Test_PEM(t *testing.T) {
encodedData, err := PEM(testCase.b) encodedData, err := PEM(testCase.b)
assert.Equal(t, testCase.encodedData, encodedData) assert.Equal(t, testCase.encodedData, encodedData)
assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errMessage != "" {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+2 -5
View File
@@ -3,7 +3,6 @@ package pkcs8
import ( import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"errors"
"fmt" "fmt"
) )
@@ -11,8 +10,6 @@ import (
// https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers // https://www.ibm.com/docs/en/zos/2.3.0?topic=programming-object-identifiers
var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals var oidDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} //nolint:gochecknoglobals
var ErrEncryptionAlgorithmNotPBES2 = errors.New("encryption algorithm is not PBES2")
type encryptedPrivateKey struct { type encryptedPrivateKey struct {
EncryptionAlgorithm pkix.AlgorithmIdentifier EncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedData []byte EncryptedData []byte
@@ -35,8 +32,8 @@ func getEncryptionAlgorithmOid(der []byte) (
oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} oidPBES2 := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm oidAlgorithm := encryptedPrivateKeyData.EncryptionAlgorithm.Algorithm
if !oidAlgorithm.Equal(oidPBES2) { if !oidAlgorithm.Equal(oidPBES2) {
return nil, fmt.Errorf("%w: %s instead of PBES2 %s", return nil, fmt.Errorf("encryption algorithm is not PBES2: %s instead of PBES2 %s",
ErrEncryptionAlgorithmNotPBES2, oidAlgorithm, oidPBES2) oidAlgorithm, oidPBES2)
} }
var encryptionAlgorithmParams encryptedAlgorithmParams var encryptionAlgorithmParams encryptedAlgorithmParams
-3
View File
@@ -2,14 +2,11 @@ package pkcs8
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
pkcs8lib "github.com/youmark/pkcs8" pkcs8lib "github.com/youmark/pkcs8"
) )
var ErrUnsupportedKeyType = errors.New("unsupported key type")
// UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption // UpgradeEncryptedKey eventually upgrades an encrypted key to a newer encryption
// if its encryption is too weak for Openvpn/Openssl. // if its encryption is too weak for Openvpn/Openssl.
// If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC. // If the key is encrypted using DES-CBC, it is decrypted and re-encrypted using AES-256-CBC.
+1 -4
View File
@@ -2,15 +2,12 @@ package openvpn
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os/exec" "os/exec"
"github.com/qdm12/gluetun/internal/constants/openvpn" "github.com/qdm12/gluetun/internal/constants/openvpn"
) )
var ErrVersionUnknown = errors.New("OpenVPN version is unknown")
const ( const (
binOpenvpn25 = "openvpn2.5" binOpenvpn25 = "openvpn2.5"
binOpenvpn26 = "openvpn2.6" binOpenvpn26 = "openvpn2.6"
@@ -26,7 +23,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
case openvpn.Openvpn26: case openvpn.Openvpn26:
bin = binOpenvpn26 bin = binOpenvpn26
default: default:
return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version) return nil, nil, nil, fmt.Errorf("OpenVPN version is unknown: %s", version)
} }
args := []string{"--config", configPath} args := []string{"--config", configPath}
+1 -4
View File
@@ -2,7 +2,6 @@ package openvpn
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"
@@ -16,8 +15,6 @@ func (c *Configurator) Version26(ctx context.Context) (version string, err error
return c.version(ctx, binOpenvpn26) return c.version(ctx, binOpenvpn26)
} }
var ErrVersionTooShort = errors.New("version output is too short")
func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) { func (c *Configurator) version(ctx context.Context, binName string) (version string, err error) {
cmd := exec.CommandContext(ctx, binName, "--version") cmd := exec.CommandContext(ctx, binName, "--version")
output, err := c.cmder.Run(cmd) output, err := c.cmder.Run(cmd)
@@ -28,7 +25,7 @@ func (c *Configurator) version(ctx context.Context, binName string) (version str
words := strings.Fields(firstLine) words := strings.Fields(firstLine)
const minWords = 2 const minWords = 2
if len(words) < minWords { if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrVersionTooShort, firstLine) return "", fmt.Errorf("version output is too short: %s", firstLine)
} }
return words[1], nil return words[1], nil
} }
+10 -21
View File
@@ -2,24 +2,17 @@ package icmp
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
) )
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) { func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch { switch {
case mtu < minMTU: case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu) return fmt.Errorf("ICMP Next Hop MTU is too low: %d", mtu)
case mtu > physicalLinkMTU: case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d", return fmt.Errorf("ICMP Next Hop MTU is too high: %d is larger than physical link MTU %d", mtu, physicalLinkMTU)
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default: default:
return nil return nil
} }
@@ -34,14 +27,12 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
} }
inboundBody, ok := inboundMessage.Body.(*icmp.Echo) inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok { if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body) return false, fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
} }
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil return inboundBody.ID == outboundBody.ID, nil
} }
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte, func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool, outboundMessage *icmp.Message, truncatedBody bool,
) (err error) { ) (err error) {
@@ -51,12 +42,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
} }
inboundBody, ok := inboundMessage.Body.(*icmp.Echo) inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok { if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body) return fmt.Errorf("ICMP body type is not supported: %T", inboundMessage.Body)
} }
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID { if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d", return fmt.Errorf("ICMP id mismatch: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID) outboundBody.ID, inboundBody.ID)
} }
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody) err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil { if err != nil {
@@ -65,19 +56,17 @@ func checkEchoReply(icmpProtocol int, received []byte,
return nil return nil
} }
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) { func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) { if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes", return fmt.Errorf("ICMP data mismatch: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received)) len(sent), len(received))
} }
if receivedTruncated { if receivedTruncated {
sent = sent[:len(received)] sent = sent[:len(received)]
} }
if !bytes.Equal(received, sent) { if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x", return fmt.Errorf("ICMP data mismatch: sent %x and received %x",
ErrEchoDataMismatch, sent, received) sent, received)
} }
return nil return nil
} }
+1 -3
View File
@@ -10,9 +10,7 @@ import (
var ( var (
ErrNotPermitted = errors.New("ICMP not permitted") ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable") errCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found") ErrMTUNotFound = errors.New("MTU not found")
errTimeout = errors.New("operation timed out") errTimeout = errors.New("operation timed out")
) )
+1 -1
View File
@@ -25,7 +25,7 @@ func PathMTUDiscover(ctx context.Context, ip netip.Addr,
switch { switch {
case err == nil: case err == nil:
return mtu, nil return mtu, nil
case errors.Is(err, errTimeout) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole case errors.Is(err, errTimeout) || errors.Is(err, errCommunicationAdministrativelyProhibited): // blackhole
default: default:
return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err) return 0, fmt.Errorf("finding IPv4 next hop MTU to %s: %w", ip, err)
} }
+3 -6
View File
@@ -117,13 +117,10 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
case portUnreachable: // triggered by TCP or UDP from applications case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode: case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)", return 0, fmt.Errorf("ICMP destination unreachable: %w (code %d)", errCommunicationAdministrativelyProhibited,
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
inboundMessage.Code) inboundMessage.Code)
default: default:
return 0, fmt.Errorf("%w: code %d", return 0, fmt.Errorf("ICMP destination unreachable: code %d", inboundMessage.Code)
ErrDestinationUnreachable, inboundMessage.Code)
} }
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4 // See https://datatracker.ietf.org/doc/html/rfc1191#section-4
@@ -158,7 +155,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
inboundID, outboundID) inboundID, outboundID)
continue continue
default: default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody) return 0, fmt.Errorf("ICMP body type is not supported: %T", typedBody)
} }
} }
} }
+3 -2
View File
@@ -2,6 +2,7 @@ package icmp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -115,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if err != nil { if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err) return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch { } else if idMatch {
return 0, fmt.Errorf("%w", ErrDestinationUnreachable) return 0, errors.New("ICMP destination unreachable")
} }
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id") logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue continue
@@ -128,7 +129,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
inboundID, outboundID) inboundID, outboundID)
continue continue
default: default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody) return 0, fmt.Errorf("ICMP body type %T is not supported", typedBody)
} }
} }
} }
+4 -4
View File
@@ -71,7 +71,7 @@ func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil { if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted) err = ErrNotPermitted
} }
return 0, fmt.Errorf("writing ICMP message: %w", err) return 0, fmt.Errorf("writing ICMP message: %w", err)
} }
@@ -157,7 +157,7 @@ func collectReplies(conn net.PacketConn, ipVersion string,
logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code) logger.Debugf("ignoring ICMP message (type: %d, code: %d)", message.Type, message.Code)
continue continue
default: default:
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body) return fmt.Errorf("ICMP body type is not supported: %T", message.Body)
} }
echoBody, _ := message.Body.(*icmp.Echo) echoBody, _ := message.Body.(*icmp.Echo)
@@ -183,8 +183,8 @@ func collectReplies(conn net.PacketConn, ipVersion string,
ipPacketLength == conservativeReplyLength ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated // Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength { if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB", return fmt.Errorf("ICMP data mismatch: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength) sentBytes, ipPacketLength)
} }
// Truncated reply or matching reply size // Truncated reply or matching reply size
tests[testIndex].ok = true tests[testIndex].ok = true
+1 -1
View File
@@ -29,7 +29,7 @@ func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(),
} }
var ( var (
errNoRoute = fmt.Errorf("no route to destination") errNoRoute = errors.New("no route to destination")
ErrNetworkUnreachable = errors.New("network unreachable") ErrNetworkUnreachable = errors.New("network unreachable")
) )
+3 -8
View File
@@ -13,11 +13,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/tcp" "github.com/qdm12/gluetun/internal/pmtud/tcp"
) )
var (
ErrICMPOkTCPFail = errors.New("PMTUD succeeded with ICMP but failed with TCP")
ErrICMPFailTCPFail = errors.New("PMTUD failed with both ICMP and TCP")
)
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP. // PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy. // Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to // ICMP PMTUD is run first. If successful, the range of possible MTU values to
@@ -81,10 +76,10 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
} }
} }
if icmpSuccess { if icmpSuccess {
return 0, fmt.Errorf("%w - discarding ICMP obtained MTU %d", return 0, fmt.Errorf("PMTUD succeeded with ICMP but failed with TCP "+
ErrICMPOkTCPFail, maxPossibleMTU) "- discarding ICMP obtained MTU %d", maxPossibleMTU)
} }
return 0, fmt.Errorf("%w", ErrICMPFailTCPFail) return 0, errors.New("PMTUD failed with both ICMP and TCP")
} }
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu) logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
return mtu, nil return mtu, nil
+2 -4
View File
@@ -57,8 +57,6 @@ func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {} func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {} func (l *noopLogger) Error(_ string) {}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4) routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil { if err != nil {
@@ -76,7 +74,7 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
return min(link.MTU, maxMTU), nil return min(link.MTU, maxMTU), nil
} }
} }
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) return 0, errors.New("route not found: no loopback route found")
} }
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
@@ -100,7 +98,7 @@ func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
} }
} }
if mtu == 0 { if mtu == 0 {
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) return 0, errors.New("route not found: no default route found")
} }
return mtu, nil return mtu, nil
} }
+5 -8
View File
@@ -12,8 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/ip"
) )
var errTCPServersUnreachable = errors.New("all TCP servers are unreachable")
// findHighestMSSDestination finds the destination with the highest // findHighestMSSDestination finds the destination with the highest
// MSS amongst the provided destinations. // MSS amongst the provided destinations.
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor, func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
@@ -68,7 +66,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
} }
if mss == 0 { // no MSS found for any destination if mss == 0 { // no MSS found for any destination
return netip.AddrPort{}, 0, fmt.Errorf("%w (%d servers)", errTCPServersUnreachable, len(dsts)) return netip.AddrPort{}, 0, fmt.Errorf("all %d TCP servers are unreachable", len(dsts))
} }
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
@@ -77,8 +75,6 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
return dst, mss, nil return dst, mss, nil
} }
var errMSSNotFound = errors.New("MSS option not found in reply")
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort, func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) ( excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
mss uint32, err error, mss uint32, err error,
@@ -132,11 +128,12 @@ func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
case err != nil: case err != nil:
return 0, fmt.Errorf("parsing reply TCP header: %w", err) return 0, fmt.Errorf("parsing reply TCP header: %w", err)
case replyHeader.typ != packetTypeSYNACK: case replyHeader.typ != packetTypeSYNACK:
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ) return 0, fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", replyHeader.typ)
case replyHeader.ack != synSeq+1: case replyHeader.ack != synSeq+1:
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack) return 0, fmt.Errorf("TCP SYN-ACK ACK number %d does not match expected value %d",
replyHeader.ack, synSeq+1)
case replyHeader.options.mss == 0: case replyHeader.options.mss == 0:
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound) return 0, errors.New("MSS option not found in reply")
} }
err = sendRST(fd, src, dst, replyHeader.ack) err = sendRST(fd, src, dst, replyHeader.ack)
+1 -6
View File
@@ -12,11 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/test" "github.com/qdm12/gluetun/internal/pmtud/test"
) )
var (
ErrMTUNotFound = errors.New("MTU not found")
ErrMSSTooSmall = errors.New("TCP MSS is too small to find the MTU")
)
type testUnit struct { type testUnit struct {
mtu uint32 mtu uint32
ok bool ok bool
@@ -178,5 +173,5 @@ func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
} }
} }
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound) return 0, errors.New("MTU not found: your connection might not be working at all")
} }
+6 -14
View File
@@ -75,13 +75,6 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
return fileDescriptor(fdPlatform), stop, nil return fileDescriptor(fdPlatform), stop, nil
} }
var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
errTCPPacketLost = errors.New("TCP packet was lost")
)
// Craft and send a raw TCP packet to test the MTU. // Craft and send a raw TCP packet to test the MTU.
// It expects either an RST reply (if no server is listening) // It expects either an RST reply (if no server is listening)
// or a SYN-ACK/ACK reply (if a server is listening). // or a SYN-ACK/ACK reply (if a server is listening).
@@ -142,9 +135,10 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
// server actively closed the connection, try sending a SYN with data // server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu) return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case firstReplyHeader.typ != packetTypeSYNACK: case firstReplyHeader.typ != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ) return fmt.Errorf("TCP packet is not a SYN-ACK: unexpected packet type %s", firstReplyHeader.typ)
case firstReplyHeader.ack != synSeq+1: case firstReplyHeader.ack != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack) return fmt.Errorf("TCP SYN-ACK ACK number does not match expected value: "+
"expected %d, got %d", synSeq+1, firstReplyHeader.ack)
} }
if firstReplyHeader.options.mss != 0 { if firstReplyHeader.options.mss != 0 {
@@ -191,15 +185,13 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
} }
return nil return nil
case packetTypeSYNACK: // server never received our MTU-test ACK packet case packetTypeSYNACK: // server never received our MTU-test ACK packet
return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost) return errors.New("TCP packet was lost: server responded with second SYN-ACK packet")
default: default:
_ = sendRST(fd, src, dst, finalPacketHeader.ack) _ = sendRST(fd, src, dst, finalPacketHeader.ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ) return fmt.Errorf("final TCP packet type is unexpected: %s", finalPacketHeader.typ)
} }
} }
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte, func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
src, dst netip.AddrPort, mtu uint32, src, dst netip.AddrPort, mtu uint32,
) error { ) error {
@@ -223,7 +215,7 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
return fmt.Errorf("parsing reply TCP header: %w", err) return fmt.Errorf("parsing reply TCP header: %w", err)
} else if replyPacketHeader.typ != packetTypeRST && } else if replyPacketHeader.typ != packetTypeRST &&
replyPacketHeader.typ != packetTypeRSTACK { replyPacketHeader.typ != packetTypeRSTACK {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ) return fmt.Errorf("TCP packet is not an RST: %s", replyPacketHeader.typ)
} }
return nil return nil
} }
+18 -34
View File
@@ -2,7 +2,6 @@ package tcp
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
@@ -120,17 +119,11 @@ type tcpHeader struct {
options options options options
} }
var (
errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
)
// parseTCPHeader parses the TCP header from b. // parseTCPHeader parses the TCP header from b.
// b should be the entire TCP packet bytes. // b should be the entire TCP packet bytes.
func parseTCPHeader(b []byte) (header tcpHeader, err error) { func parseTCPHeader(b []byte) (header tcpHeader, err error) {
if len(b) < int(constants.BaseTCPHeaderLength) { if len(b) < int(constants.BaseTCPHeaderLength) {
return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b)) return tcpHeader{}, fmt.Errorf("TCP header is too short: %d bytes", len(b))
} }
header.srcPort = binary.BigEndian.Uint16(b[0:2]) header.srcPort = binary.BigEndian.Uint16(b[0:2])
@@ -146,11 +139,11 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
switch { switch {
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength: case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes", return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength) "data offset is %d bytes, expected at least %d bytes", header.dataOffset, constants.BaseTCPHeaderLength)
case int(header.dataOffset) > len(b): case int(header.dataOffset) > len(b):
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes", return tcpHeader{}, fmt.Errorf("TCP header data offset is invalid: "+
errTCPHeaderDataOffset, header.dataOffset, len(b)) "data offset is %d bytes, but packet is only %d bytes", header.dataOffset, len(b))
} }
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength { if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
@@ -186,7 +179,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
case flags&ackFlag != 0: case flags&ackFlag != 0:
header.typ = packetTypeACK header.typ = packetTypeACK
default: default:
return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags) return tcpHeader{}, fmt.Errorf("TCP packet type is unknown: flags are 0x%02x", flags)
} }
header.seq = binary.BigEndian.Uint32(b[4:8]) header.seq = binary.BigEndian.Uint32(b[4:8])
@@ -206,15 +199,6 @@ type optionTimestamps struct {
echo uint32 echo uint32
} }
var (
errTCPOptionLengthTruncated = errors.New("TCP option length is truncated")
ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid")
ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid")
ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid")
ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid")
errTCPOptionTypeUnknown = errors.New("TCP option type is unknown")
)
func parseTCPOptions(b []byte) (parsed options, err error) { func parseTCPOptions(b []byte) (parsed options, err error) {
i := 0 i := 0
for i < len(b) { for i < len(b) {
@@ -232,7 +216,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
// Handle TLV (Type-Length-Value) options // Handle TLV (Type-Length-Value) options
if i+1 >= len(b) { if i+1 >= len(b) {
// This should not happen for DF packets. // This should not happen for DF packets.
return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i) return options{}, fmt.Errorf("TCP option length is truncated: at offset %d", i)
} }
length := int(b[i+1]) length := int(b[i+1])
@@ -240,11 +224,11 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
maxLength := len(b) - i maxLength := len(b) - i
switch { switch {
case length < minLength: case length < minLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d", return options{}, fmt.Errorf("TCP option length is invalid: "+
ErrTCPOptionLengthInvalid, optionType, i, length, minLength) "type %d at offset %d has length %d < %d", optionType, i, length, minLength)
case length > maxLength: case length > maxLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d", return options{}, fmt.Errorf("TCP option length is invalid: "+
ErrTCPOptionLengthInvalid, optionType, i, length, maxLength) "type %d at offset %d has length %d > %d", optionType, i, length, maxLength)
} }
data := b[i+2 : i+length] data := b[i+2 : i+length]
@@ -259,15 +243,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
case optionTypeMSS: case optionTypeMSS:
const expectedLength = 4 const expectedLength = 4
if length != expectedLength { if length != expectedLength {
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d", return options{}, fmt.Errorf("TCP option MSS value is invalid: "+
ErrTCPOptionMSSInvalid, i, length, expectedLength) "MSS option at offset %d has length %d, expected %d", i, length, expectedLength)
} }
parsed.mss = uint32(binary.BigEndian.Uint16(data)) parsed.mss = uint32(binary.BigEndian.Uint16(data))
case optionTypeWindowScale: case optionTypeWindowScale:
const expectedLength = 3 const expectedLength = 3
if length != expectedLength { if length != expectedLength {
return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d", return options{}, fmt.Errorf("TCP option Window Scale value is invalid: "+
ErrTCPOptionWindowScaleInvalid, i, length, expectedLength) "window scale option at offset %d has length %d, expected %d", i, length, expectedLength)
} }
windowScale := data[0] windowScale := data[0]
parsed.windowScale = &windowScale parsed.windowScale = &windowScale
@@ -276,15 +260,15 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
case optionTypeTimestamps: case optionTypeTimestamps:
const expectedLength = 10 const expectedLength = 10
if length != expectedLength { if length != expectedLength {
return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d", return options{}, fmt.Errorf("TCP option Timestamps value is invalid: "+
ErrTCPOptionTimestampsInvalid, i, length, expectedLength) "timestamps option at offset %d has length %d, expected %d", i, length, expectedLength)
} }
parsed.timestamps = &optionTimestamps{ parsed.timestamps = &optionTimestamps{
value: binary.BigEndian.Uint32(data[:4]), value: binary.BigEndian.Uint32(data[:4]),
echo: binary.BigEndian.Uint32(data[4:]), echo: binary.BigEndian.Uint32(data[4:]),
} }
default: default:
return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType) return options{}, fmt.Errorf("TCP option type is unknown: type %d", optionType)
} }
i += length i += length
+1 -3
View File
@@ -177,11 +177,9 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
return l.service.GetPortsForwarded() return l.service.GetPortsForwarded()
} }
var ErrServiceNotStarted = errors.New("port forwarding service not started")
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) { func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
if l.service == nil { if l.service == nil {
return fmt.Errorf("%w", ErrServiceNotStarted) return errors.New("port forwarding service not started")
} }
return l.service.SetPortsForwarded(l.runCtx, ports) return l.service.SetPortsForwarded(l.runCtx, ports)
+12 -24
View File
@@ -55,23 +55,10 @@ func (s *Settings) OverrideWith(update Settings) {
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password) s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
} }
var (
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrUsernameNotSet = errors.New("username not set")
ErrPasswordNotSet = errors.New("password not set")
ErrFilepathNotSet = errors.New("file path not set")
ErrInterfaceNotSet = errors.New("interface not set")
ErrPortsCountZero = errors.New("ports count cannot be zero")
ErrPortsCountTooHigh = errors.New("ports count too high")
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
ErrListeningPortZero = errors.New("listening port cannot be 0")
)
func (s *Settings) Validate(forStartup bool) (err error) { func (s *Settings) Validate(forStartup bool) (err error) {
// Minimal validation // Minimal validation
if s.Filepath == "" { if s.Filepath == "" {
return fmt.Errorf("%w", ErrFilepathNotSet) return errors.New("file path not set")
} }
if !forStartup { if !forStartup {
@@ -83,41 +70,42 @@ func (s *Settings) Validate(forStartup bool) (err error) {
// Startup validation requires additional fields set. // Startup validation requires additional fields set.
switch { switch {
case s.PortForwarder == nil: case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet) return errors.New("port forwarder not set")
case s.Interface == "": case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet) return errors.New("interface not set")
case s.PortsCount == 0: case s.PortsCount == 0:
return fmt.Errorf("%w", ErrPortsCountZero) return errors.New("ports count cannot be zero")
} }
switch s.PortForwarder.Name() { switch s.PortForwarder.Name() {
case providers.PrivateInternetAccess: case providers.PrivateInternetAccess:
switch { switch {
case s.ServerName == "": case s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet) return errors.New("server name not set")
case s.Username == "": case s.Username == "":
return fmt.Errorf("%w", ErrUsernameNotSet) return errors.New("username not set")
case s.Password == "": case s.Password == "":
return fmt.Errorf("%w", ErrPasswordNotSet) return errors.New("password not set")
} }
case providers.Protonvpn: case providers.Protonvpn:
const maxPortsCount = 4 const maxPortsCount = 4
if s.PortsCount > maxPortsCount { if s.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
} }
default: default:
const maxPortsCount = 1 const maxPortsCount = 1
if s.PortsCount > maxPortsCount { if s.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
} }
} }
if !slices.Equal(s.ListeningPorts, []uint16{0}) { if !slices.Equal(s.ListeningPorts, []uint16{0}) {
switch { switch {
case len(s.ListeningPorts) != int(s.PortsCount): case len(s.ListeningPorts) != int(s.PortsCount):
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount) return fmt.Errorf("listening ports length must be equal to ports count: %d != %d",
len(s.ListeningPorts), s.PortsCount)
case slices.Contains(s.ListeningPorts, 0): case slices.Contains(s.ListeningPorts, 0):
return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts) return fmt.Errorf("listening port cannot be 0: in %v", s.ListeningPorts)
} }
} }
+1 -1
View File
@@ -120,7 +120,7 @@ func Test_Server_BadSettings(t *testing.T) {
server, err := New(settings) server, err := New(settings)
assert.Nil(t, server) assert.Nil(t, server)
assert.ErrorIs(t, err, ErrBlockProfileRateNegative) assert.ErrorContains(t, err, "block profile rate cannot be negative")
const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative" const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative"
assert.EqualError(t, err, expectedErrMessage) assert.EqualError(t, err, expectedErrMessage)
} }
+2 -8
View File
@@ -2,7 +2,6 @@ package pprof
import ( import (
"errors" "errors"
"fmt"
"time" "time"
"github.com/qdm12/gluetun/internal/httpserver" "github.com/qdm12/gluetun/internal/httpserver"
@@ -51,18 +50,13 @@ func (s *Settings) OverrideWith(other Settings) {
s.HTTPServer.OverrideWith(other.HTTPServer) s.HTTPServer.OverrideWith(other.HTTPServer)
} }
var (
ErrBlockProfileRateNegative = errors.New("block profile rate cannot be negative")
ErrMutexProfileRateNegative = errors.New("mutex profile rate cannot be negative")
)
func (s Settings) Validate() (err error) { func (s Settings) Validate() (err error) {
if *s.BlockProfileRate < 0 { if *s.BlockProfileRate < 0 {
return fmt.Errorf("%w", ErrBlockProfileRateNegative) return errors.New("block profile rate cannot be negative")
} }
if *s.MutexProfileRate < 0 { if *s.MutexProfileRate < 0 {
return fmt.Errorf("%w", ErrMutexProfileRateNegative) return errors.New("mutex profile rate cannot be negative")
} }
return s.HTTPServer.Validate() return s.HTTPServer.Validate()
+4 -8
View File
@@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/httpserver" "github.com/qdm12/gluetun/internal/httpserver"
"github.com/qdm12/gosettings/validate"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -195,7 +194,6 @@ func Test_Settings_Validate(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
settings Settings settings Settings
errWrapped error
errMessage string errMessage string
}{ }{
"negative block profile rate": { "negative block profile rate": {
@@ -203,16 +201,14 @@ func Test_Settings_Validate(t *testing.T) {
BlockProfileRate: intPtr(-1), BlockProfileRate: intPtr(-1),
MutexProfileRate: intPtr(0), MutexProfileRate: intPtr(0),
}, },
errWrapped: ErrBlockProfileRateNegative, errMessage: "block profile rate cannot be negative",
errMessage: ErrBlockProfileRateNegative.Error(),
}, },
"negative mutex profile rate": { "negative mutex profile rate": {
settings: Settings{ settings: Settings{
BlockProfileRate: intPtr(0), BlockProfileRate: intPtr(0),
MutexProfileRate: intPtr(-1), MutexProfileRate: intPtr(-1),
}, },
errWrapped: ErrMutexProfileRateNegative, errMessage: "mutex profile rate cannot be negative",
errMessage: ErrMutexProfileRateNegative.Error(),
}, },
"http server validation error": { "http server validation error": {
settings: Settings{ settings: Settings{
@@ -222,7 +218,6 @@ func Test_Settings_Validate(t *testing.T) {
Address: ":x", Address: ":x",
}, },
}, },
errWrapped: validate.ErrPortNotAnInteger,
errMessage: "port value is not an integer: x", errMessage: "port value is not an integer: x",
}, },
"valid settings": { "valid settings": {
@@ -247,9 +242,10 @@ func Test_Settings_Validate(t *testing.T) {
err := testCase.settings.Validate() err := testCase.settings.Validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errMessage != "" { if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
} }
}) })
} }
+2 -4
View File
@@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
"github.com/qdm12/gluetun/internal/provider/common"
) )
type apiData struct { type apiData struct {
@@ -48,8 +46,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
_ = response.Body.Close() _ = response.Body.Close()
return data, fmt.Errorf("%w: %d %s", return data, fmt.Errorf("HTTP status code not OK: %d %s",
common.ErrHTTPStatusCodeNotOK, response.StatusCode, response.Status) response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
-5
View File
@@ -1,5 +0,0 @@
package common
import "errors"
var ErrPortForwardNotSupported = errors.New("port forwarding not supported")
+1 -3
View File
@@ -11,9 +11,7 @@ import (
var ( var (
ErrNotEnoughServers = errors.New("not enough servers found") ErrNotEnoughServers = errors.New("not enough servers found")
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrCredentialsMissing = errors.New("credentials are missing")
ErrIPFetcherUnsupported = errors.New("IP fetcher not supported")
ErrCredentialsMissing = errors.New("credentials missing")
) )
type Fetcher interface { type Fetcher interface {
+1 -4
View File
@@ -1,7 +1,6 @@
package custom package custom
import ( import (
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
@@ -10,8 +9,6 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
// GetConnection gets the connection from the OpenVPN configuration file. // GetConnection gets the connection from the OpenVPN configuration file.
func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) ( func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
connection models.Connection, err error, connection models.Connection, err error,
@@ -22,7 +19,7 @@ func (p *Provider) GetConnection(selection settings.ServerSelection, _ bool) (
case vpn.Wireguard, vpn.AmneziaWg: case vpn.Wireguard, vpn.AmneziaWg:
return getWireguardConnection(selection), nil return getWireguardConnection(selection), nil
default: default:
return connection, fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, selection.VPN) return connection, fmt.Errorf("VPN type not supported for custom provider: %s", selection.VPN)
} }
} }
-3
View File
@@ -1,7 +1,6 @@
package custom package custom
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -12,8 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
var ErrExtractData = errors.New("failed extracting information from custom configuration file")
func (p *Provider) OpenVPNConfig(connection models.Connection, func (p *Provider) OpenVPNConfig(connection models.Connection,
settings settings.OpenVPN, ipv6Supported bool, settings settings.OpenVPN, ipv6Supported bool,
) (lines []string) { ) (lines []string) {
+2 -5
View File
@@ -3,13 +3,10 @@ package updater
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
) )
var errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
type apiData struct { type apiData struct {
Servers []apiServer `json:"servers"` Servers []apiServer `json:"servers"`
} }
@@ -42,8 +39,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
_ = response.Body.Close() _ = response.Body.Close()
return data, fmt.Errorf("%w: %d %s", return data, fmt.Errorf("HTTP status code not OK: %d %s",
errHTTPStatusCodeNotOK, response.StatusCode, response.Status) response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)

Some files were not shown because too many files have changed in this diff Show More