chore: do not use sentinel errors when unneeded

- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly
- only use sentinel errors when it has to be checked using `errors.Is`
- replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error
- exclude the sentinel error definition requirement from .golangci.yml
- update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
Quentin McGaw
2026-05-02 00:50:16 +00:00
parent 9b6f048fe8
commit 4a78989d9d
172 changed files with 666 additions and 1433 deletions
+12 -21
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -177,14 +176,6 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
return node
}
var (
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
const amneziaWG = true
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
@@ -194,16 +185,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
if *a.JunkPacketCount == 0 {
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
}
} else {
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} else if *a.JunkPacketMin > *a.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d",
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
"jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
}
}
@@ -222,20 +213,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
case 1:
_, err := strconv.Atoi(fields[0])
if err != nil {
return fmt.Errorf("%w: %s value %s is not a number",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s is not a number", name, headerRange)
}
case 2: //nolint:mnd
for _, field := range fields {
_, err := strconv.Atoi(field)
if err != nil {
return fmt.Errorf("%w: %s value %s is not a valid range",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s is not a valid range", name, headerRange)
}
}
default:
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
ErrHeaderRangeMalformed, name, headerRange)
return fmt.Errorf("header range is malformed: "+
"%s value %s must be in the form n or n-m", name, headerRange)
}
}
+7 -13
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"time"
@@ -48,22 +47,15 @@ type DNS struct {
UpstreamPlainAddresses []netip.AddrPort
}
var (
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
)
func (d DNS) validate() (err error) {
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
}
const minUpdatePeriod = 30 * time.Second
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
return fmt.Errorf("%w: %s must be bigger than %s",
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
return fmt.Errorf("update period is too short: %s must be bigger than %s",
*d.UpdatePeriod, minUpdatePeriod)
}
if d.UpstreamType == DNSUpstreamTypePlain {
@@ -81,9 +73,11 @@ func (d DNS) validate() (err error) {
}
switch {
case *d.IPv6 && !selectedHasPlainIPv6:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
case !*d.IPv6 && !selectedHasPlainIPv4:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
"in %d addresses", len(d.UpstreamPlainAddresses))
}
}
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/http"
"net/netip"
@@ -37,22 +36,16 @@ func (b *DNSBlacklist) setDefaults() {
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
var (
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
)
func (b DNSBlacklist) validate() (err error) {
for _, host := range b.AllowedHosts {
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host)
return fmt.Errorf("allowed host is not valid: %s", host)
}
}
for _, host := range b.AddBlockedHosts {
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host)
return fmt.Errorf("blocked host is not valid: %s", host)
}
}
@@ -61,7 +54,7 @@ func (b DNSBlacklist) validate() (err error) {
host = host[2:]
}
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
}
}
@@ -209,8 +202,6 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
return ips, ipPrefixes, nil
}
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
ipPrefixes []netip.Prefix, err error,
) {
@@ -236,8 +227,9 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
}
return nil, nil, fmt.Errorf(
"environment variable DOT_PRIVATE_ADDRESS: %w: %s",
ErrPrivateAddressNotValid, privateAddress)
"environment variable DOT_PRIVATE_ADDRESS: "+
"private address is not a valid IP or CIDR range: %s",
privateAddress)
}
return ips, ipPrefixes, nil
-58
View File
@@ -1,58 +0,0 @@
package settings
import "errors"
var (
ErrValueUnknown = errors.New("value is unknown")
ErrCityNotValid = errors.New("the city specified is not valid")
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
ErrCategoryNotValid = errors.New("the category specified is not valid")
ErrCountryNotValid = errors.New("the country specified is not valid")
ErrFilepathMissing = errors.New("filepath is missing")
ErrFirewallZeroPort = errors.New("cannot have a zero port")
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
ErrISPNotValid = errors.New("the ISP specified is not valid")
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
ErrMissingValue = errors.New("missing value")
ErrNameNotValid = errors.New("the server name specified is not valid")
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
ErrRegionNotValid = errors.New("the region specified is not valid")
ErrServerAddressNotValid = errors.New("server listening address is not valid")
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
)
+4 -3
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -20,16 +21,16 @@ type Firewall struct {
func (f Firewall) validate() (err error) {
if hasZeroPort(f.VPNInputPorts) {
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort)
return errors.New("VPN input ports: cannot have a zero port")
}
if hasZeroPort(f.InputPorts) {
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
return errors.New("input ports: cannot have a zero port")
}
for _, subnet := range f.OutboundSubnets {
if subnet.Addr().IsUnspecified() {
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
}
}
@@ -13,25 +13,21 @@ func Test_Firewall_validate(t *testing.T) {
testCases := map[string]struct {
firewall Firewall
errWrapped error
errMessage string
}{
"empty": {
errWrapped: log.ErrLevelNotRecognized,
errMessage: "iptables settings: log level: level is not recognized: ",
},
"zero_vpn_input_port": {
firewall: Firewall{
VPNInputPorts: []uint16{0},
},
errWrapped: ErrFirewallZeroPort,
errMessage: "VPN input ports: cannot have a zero port",
},
"zero_input_port": {
firewall: Firewall{
InputPorts: []uint16{0},
},
errWrapped: ErrFirewallZeroPort,
errMessage: "input ports: cannot have a zero port",
},
"unspecified_outbound_subnet": {
@@ -40,7 +36,6 @@ func Test_Firewall_validate(t *testing.T) {
netip.MustParsePrefix("0.0.0.0/0"),
},
},
errWrapped: ErrFirewallPublicOutboundSubnet,
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
},
"public_outbound_subnet": {
@@ -70,9 +65,10 @@ func Test_Firewall_validate(t *testing.T) {
err := testCase.firewall.validate()
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+4 -10
View File
@@ -38,12 +38,6 @@ type Health struct {
RestartVPN *bool
}
var (
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
)
func (h Health) Validate() (err error) {
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
if err != nil {
@@ -53,16 +47,16 @@ func (h Health) Validate() (err error) {
for _, ip := range h.ICMPTargetIPs {
switch {
case !ip.IsValid():
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
ErrICMPTargetIPsNotCompatible)
return errors.New("ICMP target IP addresses are not compatible: " +
"only a single IP address must be set if it is to be unspecified")
}
}
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
if err != nil {
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
return fmt.Errorf("small check type is not valid: %w", err)
}
return nil
+2 -3
View File
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
// Do not validate user and password
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
if err != nil {
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
return fmt.Errorf("server listening address is not valid: %w", err)
}
return nil
@@ -176,7 +176,6 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
case "disabled", "no", "off":
return ptrTo(false), nil
default:
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s",
ErrValueUnknown, value)
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
}
}
+12 -14
View File
@@ -2,6 +2,7 @@ package settings
import (
"encoding/base64"
"errors"
"fmt"
"regexp"
"strings"
@@ -92,7 +93,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
// Validate version
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err)
return fmt.Errorf("version is not valid: %w", err)
}
isCustom := vpnProvider == providers.Custom
@@ -101,14 +102,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
vpnProvider != providers.VPNSecure
if isUserRequired && *o.User == "" {
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty)
return errors.New("user is empty")
}
passwordRequired := isUserRequired &&
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
if passwordRequired && *o.Password == "" {
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty)
return errors.New("password is empty")
}
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
@@ -132,23 +133,20 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
}
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty)
return errors.New("key passphrase is empty")
}
const maxMSSFix = 10000
if *o.MSSFix > maxMSSFix {
return fmt.Errorf("%w: %d is over the maximum value of %d",
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
}
if !regexpInterfaceName.MatchString(o.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'",
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
}
if *o.Verbosity < 0 || *o.Verbosity > 6 {
return fmt.Errorf("%w: %d can only be between 0 and 5",
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
}
return nil
@@ -162,7 +160,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
}
if confFile == "" {
return fmt.Errorf("%w", ErrFilepathMissing)
return errors.New("filepath is missing")
}
err = validate.FileExists(confFile)
@@ -189,7 +187,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
providers.VPNSecure,
providers.VPNUnlimited:
if clientCert == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
}
@@ -211,7 +209,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
providers.Cyberghost,
providers.VPNUnlimited:
if clientKey == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
}
@@ -230,7 +228,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
encryptedPrivateKey string,
) (err error) {
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
return fmt.Errorf("%w", ErrMissingValue)
return errors.New("missing value")
}
if encryptedPrivateKey == "" {
@@ -62,8 +62,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Perfectprivacy,
providers.Vyprvpn,
) {
return fmt.Errorf("%w: for VPN service provider %s",
ErrOpenVPNTCPNotSupported, vpnProvider)
return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
}
// Validate CustomPort
@@ -78,8 +77,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
providers.Nordvpn, providers.Purevpn,
providers.Surfshark, providers.VPNSecure,
providers.VPNUnlimited, providers.Vyprvpn:
return fmt.Errorf("%w: for VPN service provider %s",
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
default:
var allowedTCP, allowedUDP []uint16
switch vpnProvider {
@@ -123,8 +121,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
}
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
if err != nil {
return fmt.Errorf("%w: for VPN service provider %s: %w",
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
}
}
}
@@ -136,7 +133,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
presets.Strong,
}
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err)
return fmt.Errorf("PIA encryption preset is not valid: %w", err)
}
}
+2 -8
View File
@@ -1,7 +1,6 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -24,21 +23,16 @@ type PMTUD struct {
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
}
}
return nil
+9 -14
View File
@@ -55,12 +55,6 @@ type PortForwarding struct {
Password string `json:"password"`
}
var (
ErrPortsCountTooHigh = errors.New("ports count too high")
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
ErrListeningPortZero = errors.New("listening port cannot be 0")
)
func (p PortForwarding) Validate(vpnProvider string) (err error) {
if !*p.Enabled {
return nil
@@ -78,7 +72,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
providers.Protonvpn,
}
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
return fmt.Errorf("port forwarding cannot be enabled: %w", err)
}
// Validate Filepath
@@ -94,30 +88,31 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
const maxPortsCount = 1
switch {
case p.PortsCount > maxPortsCount:
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
case p.Username == "":
return fmt.Errorf("%w", ErrPortForwardingUserEmpty)
return errors.New("port forwarding username is empty")
case p.Password == "":
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty)
return errors.New("port forwarding password is empty")
}
case providers.Protonvpn:
const maxPortsCount = 4
if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
}
default:
const maxPortsCount = 1
if p.PortsCount > maxPortsCount {
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
}
}
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
switch {
case len(p.ListeningPorts) != int(p.PortsCount):
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount)
return fmt.Errorf("listening ports length must be equal to ports count: "+
"%d != %d", len(p.ListeningPorts), p.PortsCount)
case slices.Contains(p.ListeningPorts, 0):
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts)
return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
}
}
+1 -1
View File
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
}
}
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
@@ -15,7 +15,6 @@ func Test_PublicIP_read(t *testing.T) {
makeReader func(ctrl *gomock.Controller) *reader.Reader
makeWarner func(ctrl *gomock.Controller) Warner
settings PublicIP
errWrapped error
errMessage string
}{
"nothing_read": {
@@ -152,9 +151,10 @@ func Test_PublicIP_read(t *testing.T) {
err := settings.read(reader, warner)
assert.Equal(t, testCase.settings, settings)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+1 -2
View File
@@ -46,8 +46,7 @@ func (c ControlServer) validate() (err error) {
uid := os.Getuid()
const maxPrivilegedPort = 1023
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
return fmt.Errorf("%w: %d when running with user ID %d",
ErrControlServerPrivilegedPort, port, uid)
return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
}
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
@@ -71,25 +71,13 @@ type ServerSelection struct {
Wireguard WireguardSelection `json:"wireguard"`
}
var (
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
)
func (ss *ServerSelection) validate(vpnServiceProvider string,
filterChoicesGetter FilterChoicesGetter, warner Warner,
) (err error) {
switch ss.VPN {
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
}
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
@@ -150,7 +138,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
// Only return error comparing with newer regions, we don't want to confuse the user
// with the retro regions in the error message.
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
}
}
@@ -164,27 +152,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
) (err error) {
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
return fmt.Errorf("the country specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
return fmt.Errorf("the region specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
return fmt.Errorf("the city specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
return fmt.Errorf("the ISP specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
return fmt.Errorf("the hostname specified is not valid: %w", err)
}
if vpnServiceProvider == providers.Custom {
@@ -196,19 +184,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
return fmt.Errorf("name is not valid: "+
"%d names specified instead of 0 or 1 for the custom provider",
len(settings.Names))
}
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
return fmt.Errorf("the server name specified is not valid: %w", err)
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
return fmt.Errorf("the category specified is not valid: %w", err)
}
return nil
@@ -255,12 +243,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
switch {
case *settings.FreeOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrFreeOnlyNotSupported)
return errors.New("free only filter is not supported")
case *settings.PremiumOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported)
return errors.New("premium only filter is not supported")
case *settings.FreeOnly && *settings.PremiumOnly:
return fmt.Errorf("%w", ErrFreePremiumBothSet)
return errors.New("free only and premium only filters are both set")
default:
return nil
}
@@ -269,21 +257,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
switch {
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
return errors.New("owned only filter is not supported")
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
return errors.New("port forwarding only filter is not supported: together with free only filter")
case *settings.StreamOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
return errors.New("stream only filter is not supported")
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported)
return errors.New("multi hop only filter is not supported")
case *settings.PortForwardOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported)
return errors.New("port forwarding only filter is not supported")
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported)
return errors.New("secure core only filter is not supported")
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
return fmt.Errorf("%w", ErrTorOnlyNotSupported)
return errors.New("tor only filter is not supported")
default:
return nil
}
+8 -7
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"slices"
"strings"
@@ -37,20 +38,20 @@ type Updater struct {
func (u Updater) Validate() (err error) {
const minPeriod = time.Minute
if *u.Period > 0 && *u.Period < minPeriod {
return fmt.Errorf("%w: %d must be larger than %s",
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
return fmt.Errorf("VPN server data updater period is too small: "+
"%d must be larger than %s", *u.Period, minPeriod)
}
if u.MinRatio <= 0 || u.MinRatio > 1 {
return fmt.Errorf("%w: %.2f must be between 0+ and 1",
ErrMinRatioNotValid, u.MinRatio)
return fmt.Errorf("minimum ratio is not valid: "+
"%.2f must be between 0+ and 1", u.MinRatio)
}
validProviders := providers.All()
for _, provider := range u.Providers {
err = validate.IsOneOf(provider, validProviders...)
if err != nil {
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
return fmt.Errorf("VPN provider name is not valid: %w", err)
}
if provider == providers.Protonvpn {
@@ -58,9 +59,9 @@ func (u Updater) Validate() (err error) {
if authenticatedAPI {
switch {
case *u.ProtonEmail == "":
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
return errors.New("proton email is missing")
case *u.ProtonPassword == "":
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
return errors.New("proton password is missing")
}
}
}
+1 -1
View File
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
// Validate Type
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
return fmt.Errorf("VPN type is not valid: %w", err)
}
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
+11 -15
View File
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"regexp"
@@ -54,7 +55,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
// Validate PrivateKey
if *w.PrivateKey == "" {
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
return errors.New("private key is not set")
}
_, err = wgtypes.ParseKey(*w.PrivateKey)
if err != nil {
@@ -68,7 +69,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
if vpnProvider == providers.Airvpn {
if *w.PreSharedKey == "" {
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet)
return errors.New("pre-shared key is not set")
}
}
@@ -82,17 +83,15 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// Validate Addresses
if len(w.Addresses) == 0 {
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
return errors.New("interface address is not set")
}
for i, ipNet := range w.Addresses {
if !ipNet.IsValid() {
return fmt.Errorf("%w: for address at index %d",
ErrWireguardInterfaceAddressNotSet, i)
return fmt.Errorf("interface address is not set: for address at index %d", i)
}
if !ipv6Supported && ipNet.Addr().Is6() {
return fmt.Errorf("%w: address %s",
ErrWireguardInterfaceAddressIPv6, ipNet.String())
return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
}
}
@@ -100,30 +99,27 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
// WARNING: do not check for IPv6 networks in the allowed IPs,
// the wireguard code will take care to ignore it.
if len(w.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
return errors.New("allowed IPs is not set")
}
for i, allowedIP := range w.AllowedIPs {
if !allowedIP.IsValid() {
return fmt.Errorf("%w: for allowed ip %d of %d",
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
}
}
if *w.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
*w.PersistentKeepaliveInterval)
return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
}
// Validate interface
if !regexpInterfaceName.MatchString(w.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'",
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
}
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
validImplementations := []string{"auto", "userspace", "kernelspace"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
return fmt.Errorf("implementation is not valid: %w", err)
}
}
@@ -1,6 +1,7 @@
package settings
import (
"errors"
"fmt"
"net/netip"
@@ -44,7 +45,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// endpoint IP addresses are baked in
case providers.Custom:
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet)
return errors.New("endpoint IP is not set")
}
default: // Providers not supporting Wireguard
}
@@ -54,13 +55,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// EndpointPort is required
case providers.Custom:
if *w.EndpointPort == 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
return errors.New("endpoint port is not set")
}
// EndpointPort cannot be set
case providers.Fastestvpn, providers.Nordvpn,
providers.Protonvpn, providers.Surfshark:
if *w.EndpointPort != 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
return errors.New("endpoint port is set")
}
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
// EndpointPort is optional and can be 0
@@ -84,8 +85,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
if err == nil {
break
}
return fmt.Errorf("%w: for VPN service provider %s: %w",
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
default: // Providers not supporting Wireguard
}
@@ -96,15 +96,14 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// public keys are baked in
case providers.Custom:
if w.PublicKey == "" {
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet)
return errors.New("public key is not set")
}
default: // Providers not supporting Wireguard
}
if w.PublicKey != "" {
_, err := wgtypes.ParseKey(w.PublicKey)
if err != nil {
return fmt.Errorf("%w: %s: %s",
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
}
}
@@ -74,8 +74,6 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
return privateKey, addresses
}
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
func parseWireguardPeerSection(peerSection *ini.Section) (
preSharedKey, publicKey, endpointIP, endpointPort *string,
) {