feat(wireguard): amneziawg implementation (#3150)

This commit is contained in:
Zhurik
2026-03-11 16:55:28 +03:00
committed by GitHub
parent f4eeffe79a
commit e6fc792f4f
20 changed files with 635 additions and 68 deletions
@@ -0,0 +1,228 @@
package settings
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
type AmneziaWg struct {
JunkPacketCount *uint16 `json:"junk_packet_count"`
JunkPacketMin *uint16 `json:"junk_packet_min"`
JunkPacketMax *uint16 `json:"junk_packet_max"`
PaddingS1 *uint16 `json:"padding_s1"`
PaddingS2 *uint16 `json:"padding_s2"`
PaddingS3 *uint16 `json:"padding_s3"`
PaddingS4 *uint16 `json:"padding_s4"`
HeaderH1 *string `json:"header_h1"`
HeaderH2 *string `json:"header_h2"`
HeaderH3 *string `json:"header_h3"`
HeaderH4 *string `json:"header_h4"`
InitPacketI1 *string `json:"init_packet_i1"`
InitPacketI2 *string `json:"init_packet_i2"`
InitPacketI3 *string `json:"init_packet_i3"`
InitPacketI4 *string `json:"init_packet_i4"`
InitPacketI5 *string `json:"init_packet_i5"`
}
func (s *AmneziaWg) read(r *reader.Reader) error {
uint16Fields := map[string]*uint16{
"AMNEZIAWG_JC": s.JunkPacketCount,
"AMNEZIAWG_JMIN": s.JunkPacketMin,
"AMNEZIAWG_JMAX": s.JunkPacketMax,
"AMNEZIAWG_S1": s.PaddingS1,
"AMNEZIAWG_S2": s.PaddingS2,
"AMNEZIAWG_S3": s.PaddingS3,
"AMNEZIAWG_S4": s.PaddingS4,
}
for key, dst := range uint16Fields {
v, err := r.Uint16Ptr(key)
if err != nil {
return err
} else if v != nil {
*dst = *v
}
}
stringFields := map[string]*string{
"AMNEZIAWG_H1": s.HeaderH1,
"AMNEZIAWG_H2": s.HeaderH2,
"AMNEZIAWG_H3": s.HeaderH3,
"AMNEZIAWG_H4": s.HeaderH4,
"AMNEZIAWG_I1": s.InitPacketI1,
"AMNEZIAWG_I2": s.InitPacketI2,
"AMNEZIAWG_I3": s.InitPacketI3,
"AMNEZIAWG_I4": s.InitPacketI4,
"AMNEZIAWG_I5": s.InitPacketI5,
}
opt := reader.ForceLowercase(false)
for key, dst := range stringFields {
v := r.Get(key, opt)
if v != nil {
*dst = *v
}
}
return nil
}
func (s AmneziaWg) copy() (copied AmneziaWg) {
return AmneziaWg{
JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount),
JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin),
JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax),
PaddingS1: gosettings.CopyPointer(s.PaddingS1),
PaddingS2: gosettings.CopyPointer(s.PaddingS2),
PaddingS3: gosettings.CopyPointer(s.PaddingS3),
PaddingS4: gosettings.CopyPointer(s.PaddingS4),
HeaderH1: gosettings.CopyPointer(s.HeaderH1),
HeaderH2: gosettings.CopyPointer(s.HeaderH2),
HeaderH3: gosettings.CopyPointer(s.HeaderH3),
HeaderH4: gosettings.CopyPointer(s.HeaderH4),
InitPacketI1: gosettings.CopyPointer(s.InitPacketI1),
InitPacketI2: gosettings.CopyPointer(s.InitPacketI2),
InitPacketI3: gosettings.CopyPointer(s.InitPacketI3),
InitPacketI4: gosettings.CopyPointer(s.InitPacketI4),
InitPacketI5: gosettings.CopyPointer(s.InitPacketI5),
}
}
//nolint:dupl
func (s *AmneziaWg) overrideWith(other AmneziaWg) {
s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount)
s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin)
s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax)
s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1)
s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2)
s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3)
s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4)
s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1)
s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2)
s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3)
s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4)
s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1)
s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2)
s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3)
s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4)
s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5)
}
func (s *AmneziaWg) setDefaults() {
s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0)
s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0)
s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0)
s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0)
s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0)
s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0)
s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0)
s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "")
s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "")
s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "")
s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "")
s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "")
s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "")
s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "")
s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "")
s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "")
}
func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
node = gotree.New("Amneziawg parameters:")
uintFields := []struct {
key string
val *uint16
}{
{"jc", s.JunkPacketCount},
{"jmin", s.JunkPacketMin},
{"jmax", s.JunkPacketMax},
{"s1", s.PaddingS1},
{"s2", s.PaddingS2},
{"s3", s.PaddingS3},
{"s4", s.PaddingS4},
}
for _, f := range uintFields {
node.Appendf("%s: %d", f.key, *f.val)
}
stringFields := []struct {
key string
val *string
}{
{"h1", s.HeaderH1},
{"h2", s.HeaderH2},
{"h3", s.HeaderH3},
{"h4", s.HeaderH4},
{"i1", s.InitPacketI1},
{"i2", s.InitPacketI2},
{"i3", s.InitPacketI3},
{"i4", s.InitPacketI4},
{"i5", s.InitPacketI5},
}
for _, f := range stringFields {
node.Appendf("%s: %s", f.key, *f.val)
}
return node
}
var (
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (s AmneziaWg) validate() error {
if *s.JunkPacketCount == 0 {
if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
}
} else {
if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
} else if *s.JunkPacketMin > *s.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d",
ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax)
}
}
nameToHeaderRange := map[string]string{
"h1": *s.HeaderH1,
"h2": *s.HeaderH2,
"h3": *s.HeaderH3,
"h4": *s.HeaderH4,
}
for name, headerRange := range nameToHeaderRange {
if headerRange == "" {
continue
}
fields := strings.Split(headerRange, "-")
switch len(fields) {
case 1:
_, err := strconv.Atoi(fields[0])
if err != nil {
return fmt.Errorf("%w: %s value %s is not a number",
ErrHeaderRangeMalformed, name, headerRange)
}
case 2: //nolint:mnd
for _, field := range fields {
_, err := strconv.Atoi(field)
if err != nil {
return fmt.Errorf("%w: %s value %s is not a valid range",
ErrHeaderRangeMalformed, name, headerRange)
}
}
default:
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
ErrHeaderRangeMalformed, name, headerRange)
}
}
return nil
}
@@ -268,6 +268,8 @@ func (o *OpenVPN) copy() (copied OpenVPN) {
// overrideWith overrides fields of the receiver
// settings object with any field set in the other
// settings.
//
//nolint:dupl
func (o *OpenVPN) overrideWith(other OpenVPN) {
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
o.User = gosettings.OverrideWithPointer(o.User, other.User)
+22 -3
View File
@@ -42,10 +42,12 @@ type Wireguard struct {
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It can be "auto", "userspace", "kernelspace" or "amneziawg".
// It defaults to "auto" and cannot be the empty string
// in the internal state.
Implementation string `json:"implementation"`
// AmneziaWG contains obfuscation parameters
AmneziaWG AmneziaWg `json:"amneziawg"`
}
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
@@ -136,11 +138,16 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
}
validImplementations := []string{"auto", "userspace", "kernelspace"}
validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
}
err = w.AmneziaWG.validate()
if err != nil {
return fmt.Errorf("amneziawg settings: %w", err)
}
return nil
}
@@ -154,6 +161,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation,
AmneziaWG: w.AmneziaWG.copy(),
}
}
@@ -167,6 +175,7 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
w.AmneziaWG.overrideWith(other.AmneziaWG)
}
func (w *Wireguard) setDefaults(vpnProvider string) {
@@ -191,6 +200,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
w.AmneziaWG.setDefaults()
}
func (w Wireguard) String() string {
@@ -232,7 +242,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
implNode := node.Appendf("Implementation: %s", w.Implementation)
if w.Implementation == "amneziawg" {
implNode.AppendNode(w.AmneziaWG.toLinesNode())
}
}
return node
@@ -245,6 +259,11 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false))
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
err = w.AmneziaWG.read(r)
if err != nil {
return err
}
addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS"))
// WARNING: do not initialize w.Addresses to an empty slice
// or the defaults for nordvpn will not work.
@@ -69,6 +69,38 @@ func (s *Source) Get(key string) (value string, isSet bool) {
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
case "wireguard_endpoint_port":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
case "wireguard_jc":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
case "wireguard_jmin":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
case "wireguard_jmax":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
case "wireguard_s1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
case "wireguard_s2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
case "wireguard_s3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
case "wireguard_s4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
case "wireguard_h1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
case "wireguard_h2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
case "wireguard_h3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
case "wireguard_h4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
case "wireguard_i1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
case "wireguard_i2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
case "wireguard_i3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
case "wireguard_i4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
case "wireguard_i5":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
}
value, isSet, err := ReadFromFile(path)
@@ -25,13 +25,54 @@ func (s *Source) lazyLoadWireguardConf() WireguardConfig {
return s.cached.wireguardConf
}
type amneziaWgConfig struct {
Jc *string
Jmin *string
Jmax *string
S1 *string
S2 *string
S3 *string
S4 *string
H1 *string
H2 *string
H3 *string
H4 *string
I1 *string
I2 *string
I3 *string
I4 *string
I5 *string
}
func parseWireguardAmneziaInterfaceSection(interfaceSection *ini.Section) amneziaWgConfig {
return amneziaWgConfig{
Jc: getINIKeyFromSection(interfaceSection, "Jc"),
Jmin: getINIKeyFromSection(interfaceSection, "Jmin"),
Jmax: getINIKeyFromSection(interfaceSection, "Jmax"),
S1: getINIKeyFromSection(interfaceSection, "S1"),
S2: getINIKeyFromSection(interfaceSection, "S2"),
S3: getINIKeyFromSection(interfaceSection, "S3"),
S4: getINIKeyFromSection(interfaceSection, "S4"),
H1: getINIKeyFromSection(interfaceSection, "H1"),
H2: getINIKeyFromSection(interfaceSection, "H2"),
H3: getINIKeyFromSection(interfaceSection, "H3"),
H4: getINIKeyFromSection(interfaceSection, "H4"),
I1: getINIKeyFromSection(interfaceSection, "I1"),
I2: getINIKeyFromSection(interfaceSection, "I2"),
I3: getINIKeyFromSection(interfaceSection, "I3"),
I4: getINIKeyFromSection(interfaceSection, "I4"),
I5: getINIKeyFromSection(interfaceSection, "I5"),
}
}
type WireguardConfig struct {
PrivateKey *string
PreSharedKey *string
Addresses *string
PublicKey *string
EndpointIP *string
EndpointPort *string
PrivateKey *string
PreSharedKey *string
Addresses *string
PublicKey *string
EndpointIP *string
EndpointPort *string
AmneziaParams amneziaWgConfig
}
var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`)
@@ -48,6 +89,7 @@ func ParseWireguardConf(path string) (config WireguardConfig, err error) {
interfaceSection, err := iniFile.GetSection("Interface")
if err == nil {
config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection)
config.AmneziaParams = parseWireguardAmneziaInterfaceSection(interfaceSection)
} else if !regexINISectionNotExist.MatchString(err.Error()) {
// can never happen
return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err)
@@ -97,9 +97,10 @@ func Test_parseWireguardInterfaceSection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
iniData string
privateKey *string
addresses *string
iniData string
privateKey *string
addresses *string
amneziaParams amneziaWgConfig
}{
"no_fields": {
iniData: `[Interface]`,
@@ -115,9 +116,17 @@ PrivateKey = x
[Interface]
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
Address = 10.38.22.35/32
Jc = 4
H1 = 721391205
I1 = <b 0x1234>
`,
privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
addresses: ptrTo("10.38.22.35/32"),
amneziaParams: amneziaWgConfig{
Jc: ptrTo("4"),
H1: ptrTo("721391205"),
I1: ptrTo("<b 0x1234>"),
},
},
}
@@ -131,9 +140,11 @@ Address = 10.38.22.35/32
require.NoError(t, err)
privateKey, addresses := parseWireguardInterfaceSection(iniSection)
amneziaWgConfig := parseWireguardAmneziaInterfaceSection(iniSection)
assert.Equal(t, testCase.privateKey, privateKey)
assert.Equal(t, testCase.addresses, addresses)
assert.Equal(t, testCase.amneziaParams, amneziaWgConfig)
})
}
}
@@ -83,6 +83,11 @@ func (s *Source) Get(key string) (value string, isSet bool) {
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
}
value, isSet, matched := s.getAmneziaWg(key)
if matched {
return value, isSet
}
value, isSet, err := files.ReadFromFile(path)
if err != nil {
s.warner.Warnf("skipping %s: reading file: %s", path, err)
@@ -104,3 +109,43 @@ func (s *Source) KeyTransform(key string) string {
return key
}
}
func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) {
switch key {
case "wireguard_jc":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
case "wireguard_jmin":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
case "wireguard_jmax":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
case "wireguard_s1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
case "wireguard_s2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
case "wireguard_s3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
case "wireguard_s4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
case "wireguard_h1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
case "wireguard_h2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
case "wireguard_h3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
case "wireguard_h4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
case "wireguard_i1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
case "wireguard_i2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
case "wireguard_i3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
case "wireguard_i4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
case "wireguard_i5":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
default:
return "", false, false
}
return value, isSet, true
}