diff --git a/internal/netlink/conntrack_linux.go b/internal/netlink/conntrack_linux.go index 9ece2722..868e25e8 100644 --- a/internal/netlink/conntrack_linux.go +++ b/internal/netlink/conntrack_linux.go @@ -5,6 +5,7 @@ import ( "github.com/mdlayher/netlink" "github.com/ti-mo/netfilter" + "golang.org/x/sys/unix" ) func (n *NetLink) FlushConntrack() error { @@ -14,25 +15,21 @@ func (n *NetLink) FlushConntrack() error { } defer conn.Close() - families := [...]netfilter.ProtoFamily{netfilter.ProtoIPv4, netfilter.ProtoIPv6} - for _, family := range families { - const IPCtnlMsgCtDelete = 2 - request, err := netfilter.MarshalNetlink( - netfilter.Header{ - SubsystemID: netfilter.NFSubsysCTNetlink, - MessageType: netfilter.MessageType(IPCtnlMsgCtDelete), - Family: family, - Flags: netlink.Request | netlink.Acknowledge, - }, - nil) - if err != nil { - return fmt.Errorf("encoding netlink request: %w", err) - } + const ipCtnlMsgCtDelete = netfilter.MessageType(2) + header := netfilter.Header{ + SubsystemID: netfilter.NFSubsysCTNetlink, + MessageType: ipCtnlMsgCtDelete, + Family: unix.AF_UNSPEC, + Flags: netlink.Request | netlink.Acknowledge, + } + request, err := netfilter.MarshalNetlink(header, nil) + if err != nil { + return fmt.Errorf("encoding netlink request: %w", err) + } - _, err = conn.Query(request) - if err != nil { - return fmt.Errorf("querying netlink request: %w", err) - } + _, err = conn.Query(request) + if err != nil { + return fmt.Errorf("querying netlink request: %w", err) } return nil }