hotfix(pmtud/tcp): respect MSS from server into account

This commit is contained in:
Quentin McGaw
2026-02-18 18:32:10 +00:00
parent 183d351b58
commit 224618337c
2 changed files with 233 additions and 40 deletions
+34 -15
View File
@@ -58,6 +58,7 @@ var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK") errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value") errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected") errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
errTCPPacketLost = errors.New("TCP packet was lost")
) )
// Craft and send a raw TCP packet to test the MTU. // Craft and send a raw TCP packet to test the MTU.
@@ -95,22 +96,37 @@ func runTest(ctx context.Context, fd fileDescriptor,
case reply = <-ch: case reply = <-ch:
} }
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) firstReplyHeader, err := parseTCPHeader(reply)
switch { switch {
case err != nil: case err != nil:
return fmt.Errorf("parsing first reply TCP header: %w", err) return fmt.Errorf("parsing first reply TCP header: %w", err)
case packetType == packetTypeRST: case firstReplyHeader.typ == packetTypeRST,
firstReplyHeader.typ == packetTypeRSTACK:
// server actively closed the connection, try sending a SYN with data // server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu) return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case packetType != packetTypeSYNACK: case firstReplyHeader.typ != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType) return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, firstReplyHeader.typ)
case synAckAck != synSeq+1: case firstReplyHeader.ack != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck) return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, firstReplyHeader.ack)
}
if firstReplyHeader.options.mss != 0 {
// If the server sent an MSS option, make sure our test packet is not larger than that MSS.
tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength
if tcpDataLength > uint32(firstReplyHeader.options.mss) {
diff := tcpDataLength - uint32(firstReplyHeader.options.mss)
minMTU := constants.MinIPv4MTU
if dst.Addr().Is6() {
minMTU = constants.MinIPv6MTU
}
diff = min(diff, mtu-minMTU)
mtu -= diff
}
} }
// Send an ACK packet to finish the 3-way handshake, together with the // Send an ACK packet to finish the 3-way handshake, together with the
// data to test the MTU, using TCP fast-open. // data to test the MTU, using TCP fast-open.
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu) ackPacket := createACKPacket(src, dst, firstReplyHeader.ack, firstReplyHeader.seq+1, mtu)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr) err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil { if err != nil {
return fmt.Errorf("sending ACK packet: %w", err) return fmt.Errorf("sending ACK packet: %w", err)
@@ -122,23 +138,25 @@ func runTest(ctx context.Context, fd fileDescriptor,
case reply = <-ch: case reply = <-ch:
} }
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) finalPacketHeader, err := parseTCPHeader(reply)
if err != nil { if err != nil {
return fmt.Errorf("parsing second reply TCP header: %w", err) return fmt.Errorf("parsing second reply TCP header: %w", err)
} }
switch packetType { //nolint:exhaustive switch finalPacketHeader.typ { //nolint:exhaustive
case packetTypeRST: case packetTypeRST:
return nil return nil
case packetTypeACK: case packetTypeACK:
err = sendRST(fd, src, dst, ack) err = sendRST(fd, src, dst, finalPacketHeader.ack)
if err != nil { if err != nil {
return fmt.Errorf("sending RST packet: %w", err) return fmt.Errorf("sending RST packet: %w", err)
} }
return nil return nil
case packetTypeSYNACK: // server never received our MTU-test ACK packet
return fmt.Errorf("%w: server responded with second SYN-ACK packet", errTCPPacketLost)
default: default:
_ = sendRST(fd, src, dst, ack) _ = sendRST(fd, src, dst, finalPacketHeader.ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType) return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, finalPacketHeader.typ)
} }
} }
@@ -161,11 +179,12 @@ func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
case reply = <-ch: case reply = <-ch:
} }
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) replyPacketHeader, err := parseTCPHeader(reply)
if err != nil { if err != nil {
return fmt.Errorf("parsing reply TCP header: %w", err) return fmt.Errorf("parsing reply TCP header: %w", err)
} else if packetType != packetTypeRST { } else if replyPacketHeader.typ != packetTypeRST &&
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType) replyPacketHeader.typ != packetTypeRSTACK {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, replyPacketHeader.typ)
} }
return nil return nil
} }
+199 -25
View File
@@ -64,10 +64,14 @@ func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
const ( const (
tcpFlagsOffset = 13 tcpFlagsOffset = 13
rstFlag byte = 0x04 finFlag byte = 0x01
synFlag byte = 0x02 synFlag byte = 0x02
ackFlag byte = 0x10 rstFlag byte = 0x04
pshFlag byte = 0x08 pshFlag byte = 0x08
ackFlag byte = 0x10
urgFlag byte = 0x20
eceFlag byte = 0x40
cwrFlag byte = 0x80
) )
type packetType uint8 type packetType uint8
@@ -75,8 +79,12 @@ type packetType uint8
const ( const (
packetTypeSYN packetType = iota + 1 packetTypeSYN packetType = iota + 1
packetTypeSYNACK packetTypeSYNACK
packetTypeACK packetTypeFIN
packetTypeFINACK
packetTypeRST packetTypeRST
packetTypeRSTACK
packetTypePSHACK
packetTypeACK
) )
func (p packetType) String() string { func (p packetType) String() string {
@@ -85,40 +93,206 @@ func (p packetType) String() string {
return "SYN" return "SYN"
case packetTypeSYNACK: case packetTypeSYNACK:
return "SYN-ACK" return "SYN-ACK"
case packetTypeACK: case packetTypeFIN:
return "ACK" return "FIN"
case packetTypeFINACK:
return "FIN-ACK"
case packetTypeRST: case packetTypeRST:
return "RST" return "RST"
case packetTypeRSTACK:
return "RST-ACK"
case packetTypePSHACK:
return "PSH-ACK"
case packetTypeACK:
return "ACK"
default: default:
panic("unknown packet type") panic("unknown packet type")
} }
} }
type tcpHeader struct {
typ packetType
srcPort uint16
dstPort uint16
seq uint32
ack uint32
dataOffset uint8
flags uint8
windowSize uint16
checksum uint16
urgentPtr uint16
options options
}
var ( var (
errTCPHeaderTooShort = errors.New("TCP header is too short") errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPHeaderDataOffset = errors.New("TCP header data offset is invalid")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown") errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
) )
// parseTCPHeader parses some elements from the TCP header. // parseTCPHeader parses the TCP header from b.
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) { // b should be the entire TCP packet bytes.
if len(header) < int(constants.BaseTCPHeaderLength) { func parseTCPHeader(b []byte) (header tcpHeader, err error) {
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header)) if len(b) < int(constants.BaseTCPHeaderLength) {
} return tcpHeader{}, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(b))
flags := header[tcpFlagsOffset]
switch {
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
packetType = packetTypeSYN
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
packetType = packetTypeSYNACK
case (flags & rstFlag) != 0:
packetType = packetTypeRST
case (flags & ackFlag) != 0:
packetType = packetTypeACK
default:
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
} }
seq = binary.BigEndian.Uint32(header[4:8]) header.srcPort = binary.BigEndian.Uint16(b[0:2])
ack = binary.BigEndian.Uint32(header[8:12]) header.dstPort = binary.BigEndian.Uint16(b[2:4])
return packetType, seq, ack, nil header.seq = binary.BigEndian.Uint32(b[4:8])
header.ack = binary.BigEndian.Uint32(b[8:12])
// upper 4 bits of the 12th byte
header.dataOffset = (b[12] >> 4) * 4 //nolint:mnd
header.flags = b[13]
header.windowSize = binary.BigEndian.Uint16(b[14:16])
header.checksum = binary.BigEndian.Uint16(b[16:18])
header.urgentPtr = binary.BigEndian.Uint16(b[18:20])
switch {
case uint32(header.dataOffset) < constants.BaseTCPHeaderLength:
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, expected at least %d bytes",
errTCPHeaderDataOffset, header.dataOffset, constants.BaseTCPHeaderLength)
case int(header.dataOffset) > len(b):
return tcpHeader{}, fmt.Errorf("%w: data offset is %d bytes, but packet is only %d bytes",
errTCPHeaderDataOffset, header.dataOffset, len(b))
}
if uint32(header.dataOffset) > constants.BaseTCPHeaderLength {
optionsBytes := b[constants.BaseTCPHeaderLength:header.dataOffset]
header.options, err = parseTCPOptions(optionsBytes)
if err != nil {
return tcpHeader{}, fmt.Errorf("parsing TCP options: %w", err)
}
}
flags := header.flags
switch {
case flags&synFlag != 0:
if flags&ackFlag != 0 {
header.typ = packetTypeSYNACK
} else {
header.typ = packetTypeSYN
}
case flags&rstFlag != 0:
if flags&ackFlag != 0 {
header.typ = packetTypeRSTACK
} else {
header.typ = packetTypeRST
}
case flags&finFlag != 0:
if flags&ackFlag != 0 {
header.typ = packetTypeFINACK
} else {
header.typ = packetTypeFIN
}
case flags&pshFlag != 0:
header.typ = packetTypePSHACK
case flags&ackFlag != 0:
header.typ = packetTypeACK
default:
return tcpHeader{}, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
}
header.seq = binary.BigEndian.Uint32(b[4:8])
header.ack = binary.BigEndian.Uint32(b[8:12])
return header, nil
}
type options struct {
mss uint16
windowScale *uint8 // Pointer to differentiate between 0 and "not present"
sackPermitted bool
timestamps *optionTimestamps
}
type optionTimestamps struct {
value uint32
echo uint32
}
var (
errTCPOptionLengthTruncated = errors.New("TCP option length is truncated")
ErrTCPOptionLengthInvalid = errors.New("TCP option length is invalid")
ErrTCPOptionMSSInvalid = errors.New("TCP option MSS value is invalid")
ErrTCPOptionWindowScaleInvalid = errors.New("TCP option Window Scale value is invalid")
ErrTCPOptionTimestampsInvalid = errors.New("TCP option Timestamps value is invalid")
errTCPOptionTypeUnknown = errors.New("TCP option type is unknown")
)
func parseTCPOptions(b []byte) (parsed options, err error) {
i := 0
for i < len(b) {
optionType := b[i]
// Handle single-byte options
if optionType == 0 { // End of List
break
}
if optionType == 1 { // No-Operation (Padding)
i++
continue
}
// Handle TLV (Type-Length-Value) options
if i+1 >= len(b) {
// This should not happen for DF packets.
return options{}, fmt.Errorf("%w: at offset %d", errTCPOptionLengthTruncated, i)
}
length := int(b[i+1])
const minLength = 2
maxLength := len(b) - i
switch {
case length < minLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d < %d",
ErrTCPOptionLengthInvalid, optionType, i, length, minLength)
case length > maxLength:
return options{}, fmt.Errorf("%w: type %d at offset %d has length %d > %d",
ErrTCPOptionLengthInvalid, optionType, i, length, maxLength)
}
data := b[i+2 : i+length]
const (
optionTypeMSS = 2
optionTypeWindowScale = 3
optionTypeSACKPermitted = 4
optionTypeTimestamps = 8
)
switch optionType {
case optionTypeMSS:
const expectedLength = 4
if length != expectedLength {
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
ErrTCPOptionMSSInvalid, i, length, expectedLength)
}
parsed.mss = binary.BigEndian.Uint16(data)
case optionTypeWindowScale:
const expectedLength = 3
if length != expectedLength {
return options{}, fmt.Errorf("%w: window scale option at offset %d has length %d, expected %d",
ErrTCPOptionWindowScaleInvalid, i, length, expectedLength)
}
windowScale := data[0]
parsed.windowScale = &windowScale
case optionTypeSACKPermitted:
parsed.sackPermitted = true
case optionTypeTimestamps:
const expectedLength = 10
if length != expectedLength {
return options{}, fmt.Errorf("%w: timestamps option at offset %d has length %d, expected %d",
ErrTCPOptionTimestampsInvalid, i, length, expectedLength)
}
parsed.timestamps = &optionTimestamps{
value: binary.BigEndian.Uint32(data[:4]),
echo: binary.BigEndian.Uint32(data[4:]),
}
default:
return options{}, fmt.Errorf("%w: type %d", errTCPOptionTypeUnknown, optionType)
}
i += length
}
return parsed, nil
} }