From 224618337c43cb5b3a73280f9a8fee6ffca046a8 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 18 Feb 2026 18:32:10 +0000 Subject: [PATCH] hotfix(pmtud/tcp): respect MSS from server into account --- internal/pmtud/tcp/tcp.go | 49 ++++--- internal/pmtud/tcp/tcpheader.go | 224 ++++++++++++++++++++++++++++---- 2 files changed, 233 insertions(+), 40 deletions(-) diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go index c0676e16..0c88a841 100644 --- a/internal/pmtud/tcp/tcp.go +++ b/internal/pmtud/tcp/tcp.go @@ -58,6 +58,7 @@ var ( errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK") errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value") errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected") + errTCPPacketLost = errors.New("TCP packet was lost") ) // Craft and send a raw TCP packet to test the MTU. @@ -95,22 +96,37 @@ func runTest(ctx context.Context, fd fileDescriptor, case reply = <-ch: } - packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + firstReplyHeader, err := parseTCPHeader(reply) switch { case err != nil: return fmt.Errorf("parsing first reply TCP header: %w", err) - case packetType == packetTypeRST: + case firstReplyHeader.typ == packetTypeRST, + firstReplyHeader.typ == packetTypeRSTACK: // server actively closed the connection, try sending a SYN with data return handleRSTReply(ctx, fd, ch, src, dst, mtu) - case packetType != packetTypeSYNACK: - return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType) - case synAckAck != synSeq+1: - return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck) + case firstReplyHeader.typ != packetTypeSYNACK: + return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ) + case firstReplyHeader.ack != synSeq+1: + return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack) + } + + if firstReplyHeader.options.mss != 0 { + // If the server sent an MSS option, make sure our test packet is not larger than that MSS. + tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength + if tcpDataLength > uint32(firstReplyHeader.options.mss) { + diff := tcpDataLength - uint32(firstReplyHeader.options.mss) + minMTU := constants.MinIPv4MTU + if dst.Addr().Is6() { + minMTU = constants.MinIPv6MTU + } + diff = min(diff, mtu-minMTU) + mtu -= diff + } } // Send an ACK packet to finish the 3-way handshake, together with the // data to test the MTU, using TCP fast-open. - ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu) + ackPacket := createACKPacket(src, dst, firstReplyHeader.ack, firstReplyHeader.seq+1, mtu) err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr) if err != nil { return fmt.Errorf("sending ACK packet: %w", err) @@ -122,23 +138,25 @@ func runTest(ctx context.Context, fd fileDescriptor, case reply = <-ch: } - packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + finalPacketHeader, err := parseTCPHeader(reply) if err != nil { return fmt.Errorf("parsing second reply TCP header: %w", err) } - switch packetType { //nolint:exhaustive + switch finalPacketHeader.typ { //nolint:exhaustive case packetTypeRST: return nil case packetTypeACK: - err = sendRST(fd, src, dst, ack) + err = sendRST(fd, src, dst, finalPacketHeader.ack) if err != nil { return fmt.Errorf("sending RST packet: %w", err) } return nil + case packetTypeSYNACK: // server never received our MTU-test ACK packet + return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost) default: - _ = sendRST(fd, src, dst, ack) - return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType) + _ = sendRST(fd, src, dst, finalPacketHeader.ack) + return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ) } } @@ -161,11 +179,12 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte, case reply = <-ch: } - packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + replyPacketHeader, err := parseTCPHeader(reply) if err != nil { return fmt.Errorf("parsing reply TCP header: %w", err) - } else if packetType != packetTypeRST { - return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType) + } else if replyPacketHeader.typ != packetTypeRST && + replyPacketHeader.typ != packetTypeRSTACK { + return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ) } return nil } diff --git a/internal/pmtud/tcp/tcpheader.go b/internal/pmtud/tcp/tcpheader.go index f9deb13b..c3b3c111 100644 --- a/internal/pmtud/tcp/tcpheader.go +++ b/internal/pmtud/tcp/tcpheader.go @@ -64,10 +64,14 @@ func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 { const ( tcpFlagsOffset = 13 - rstFlag byte = 0x04 + finFlag byte = 0x01 synFlag byte = 0x02 - ackFlag byte = 0x10 + rstFlag byte = 0x04 pshFlag byte = 0x08 + ackFlag byte = 0x10 + urgFlag byte = 0x20 + eceFlag byte = 0x40 + cwrFlag byte = 0x80 ) type packetType uint8 @@ -75,8 +79,12 @@ type packetType uint8 const ( packetTypeSYN packetType = iota + 1 packetTypeSYNACK - packetTypeACK + packetTypeFIN + packetTypeFINACK packetTypeRST + packetTypeRSTACK + packetTypePSHACK + packetTypeACK ) func (p packetType) String() string { @@ -85,40 +93,206 @@ func (p packetType) String() string { return "SYN" case packetTypeSYNACK: return "SYN-ACK" - case packetTypeACK: - return "ACK" + case packetTypeFIN: + return "FIN" + case packetTypeFINACK: + return "FIN-ACK" case packetTypeRST: return "RST" + case packetTypeRSTACK: + return "RST-ACK" + case packetTypePSHACK: + return "PSH-ACK" + case packetTypeACK: + return "ACK" default: panic("unknown packet type") } } +type tcpHeader struct { + typ packetType + srcPort uint16 + dstPort uint16 + seq uint32 + ack uint32 + dataOffset uint8 + flags uint8 + windowSize uint16 + checksum uint16 + urgentPtr uint16 + options options +} + var ( errTCPHeaderTooShort = errors.New("TCP header is too short") + errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid") errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown") ) -// parseTCPHeader parses some elements from the TCP header. -func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) { - if len(header) < int(constants.BaseTCPHeaderLength) { - return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header)) - } - flags := header[tcpFlagsOffset] - switch { - case (flags&synFlag) != 0 && (flags&ackFlag) == 0: - packetType = packetTypeSYN - case (flags&synFlag) != 0 && (flags&ackFlag) != 0: - packetType = packetTypeSYNACK - case (flags & rstFlag) != 0: - packetType = packetTypeRST - case (flags & ackFlag) != 0: - packetType = packetTypeACK - default: - return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags) +// parseTCPHeader parses the TCP header from b. +// b should be the entire TCP packet bytes. +func parseTCPHeader(b []byte) (header tcpHeader, err error) { + if len(b) < int(constants.BaseTCPHeaderLength) { + return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b)) } - seq = binary.BigEndian.Uint32(header[4:8]) - ack = binary.BigEndian.Uint32(header[8:12]) - return packetType, seq, ack, nil + header.srcPort = binary.BigEndian.Uint16(b[0:2]) + header.dstPort = binary.BigEndian.Uint16(b[2:4]) + header.seq = binary.BigEndian.Uint32(b[4:8]) + header.ack = binary.BigEndian.Uint32(b[8:12]) + // upper 4 bits of the 12th byte + header.dataOffset = (b[12] >> 4) * 4 //nolint:mnd + header.flags = b[13] + header.windowSize = binary.BigEndian.Uint16(b[14:16]) + header.checksum = binary.BigEndian.Uint16(b[16:18]) + header.urgentPtr = binary.BigEndian.Uint16(b[18:20]) + + switch { + case uint32(header.dataOffset) < constants.BaseTCPHeaderLength: + return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes", + errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength) + case int(header.dataOffset) > len(b): + return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes", + errTCPHeaderDataOffset, header.dataOffset, len(b)) + } + + if uint32(header.dataOffset) > constants.BaseTCPHeaderLength { + optionsBytes := b[constants.BaseTCPHeaderLength:header.dataOffset] + header.options, err = parseTCPOptions(optionsBytes) + if err != nil { + return tcpHeader{}, fmt.Errorf("parsing TCP options: %w", err) + } + } + + flags := header.flags + switch { + case flags&synFlag != 0: + if flags&ackFlag != 0 { + header.typ = packetTypeSYNACK + } else { + header.typ = packetTypeSYN + } + case flags&rstFlag != 0: + if flags&ackFlag != 0 { + header.typ = packetTypeRSTACK + } else { + header.typ = packetTypeRST + } + case flags&finFlag != 0: + if flags&ackFlag != 0 { + header.typ = packetTypeFINACK + } else { + header.typ = packetTypeFIN + } + case flags&pshFlag != 0: + header.typ = packetTypePSHACK + case flags&ackFlag != 0: + header.typ = packetTypeACK + default: + return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags) + } + + header.seq = binary.BigEndian.Uint32(b[4:8]) + header.ack = binary.BigEndian.Uint32(b[8:12]) + return header, nil +} + +type options struct { + mss uint16 + windowScale *uint8 // Pointer to differentiate between 0 and "not present" + sackPermitted bool + timestamps *optionTimestamps +} + +type optionTimestamps struct { + value uint32 + echo uint32 +} + +var ( + errTCPOptionLengthTruncated = errors.New("TCP option length is truncated") + ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid") + ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid") + ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid") + ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid") + errTCPOptionTypeUnknown = errors.New("TCP option type is unknown") +) + +func parseTCPOptions(b []byte) (parsed options, err error) { + i := 0 + for i < len(b) { + optionType := b[i] + + // Handle single-byte options + if optionType == 0 { // End of List + break + } + if optionType == 1 { // No-Operation (Padding) + i++ + continue + } + + // Handle TLV (Type-Length-Value) options + if i+1 >= len(b) { + // This should not happen for DF packets. + return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i) + } + + length := int(b[i+1]) + const minLength = 2 + maxLength := len(b) - i + switch { + case length < minLength: + return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d", + ErrTCPOptionLengthInvalid, optionType, i, length, minLength) + case length > maxLength: + return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d", + ErrTCPOptionLengthInvalid, optionType, i, length, maxLength) + } + + data := b[i+2 : i+length] + + const ( + optionTypeMSS = 2 + optionTypeWindowScale = 3 + optionTypeSACKPermitted = 4 + optionTypeTimestamps = 8 + ) + switch optionType { + case optionTypeMSS: + const expectedLength = 4 + if length != expectedLength { + return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d", + ErrTCPOptionMSSInvalid, i, length, expectedLength) + } + parsed.mss = binary.BigEndian.Uint16(data) + case optionTypeWindowScale: + const expectedLength = 3 + if length != expectedLength { + return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d", + ErrTCPOptionWindowScaleInvalid, i, length, expectedLength) + } + windowScale := data[0] + parsed.windowScale = &windowScale + case optionTypeSACKPermitted: + parsed.sackPermitted = true + case optionTypeTimestamps: + const expectedLength = 10 + if length != expectedLength { + return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d", + ErrTCPOptionTimestampsInvalid, i, length, expectedLength) + } + parsed.timestamps = &optionTimestamps{ + value: binary.BigEndian.Uint32(data[:4]), + echo: binary.BigEndian.Uint32(data[4:]), + } + default: + return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType) + } + + i += length + } + + return parsed, nil }