diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index d97b306b..68ed71e0 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -151,9 +151,6 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.Connection, remove bool, ) error { protocol := connection.Protocol - if protocol == "tcp-client" { - protocol = "tcp" - } instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", appendOrDelete(remove), connection.IP, defaultInterface, protocol, protocol, connection.Port) diff --git a/internal/openvpn/extract/extract.go b/internal/openvpn/extract/extract.go index 64ad6181..eaa90888 100644 --- a/internal/openvpn/extract/extract.go +++ b/internal/openvpn/extract/extract.go @@ -81,10 +81,7 @@ func extractDataFromLine(line string) ( return ip, 0, "", nil } -var ( - errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected") - errProtocolNotSupported = errors.New("network protocol not supported") -) +var errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected") func extractProto(line string) (protocol string, err error) { fields := strings.Fields(line) @@ -92,13 +89,25 @@ func extractProto(line string) (protocol string, err error) { return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line) } - switch fields[1] { - case "tcp", "tcp4", "tcp6", "tcp-client", "udp", "udp4", "udp6": - default: - return "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[1]) - } + return parseProto(fields[1]) +} - return fields[1], nil +var errProtocolNotSupported = errors.New("network protocol not supported") + +func parseProto(field string) (protocol string, err error) { + switch field { + case "tcp", "tcp4", "tcp6", "tcp-client": + // tcp4, tcp6 can be assimilated as tcp since the IP version is + // determined by the remote IP address version. + // tcp-client is a synonym of tcp for OpenVPN 2.5+ acting in client mode. + return constants.TCP, nil + case "udp", "udp4", "udp6": + // udp4, udp6 can be assimilated as udp since the IP version is + // determined by the remote IP address version. + return constants.UDP, nil + default: + return "", fmt.Errorf("%w: %s", errProtocolNotSupported, field) + } } var ( @@ -136,11 +145,9 @@ func extractRemote(line string) (ip netip.Addr, port uint16, } if n > 3 { //nolint:mnd - switch fields[3] { - case "tcp", "udp": - protocol = fields[3] - default: - return netip.Addr{}, 0, "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[3]) + protocol, err = parseProto(fields[3]) + if err != nil { + return netip.Addr{}, 0, "", fmt.Errorf("parsing protocol from remote line: %w", err) } } diff --git a/internal/openvpn/extract/extract_test.go b/internal/openvpn/extract/extract_test.go index ee21588e..77e122ec 100644 --- a/internal/openvpn/extract/extract_test.go +++ b/internal/openvpn/extract/extract_test.go @@ -105,7 +105,7 @@ func Test_extractDataFromLine(t *testing.T) { }, "tcp-client": { line: "proto tcp-client", - protocol: "tcp-client", + protocol: constants.TCP, }, "extract remote error": { line: "remote bad", @@ -239,7 +239,7 @@ func Test_extractRemote(t *testing.T) { }, "invalid protocol": { line: "remote 1.2.3.4 8000 bad", - err: errors.New("network protocol not supported: bad"), + err: errors.New("parsing protocol from remote line: network protocol not supported: bad"), }, "IP host and port and protocol": { line: "remote 1.2.3.4 8000 udp",