chore!(amneziawg): refactor to be separate from wireguard

- amneziawg is now a VPN protocol and no longer a Wireguard implementation
- Use it with VPN_TYPE=amneziawg
- document AMNEZIAWG_* options in Dockerfile
- document amneziawg support in readme
- separate amneziawg settings and code from wireguard
- re-use code from wireguard whenever possible
This commit is contained in:
Quentin McGaw
2026-03-11 16:35:18 +00:00
parent efea169495
commit b04529c380
54 changed files with 1608 additions and 741 deletions
+139 -119
View File
@@ -12,33 +12,42 @@ import (
)
type AmneziaWg struct {
JunkPacketCount *uint16 `json:"junk_packet_count"`
JunkPacketMin *uint16 `json:"junk_packet_min"`
JunkPacketMax *uint16 `json:"junk_packet_max"`
PaddingS1 *uint16 `json:"padding_s1"`
PaddingS2 *uint16 `json:"padding_s2"`
PaddingS3 *uint16 `json:"padding_s3"`
PaddingS4 *uint16 `json:"padding_s4"`
HeaderH1 *string `json:"header_h1"`
HeaderH2 *string `json:"header_h2"`
HeaderH3 *string `json:"header_h3"`
HeaderH4 *string `json:"header_h4"`
InitPacketI1 *string `json:"init_packet_i1"`
InitPacketI2 *string `json:"init_packet_i2"`
InitPacketI3 *string `json:"init_packet_i3"`
InitPacketI4 *string `json:"init_packet_i4"`
InitPacketI5 *string `json:"init_packet_i5"`
// Wireguard contains the configuration for Wireguard, given
// AmneziaWg is based on Wireguard
Wireguard Wireguard `json:"wireguard"`
JunkPacketCount *uint16 `json:"junk_packet_count"`
JunkPacketMin *uint16 `json:"junk_packet_min"`
JunkPacketMax *uint16 `json:"junk_packet_max"`
PaddingS1 *uint16 `json:"padding_s1"`
PaddingS2 *uint16 `json:"padding_s2"`
PaddingS3 *uint16 `json:"padding_s3"`
PaddingS4 *uint16 `json:"padding_s4"`
HeaderH1 *string `json:"header_h1"`
HeaderH2 *string `json:"header_h2"`
HeaderH3 *string `json:"header_h3"`
HeaderH4 *string `json:"header_h4"`
InitPacketI1 *string `json:"init_packet_i1"`
InitPacketI2 *string `json:"init_packet_i2"`
InitPacketI3 *string `json:"init_packet_i3"`
InitPacketI4 *string `json:"init_packet_i4"`
InitPacketI5 *string `json:"init_packet_i5"`
}
func (s *AmneziaWg) read(r *reader.Reader) (err error) {
func (a *AmneziaWg) read(r *reader.Reader) (err error) {
const amneziawg = true
err = a.Wireguard.read(r, amneziawg)
if err != nil {
return err // do not wrap this error
}
uint16Fields := map[string]**uint16{
"AMNEZIAWG_JC": &s.JunkPacketCount,
"AMNEZIAWG_JMIN": &s.JunkPacketMin,
"AMNEZIAWG_JMAX": &s.JunkPacketMax,
"AMNEZIAWG_S1": &s.PaddingS1,
"AMNEZIAWG_S2": &s.PaddingS2,
"AMNEZIAWG_S3": &s.PaddingS3,
"AMNEZIAWG_S4": &s.PaddingS4,
"AMNEZIAWG_JC": &a.JunkPacketCount,
"AMNEZIAWG_JMIN": &a.JunkPacketMin,
"AMNEZIAWG_JMAX": &a.JunkPacketMax,
"AMNEZIAWG_S1": &a.PaddingS1,
"AMNEZIAWG_S2": &a.PaddingS2,
"AMNEZIAWG_S3": &a.PaddingS3,
"AMNEZIAWG_S4": &a.PaddingS4,
}
for key, dst := range uint16Fields {
*dst, err = r.Uint16Ptr(key)
@@ -47,15 +56,15 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
}
}
stringFields := map[string]**string{
"AMNEZIAWG_H1": &s.HeaderH1,
"AMNEZIAWG_H2": &s.HeaderH2,
"AMNEZIAWG_H3": &s.HeaderH3,
"AMNEZIAWG_H4": &s.HeaderH4,
"AMNEZIAWG_I1": &s.InitPacketI1,
"AMNEZIAWG_I2": &s.InitPacketI2,
"AMNEZIAWG_I3": &s.InitPacketI3,
"AMNEZIAWG_I4": &s.InitPacketI4,
"AMNEZIAWG_I5": &s.InitPacketI5,
"AMNEZIAWG_H1": &a.HeaderH1,
"AMNEZIAWG_H2": &a.HeaderH2,
"AMNEZIAWG_H3": &a.HeaderH3,
"AMNEZIAWG_H4": &a.HeaderH4,
"AMNEZIAWG_I1": &a.InitPacketI1,
"AMNEZIAWG_I2": &a.InitPacketI2,
"AMNEZIAWG_I3": &a.InitPacketI3,
"AMNEZIAWG_I4": &a.InitPacketI4,
"AMNEZIAWG_I5": &a.InitPacketI5,
}
opt := reader.ForceLowercase(false)
for key, dst := range stringFields {
@@ -64,80 +73,84 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
return nil
}
func (s AmneziaWg) copy() (copied AmneziaWg) {
func (a AmneziaWg) copy() (copied AmneziaWg) {
return AmneziaWg{
JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount),
JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin),
JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax),
PaddingS1: gosettings.CopyPointer(s.PaddingS1),
PaddingS2: gosettings.CopyPointer(s.PaddingS2),
PaddingS3: gosettings.CopyPointer(s.PaddingS3),
PaddingS4: gosettings.CopyPointer(s.PaddingS4),
HeaderH1: gosettings.CopyPointer(s.HeaderH1),
HeaderH2: gosettings.CopyPointer(s.HeaderH2),
HeaderH3: gosettings.CopyPointer(s.HeaderH3),
HeaderH4: gosettings.CopyPointer(s.HeaderH4),
InitPacketI1: gosettings.CopyPointer(s.InitPacketI1),
InitPacketI2: gosettings.CopyPointer(s.InitPacketI2),
InitPacketI3: gosettings.CopyPointer(s.InitPacketI3),
InitPacketI4: gosettings.CopyPointer(s.InitPacketI4),
InitPacketI5: gosettings.CopyPointer(s.InitPacketI5),
Wireguard: a.Wireguard.copy(),
JunkPacketCount: gosettings.CopyPointer(a.JunkPacketCount),
JunkPacketMin: gosettings.CopyPointer(a.JunkPacketMin),
JunkPacketMax: gosettings.CopyPointer(a.JunkPacketMax),
PaddingS1: gosettings.CopyPointer(a.PaddingS1),
PaddingS2: gosettings.CopyPointer(a.PaddingS2),
PaddingS3: gosettings.CopyPointer(a.PaddingS3),
PaddingS4: gosettings.CopyPointer(a.PaddingS4),
HeaderH1: gosettings.CopyPointer(a.HeaderH1),
HeaderH2: gosettings.CopyPointer(a.HeaderH2),
HeaderH3: gosettings.CopyPointer(a.HeaderH3),
HeaderH4: gosettings.CopyPointer(a.HeaderH4),
InitPacketI1: gosettings.CopyPointer(a.InitPacketI1),
InitPacketI2: gosettings.CopyPointer(a.InitPacketI2),
InitPacketI3: gosettings.CopyPointer(a.InitPacketI3),
InitPacketI4: gosettings.CopyPointer(a.InitPacketI4),
InitPacketI5: gosettings.CopyPointer(a.InitPacketI5),
}
}
//nolint:dupl
func (s *AmneziaWg) overrideWith(other AmneziaWg) {
s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount)
s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin)
s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax)
s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1)
s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2)
s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3)
s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4)
s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1)
s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2)
s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3)
s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4)
s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1)
s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2)
s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3)
s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4)
s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5)
func (a *AmneziaWg) overrideWith(other AmneziaWg) {
a.Wireguard.overrideWith(other.Wireguard)
a.JunkPacketCount = gosettings.OverrideWithPointer(a.JunkPacketCount, other.JunkPacketCount)
a.JunkPacketMin = gosettings.OverrideWithPointer(a.JunkPacketMin, other.JunkPacketMin)
a.JunkPacketMax = gosettings.OverrideWithPointer(a.JunkPacketMax, other.JunkPacketMax)
a.PaddingS1 = gosettings.OverrideWithPointer(a.PaddingS1, other.PaddingS1)
a.PaddingS2 = gosettings.OverrideWithPointer(a.PaddingS2, other.PaddingS2)
a.PaddingS3 = gosettings.OverrideWithPointer(a.PaddingS3, other.PaddingS3)
a.PaddingS4 = gosettings.OverrideWithPointer(a.PaddingS4, other.PaddingS4)
a.HeaderH1 = gosettings.OverrideWithPointer(a.HeaderH1, other.HeaderH1)
a.HeaderH2 = gosettings.OverrideWithPointer(a.HeaderH2, other.HeaderH2)
a.HeaderH3 = gosettings.OverrideWithPointer(a.HeaderH3, other.HeaderH3)
a.HeaderH4 = gosettings.OverrideWithPointer(a.HeaderH4, other.HeaderH4)
a.InitPacketI1 = gosettings.OverrideWithPointer(a.InitPacketI1, other.InitPacketI1)
a.InitPacketI2 = gosettings.OverrideWithPointer(a.InitPacketI2, other.InitPacketI2)
a.InitPacketI3 = gosettings.OverrideWithPointer(a.InitPacketI3, other.InitPacketI3)
a.InitPacketI4 = gosettings.OverrideWithPointer(a.InitPacketI4, other.InitPacketI4)
a.InitPacketI5 = gosettings.OverrideWithPointer(a.InitPacketI5, other.InitPacketI5)
}
func (s *AmneziaWg) setDefaults() {
s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0)
s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0)
s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0)
s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0)
s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0)
s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0)
s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0)
s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "")
s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "")
s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "")
s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "")
s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "")
s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "")
s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "")
s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "")
s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "")
func (a *AmneziaWg) setDefaults(vpnProvider string) {
a.Wireguard.setDefaults(vpnProvider)
a.Wireguard.Implementation = "userspace" // unused except in logs
a.JunkPacketCount = gosettings.DefaultPointer(a.JunkPacketCount, 0)
a.JunkPacketMin = gosettings.DefaultPointer(a.JunkPacketMin, 0)
a.JunkPacketMax = gosettings.DefaultPointer(a.JunkPacketMax, 0)
a.PaddingS1 = gosettings.DefaultPointer(a.PaddingS1, 0)
a.PaddingS2 = gosettings.DefaultPointer(a.PaddingS2, 0)
a.PaddingS3 = gosettings.DefaultPointer(a.PaddingS3, 0)
a.PaddingS4 = gosettings.DefaultPointer(a.PaddingS4, 0)
a.HeaderH1 = gosettings.DefaultPointer(a.HeaderH1, "")
a.HeaderH2 = gosettings.DefaultPointer(a.HeaderH2, "")
a.HeaderH3 = gosettings.DefaultPointer(a.HeaderH3, "")
a.HeaderH4 = gosettings.DefaultPointer(a.HeaderH4, "")
a.InitPacketI1 = gosettings.DefaultPointer(a.InitPacketI1, "")
a.InitPacketI2 = gosettings.DefaultPointer(a.InitPacketI2, "")
a.InitPacketI3 = gosettings.DefaultPointer(a.InitPacketI3, "")
a.InitPacketI4 = gosettings.DefaultPointer(a.InitPacketI4, "")
a.InitPacketI5 = gosettings.DefaultPointer(a.InitPacketI5, "")
}
func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
node = gotree.New("Amneziawg parameters:")
func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
node = gotree.New("AmneziaWG settings:")
node.AppendNode(a.Wireguard.toLinesNode())
uintFields := []struct {
key string
val *uint16
}{
{"jc", s.JunkPacketCount},
{"jmin", s.JunkPacketMin},
{"jmax", s.JunkPacketMax},
{"s1", s.PaddingS1},
{"s2", s.PaddingS2},
{"s3", s.PaddingS3},
{"s4", s.PaddingS4},
{"JC", a.JunkPacketCount},
{"JMIN", a.JunkPacketMin},
{"JMAX", a.JunkPacketMax},
{"S1", a.PaddingS1},
{"S2", a.PaddingS2},
{"S3", a.PaddingS3},
{"S4", a.PaddingS4},
}
for _, f := range uintFields {
node.Appendf("%s: %d", f.key, *f.val)
@@ -147,15 +160,15 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
key string
val *string
}{
{"h1", s.HeaderH1},
{"h2", s.HeaderH2},
{"h3", s.HeaderH3},
{"h4", s.HeaderH4},
{"i1", s.InitPacketI1},
{"i2", s.InitPacketI2},
{"i3", s.InitPacketI3},
{"i4", s.InitPacketI4},
{"i5", s.InitPacketI5},
{"H1", a.HeaderH1},
{"H2", a.HeaderH2},
{"H3", a.HeaderH3},
{"H4", a.HeaderH4},
{"I1", a.InitPacketI1},
{"I2", a.InitPacketI2},
{"I3", a.InitPacketI3},
{"I4", a.InitPacketI4},
{"I5", a.InitPacketI5},
}
for _, f := range stringFields {
node.Appendf("%s: %s", f.key, *f.val)
@@ -165,33 +178,40 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
}
var (
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (s AmneziaWg) validate() error {
if *s.JunkPacketCount == 0 {
if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 {
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
const amneziaWG = true
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
if err != nil {
return fmt.Errorf("wireguard settings: %w", err)
}
if *a.JunkPacketCount == 0 {
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
}
} else {
if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 {
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
} else if *s.JunkPacketMin > *s.JunkPacketMax {
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} else if *a.JunkPacketMin > *a.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d",
ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax)
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
}
}
nameToHeaderRange := map[string]string{
"h1": *s.HeaderH1,
"h2": *s.HeaderH2,
"h3": *s.HeaderH3,
"h4": *s.HeaderH4,
"h1": *a.HeaderH1,
"h2": *a.HeaderH2,
"h3": *a.HeaderH3,
"h4": *a.HeaderH4,
}
for name, headerRange := range nameToHeaderRange {
if headerRange == "" {
@@ -268,8 +268,6 @@ func (o *OpenVPN) copy() (copied OpenVPN) {
// overrideWith overrides fields of the receiver
// settings object with any field set in the other
// settings.
//
//nolint:dupl
func (o *OpenVPN) overrideWith(other OpenVPN) {
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
o.User = gosettings.OverrideWithPointer(o.User, other.User)
+6 -3
View File
@@ -30,7 +30,10 @@ type Provider struct {
func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) {
// Validate Name
var validNames []string
if vpnType == vpn.OpenVPN {
switch vpnType {
case vpn.AmneziaWg:
validNames = []string{providers.Custom}
case vpn.OpenVPN:
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
@@ -38,7 +41,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
validNames = validNames[:len(validNames)-1]
sort.Strings(validNames)
} else { // Wireguard
case vpn.Wireguard:
validNames = []string{
providers.Airvpn,
providers.Custom,
@@ -52,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
}
}
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
@@ -87,7 +87,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
filterChoicesGetter FilterChoicesGetter, warner Warner,
) (err error) {
switch ss.VPN {
case vpn.OpenVPN, vpn.Wireguard:
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
}
+29 -7
View File
@@ -16,6 +16,7 @@ type VPN struct {
// empty string in the internal state.
Type string `json:"type"`
Provider Provider `json:"provider"`
AmneziaWg AmneziaWg `json:"amneziawg"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
@@ -29,10 +30,12 @@ type VPN struct {
DownCommand *string `json:"down_command"`
}
// Validate validates VPN settings, using the filter choices getter (aka servers data storage),
// and if IPv6 is supported or not.
// TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) {
// Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
}
@@ -42,13 +45,20 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
return fmt.Errorf("provider settings: %w", err)
}
if v.Type == vpn.OpenVPN {
switch v.Type {
case vpn.AmneziaWg:
err = v.AmneziaWg.validate(v.Provider.Name, ipv6Supported)
if err != nil {
return fmt.Errorf("AmneziaWG settings: %w", err)
}
case vpn.OpenVPN:
err := v.OpenVPN.validate(v.Provider.Name)
if err != nil {
return fmt.Errorf("OpenVPN settings: %w", err)
}
} else {
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported)
case vpn.Wireguard:
const amneziawg = false
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported, amneziawg)
if err != nil {
return fmt.Errorf("Wireguard settings: %w", err)
}
@@ -66,6 +76,7 @@ func (v *VPN) Copy() (copied VPN) {
return VPN{
Type: v.Type,
Provider: v.Provider.copy(),
AmneziaWg: v.AmneziaWg.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
@@ -77,6 +88,7 @@ func (v *VPN) Copy() (copied VPN) {
func (v *VPN) OverrideWith(other VPN) {
v.Type = gosettings.OverrideWithComparable(v.Type, other.Type)
v.Provider.overrideWith(other.Provider)
v.AmneziaWg.overrideWith(other.AmneziaWg)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
@@ -87,6 +99,7 @@ func (v *VPN) OverrideWith(other VPN) {
func (v *VPN) setDefaults() {
v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN)
v.Provider.setDefaults()
v.AmneziaWg.setDefaults(v.Provider.Name)
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
@@ -103,9 +116,12 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
node.AppendNode(v.Provider.toLinesNode())
if v.Type == vpn.OpenVPN {
switch v.Type {
case vpn.AmneziaWg:
node.AppendNode(v.AmneziaWg.toLinesNode())
case vpn.OpenVPN:
node.AppendNode(v.OpenVPN.toLinesNode())
} else {
case vpn.Wireguard:
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
@@ -128,12 +144,18 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("VPN provider: %w", err)
}
err = v.AmneziaWg.read(r)
if err != nil {
return fmt.Errorf("AmneziaWG: %w", err)
}
err = v.OpenVPN.read(r)
if err != nil {
return fmt.Errorf("OpenVPN: %w", err)
}
err = v.Wireguard.read(r)
const amneziawg = false
err = v.Wireguard.read(r, amneziawg)
if err != nil {
return fmt.Errorf("wireguard: %w", err)
}
+23 -49
View File
@@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
@@ -42,34 +41,17 @@ type Wireguard struct {
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace", "kernelspace" or "amneziawg".
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
// in the internal state.
Implementation string `json:"implementation"`
// AmneziaWG contains obfuscation parameters
AmneziaWG AmneziaWg `json:"amneziawg"`
}
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
// Validate validates Wireguard settings.
// It should only be ran if the VPN type chosen is Wireguard.
func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) {
if !helpers.IsOneOf(vpnProvider,
providers.Airvpn,
providers.Custom,
providers.Fastestvpn,
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.Protonvpn,
providers.Surfshark,
providers.Windscribe,
) {
// do not validate for VPN provider not supporting Wireguard
return nil
}
// It should only be ran if the VPN type chosen is Wireguard or AmneziaWg.
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
// Validate PrivateKey
if *w.PrivateKey == "" {
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
@@ -138,14 +120,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
}
validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
}
err = w.AmneziaWG.validate()
if err != nil {
return fmt.Errorf("amneziawg settings: %w", err)
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
validImplementations := []string{"auto", "userspace", "kernelspace"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
}
}
return nil
@@ -161,7 +140,6 @@ func (w *Wireguard) copy() (copied Wireguard) {
Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation,
AmneziaWG: w.AmneziaWG.copy(),
}
}
@@ -175,7 +153,6 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
w.AmneziaWG.overrideWith(other.AmneziaWG)
}
func (w *Wireguard) setDefaults(vpnProvider string) {
@@ -200,7 +177,6 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
w.AmneziaWG.setDefaults()
}
func (w Wireguard) String() string {
@@ -242,29 +218,27 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
if w.Implementation != "auto" {
implNode := node.Appendf("Implementation: %s", w.Implementation)
if w.Implementation == "amneziawg" {
implNode.AppendNode(w.AmneziaWG.toLinesNode())
}
node.Appendf("Implementation: %s", w.Implementation)
}
return node
}
func (w *Wireguard) read(r *reader.Reader) (err error) {
w.PrivateKey = r.Get("WIREGUARD_PRIVATE_KEY", reader.ForceLowercase(false))
w.PreSharedKey = r.Get("WIREGUARD_PRESHARED_KEY", reader.ForceLowercase(false))
func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) {
prefix := "WIREGUARD"
if amneziaWG {
prefix = "AMNEZIAWG"
}
w.PrivateKey = r.Get(prefix+"_PRIVATE_KEY", reader.ForceLowercase(false))
w.PreSharedKey = r.Get(prefix+"_PRESHARED_KEY", reader.ForceLowercase(false))
w.Interface = r.String("VPN_INTERFACE",
reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false))
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
reader.RetroKeys(prefix+"_INTERFACE"), reader.ForceLowercase(false))
err = w.AmneziaWG.read(r)
if err != nil {
return err
if !amneziaWG {
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
}
addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS"))
addressStrings := r.CSV(prefix+"_ADDRESSES", reader.RetroKeys(prefix+"_ADDRESS"))
// WARNING: do not initialize w.Addresses to an empty slice
// or the defaults for nordvpn will not work.
for _, addressString := range addressStrings {
@@ -279,17 +253,17 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
w.Addresses = append(w.Addresses, address)
}
w.AllowedIPs, err = r.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
w.AllowedIPs, err = r.CSVNetipPrefixes(prefix + "_ALLOWED_IPS")
if err != nil {
return err // already wrapped
}
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
w.PersistentKeepaliveInterval, err = r.DurationPtr(prefix + "_PERSISTENT_KEEPALIVE_INTERVAL")
if err != nil {
return err
}
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
w.MTU, err = r.Uint32Ptr(prefix + "_MTU")
if err != nil {
return err
}