diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index d71db02b..9e0bca57 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -634,6 +634,8 @@ type RunStarter interface { Run(cmd *exec.Cmd) (output string, err error) Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, waitError <-chan error, err error) + RunAndLog(ctx context.Context, commandString string, + logger command.Logger) (err error) } const gluetunLogo = ` @@@ diff --git a/internal/command/interfaces.go b/internal/command/interfaces.go new file mode 100644 index 00000000..602987e7 --- /dev/null +++ b/internal/command/interfaces.go @@ -0,0 +1,6 @@ +package command + +type Logger interface { + Info(s string) + Error(s string) +} diff --git a/internal/command/split.go b/internal/command/split.go index 870d0a66..e1e8d223 100644 --- a/internal/command/split.go +++ b/internal/command/split.go @@ -9,13 +9,13 @@ import ( ) var ( - ErrCommandEmpty = errors.New("command is empty") - ErrSingleQuoteUnterminated = errors.New("unterminated single-quoted string") - ErrDoubleQuoteUnterminated = errors.New("unterminated double-quoted string") - ErrEscapeUnterminated = errors.New("unterminated backslash-escape") + errCommandEmpty = errors.New("command is empty") + errSingleQuoteUnterminated = errors.New("unterminated single-quoted string") + errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string") + errEscapeUnterminated = errors.New("unterminated backslash-escape") ) -// Split splits a command string into a slice of arguments. +// split splits a command string into a slice of arguments. // This is especially important for commands such as: // /bin/sh -c "echo hello" // which should be split into: ["/bin/sh", "-c", "echo hello"] @@ -23,9 +23,9 @@ var ( // It does not support: // - the $" quoting style. // - expansion (brace, shell or pathname). -func Split(command string) (words []string, err error) { +func split(command string) (words []string, err error) { if command == "" { - return nil, fmt.Errorf("%w", ErrCommandEmpty) + return nil, fmt.Errorf("%w", errCommandEmpty) } const bufferSize = 1024 @@ -42,7 +42,7 @@ func Split(command string) (words []string, err error) { case character == '\\': // Look ahead to eventually skip an escaped newline if command[startIndex+runeSize:] == "" { - return nil, fmt.Errorf("%w: %q", ErrEscapeUnterminated, command) + return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command) } character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:]) if character == '\n' { @@ -119,7 +119,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( startIndex = cursor } } - return "", 0, fmt.Errorf("%w", ErrDoubleQuoteUnterminated) + return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated) } func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( @@ -127,7 +127,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( ) { closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'') if closingQuoteIndex == -1 { - return "", 0, fmt.Errorf("%w", ErrSingleQuoteUnterminated) + return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated) } buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex]) const singleQuoteRuneLength = 1 @@ -139,7 +139,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) ( word string, newStartIndex int, err error, ) { if input[startIndex:] == "" { - return "", 0, fmt.Errorf("%w", ErrEscapeUnterminated) + return "", 0, fmt.Errorf("%w", errEscapeUnterminated) } character, runeLength := utf8.DecodeRuneInString(input[startIndex:]) if character != '\n' { // backslash-escaped newline is ignored diff --git a/internal/command/split_test.go b/internal/command/split_test.go index 96d38751..8eb14cbb 100644 --- a/internal/command/split_test.go +++ b/internal/command/split_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_Split(t *testing.T) { +func Test_split(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -17,7 +17,7 @@ func Test_Split(t *testing.T) { }{ "empty": { command: "", - errWrapped: ErrCommandEmpty, + errWrapped: errCommandEmpty, errMessage: "command is empty", }, "concrete_sh_command": { @@ -74,22 +74,22 @@ func Test_Split(t *testing.T) { }, "unterminated_single_quote": { command: "'abc'\\''def", - errWrapped: ErrSingleQuoteUnterminated, + errWrapped: errSingleQuoteUnterminated, errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`, }, "unterminated_double_quote": { command: "\"abc'def", - errWrapped: ErrDoubleQuoteUnterminated, + errWrapped: errDoubleQuoteUnterminated, errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`, }, "unterminated_escape": { command: "abc\\", - errWrapped: ErrEscapeUnterminated, + errWrapped: errEscapeUnterminated, errMessage: `splitting word in "abc\\": unterminated backslash-escape`, }, "unterminated_escape_only": { command: " \\", - errWrapped: ErrEscapeUnterminated, + errWrapped: errEscapeUnterminated, errMessage: `unterminated backslash-escape: " \\"`, }, } @@ -98,7 +98,7 @@ func Test_Split(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - words, err := Split(testCase.command) + words, err := split(testCase.command) assert.Equal(t, testCase.words, words) assert.ErrorIs(t, err, testCase.errWrapped) diff --git a/internal/command/startnlog.go b/internal/command/startnlog.go new file mode 100644 index 00000000..59699df7 --- /dev/null +++ b/internal/command/startnlog.go @@ -0,0 +1,48 @@ +package command + +import ( + "context" + "fmt" + "os/exec" +) + +func (c *Cmder) RunAndLog(ctx context.Context, command string, logger Logger) (err error) { + args, err := split(command) + if err != nil { + return fmt.Errorf("parsing command: %w", err) + } + + cmd := exec.CommandContext(ctx, args[0], args[1:]...) // #nosec G204 + stdout, stderr, waitError, err := c.Start(cmd) + if err != nil { + return err + } + + streamCtx, streamCancel := context.WithCancel(context.Background()) + streamDone := make(chan struct{}) + go streamLines(streamCtx, streamDone, logger, stdout, stderr) + + err = <-waitError + streamCancel() + <-streamDone + return err +} + +func streamLines(ctx context.Context, done chan<- struct{}, + logger Logger, stdout, stderr <-chan string, +) { + defer close(done) + + var line string + + for { + select { + case <-ctx.Done(): + return + case line = <-stdout: + logger.Info(line) + case line = <-stderr: + logger.Error(line) + } + } +} diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index d8aa6f1c..25336cfd 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -19,6 +19,14 @@ type VPN struct { OpenVPN OpenVPN `json:"openvpn"` Wireguard Wireguard `json:"wireguard"` PMTUD PMTUD `json:"pmtud"` + // UpCommand is the command to use when the VPN connection is up. + // It can be the empty string to indicate not to run a command. + // It cannot be nil in the internal state. + UpCommand *string `json:"up_command"` + // DownCommand is the command to use after the VPN connection goes down. + // It can be the empty string to indicate to NOT run a command. + // It cannot be nil in the internal state. + DownCommand *string `json:"down_command"` } // TODO v4 remove pointer for receiver (because of Surfshark). @@ -56,11 +64,13 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo func (v *VPN) Copy() (copied VPN) { return VPN{ - Type: v.Type, - Provider: v.Provider.copy(), - OpenVPN: v.OpenVPN.copy(), - Wireguard: v.Wireguard.copy(), - PMTUD: v.PMTUD.copy(), + Type: v.Type, + Provider: v.Provider.copy(), + OpenVPN: v.OpenVPN.copy(), + Wireguard: v.Wireguard.copy(), + PMTUD: v.PMTUD.copy(), + UpCommand: gosettings.CopyPointer(v.UpCommand), + DownCommand: gosettings.CopyPointer(v.DownCommand), } } @@ -70,6 +80,8 @@ func (v *VPN) OverrideWith(other VPN) { v.OpenVPN.overrideWith(other.OpenVPN) v.Wireguard.overrideWith(other.Wireguard) v.PMTUD.overrideWith(other.PMTUD) + v.UpCommand = gosettings.OverrideWithPointer(v.UpCommand, other.UpCommand) + v.DownCommand = gosettings.OverrideWithPointer(v.DownCommand, other.DownCommand) } func (v *VPN) setDefaults() { @@ -78,6 +90,8 @@ func (v *VPN) setDefaults() { v.OpenVPN.setDefaults(v.Provider.Name) v.Wireguard.setDefaults(v.Provider.Name) v.PMTUD.setDefaults() + v.UpCommand = gosettings.DefaultPointer(v.UpCommand, "") + v.DownCommand = gosettings.DefaultPointer(v.DownCommand, "") } func (v VPN) String() string { @@ -96,6 +110,13 @@ func (v VPN) toLinesNode() (node *gotree.Node) { } node.AppendNode(v.PMTUD.toLinesNode()) + if *v.UpCommand != "" { + node.Appendf("Up command: %s", *v.UpCommand) + } + if *v.DownCommand != "" { + node.Appendf("Down command: %s", *v.DownCommand) + } + return node } @@ -122,5 +143,9 @@ func (v *VPN) read(r *reader.Reader) (err error) { return fmt.Errorf("PMTUD: %w", err) } + v.UpCommand = r.Get("VPN_UP_COMMAND", reader.ForceLowercase(false)) + + v.DownCommand = r.Get("VPN_DOWN_COMMAND", reader.ForceLowercase(false)) + return nil } diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index 68da0046..45191f7e 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -4,6 +4,8 @@ import ( "context" "net/netip" "os/exec" + + "github.com/qdm12/gluetun/internal/command" ) type Service interface { @@ -35,4 +37,5 @@ type Logger interface { type Cmder interface { Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, waitError <-chan error, startErr error) + RunAndLog(ctx context.Context, commandString string, logger command.Logger) (err error) } diff --git a/internal/portforward/service/command.go b/internal/portforward/service/command.go index 4d351c25..6233accd 100644 --- a/internal/portforward/service/command.go +++ b/internal/portforward/service/command.go @@ -3,10 +3,7 @@ package service import ( "context" "fmt" - "os/exec" "strings" - - "github.com/qdm12/gluetun/internal/command" ) func runCommand(ctx context.Context, cmder Cmder, logger Logger, @@ -20,42 +17,5 @@ func runCommand(ctx context.Context, cmder Cmder, logger Logger, commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString) commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0]) commandString = strings.ReplaceAll(commandString, "{{VPN_INTERFACE}}", vpnInterface) - args, err := command.Split(commandString) - if err != nil { - return fmt.Errorf("parsing command: %w", err) - } - - cmd := exec.CommandContext(ctx, args[0], args[1:]...) // #nosec G204 - stdout, stderr, waitError, err := cmder.Start(cmd) - if err != nil { - return err - } - - streamCtx, streamCancel := context.WithCancel(context.Background()) - streamDone := make(chan struct{}) - go streamLines(streamCtx, streamDone, logger, stdout, stderr) - - err = <-waitError - streamCancel() - <-streamDone - return err -} - -func streamLines(ctx context.Context, done chan<- struct{}, - logger Logger, stdout, stderr <-chan string, -) { - defer close(done) - - var line string - - for { - select { - case <-ctx.Done(): - return - case line = <-stdout: - logger.Info(line) - case line = <-stderr: - logger.Error(line) - } - } + return cmder.RunAndLog(ctx, commandString, logger) } diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index 33288a30..f1ea7298 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -3,8 +3,8 @@ package service import ( "context" "net/netip" - "os/exec" + "github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -35,6 +35,5 @@ type PortForwarder interface { } type Cmder interface { - Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, - waitError <-chan error, startErr error) + RunAndLog(ctx context.Context, command string, logger command.Logger) (err error) } diff --git a/internal/vpn/cleanup.go b/internal/vpn/cleanup.go index 0b7db1a6..5c768757 100644 --- a/internal/vpn/cleanup.go +++ b/internal/vpn/cleanup.go @@ -3,9 +3,22 @@ package vpn import ( "context" "errors" + "strings" + + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/vpn" ) func (l *Loop) cleanup() { + settings := l.GetSettings() + + var err error + commandString := strings.ReplaceAll(*settings.DownCommand, "{{VPN_INTERFACE}}", getVPNInterface(settings)) + err = l.cmder.RunAndLog(context.Background(), commandString, l.logger) + if err != nil { + l.logger.Error("failed to run VPN down command: " + err.Error()) + } + for _, vpnPort := range l.vpnInputPorts { err := l.fw.RemoveAllowedPort(context.Background(), vpnPort) if err != nil { @@ -13,7 +26,7 @@ func (l *Loop) cleanup() { } } - err := l.publicip.ClearData() + err = l.publicip.ClearData() if err != nil { l.logger.Error("clearing public IP data: " + err.Error()) } @@ -31,3 +44,14 @@ func (l *Loop) cleanup() { l.logger.Error("stopping boring poll: " + err.Error()) } } + +func getVPNInterface(settings settings.VPN) string { + switch settings.Type { + case vpn.OpenVPN: + return settings.OpenVPN.Interface + case vpn.Wireguard: + return settings.Wireguard.Interface + default: + panic("invalid VPN type: " + settings.Type) + } +} diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 4189fe62..bfc52eb5 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -5,6 +5,7 @@ import ( "net/netip" "os/exec" + "github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/netlink" @@ -120,3 +121,7 @@ type Service interface { Start() (runError <-chan error, err error) Stop() error } + +type Cmder interface { + RunAndLog(ctx context.Context, command string, logger command.Logger) (err error) +} diff --git a/internal/vpn/loop.go b/internal/vpn/loop.go index e28cc804..3ff7cde0 100644 --- a/internal/vpn/loop.go +++ b/internal/vpn/loop.go @@ -17,6 +17,7 @@ type Loop struct { state *state.State providers Providers storage Storage + cmder Cmder healthSettings settings.Health healthChecker HealthChecker healthServer HealthServer diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 167750a3..5b23a22a 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -47,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ + upCommand: *settings.UpCommand, pmtud: tunnelUpPMTUDData{ enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0, vpnType: settings.Type, diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 6d132ef4..a1ecff84 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "strings" "time" "github.com/qdm12/gluetun/internal/constants" @@ -16,6 +17,7 @@ import ( ) type tunnelUpData struct { + upCommand string // Healthcheck serverIP netip.Addr pmtud tunnelUpPMTUDData @@ -107,6 +109,14 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { } } + if data.upCommand != "" { + commandString := strings.ReplaceAll(data.upCommand, "{{VPN_INTERFACE}}", data.vpnIntf) + err := l.cmder.RunAndLog(context.Background(), commandString, l.logger) + if err != nil { + l.logger.Error("failed to run VPN up command: " + err.Error()) + } + } + err = l.startPortForwarding(data) if err != nil { l.logger.Error(err.Error())