vmm/dumbo/pdu/
tcp.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Contains support for parsing and writing TCP segments.
5//!
6//! [Here]'s a useful depiction of the TCP header layout (watch out for the MSB 0 bit numbering.)
7//!
8//! [Here]: https://en.wikipedia.org/wiki/Transmission_Control_Protocol#TCP_segment_structure
9
10use std::cmp::min;
11use std::fmt::Debug;
12use std::net::Ipv4Addr;
13use std::num::NonZeroU16;
14
15use bitflags::bitflags;
16
17use super::Incomplete;
18use super::bytes::{InnerBytes, NetworkBytes, NetworkBytesMut};
19use crate::dumbo::ByteBuffer;
20use crate::dumbo::pdu::ChecksumProto;
21
22const SOURCE_PORT_OFFSET: usize = 0;
23const DESTINATION_PORT_OFFSET: usize = 2;
24const SEQ_NUMBER_OFFSET: usize = 4;
25const ACK_NUMBER_OFFSET: usize = 8;
26const DATAOFF_RSVD_NS_OFFSET: usize = 12;
27const FLAGS_AFTER_NS_OFFSET: usize = 13;
28const WINDOW_SIZE_OFFSET: usize = 14;
29const CHECKSUM_OFFSET: usize = 16;
30const URG_POINTER_OFFSET: usize = 18;
31
32const OPTIONS_OFFSET: u8 = 20;
33
34const MAX_HEADER_LEN: u8 = 60;
35
36const OPTION_KIND_EOL: u8 = 0x00;
37const OPTION_KIND_NOP: u8 = 0x01;
38const OPTION_KIND_MSS: u8 = 0x02;
39
40const OPTION_LEN_MSS: u8 = 0x04;
41
42// An arbitrarily chosen value, used for sanity checks.
43const MSS_MIN: u16 = 100;
44
45bitflags! {
46    /// Represents the TCP header flags, with the exception of `NS`.
47    ///
48    /// These values are only valid in conjunction with the [`flags_after_ns()`] method (and its
49    /// associated setter method), which operates on the header byte containing every other flag
50    /// besides `NS`.
51    ///
52    /// [`flags_after_ns()`]: struct.TcpSegment.html#method.flags_after_ns
53    #[derive(Debug, Copy, Clone, PartialEq)]
54    pub struct Flags: u8 {
55        /// Congestion window reduced.
56        const CWR = 1 << 7;
57        /// ECN-echo.
58        const ECE = 1 << 6;
59        /// Urgent pointer.
60        const URG = 1 << 5;
61        /// The acknowledgement number field is valid.
62        const ACK = 1 << 4;
63        /// Push flag.
64        const PSH = 1 << 3;
65        /// Reset the connection.
66        const RST = 1 << 2;
67        /// SYN flag.
68        const SYN = 1 << 1;
69        /// FIN flag.
70        const FIN = 1;
71    }
72}
73
74/// Describes the errors which may occur while handling TCP segments.
75#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
76pub enum TcpError {
77    /// Invalid checksum.
78    Checksum,
79    /// A payload has been specified for the segment, but the maximum readable length is 0.
80    EmptyPayload,
81    /// Invalid header length.
82    HeaderLen,
83    /// The MSS option contains an invalid value.
84    MssOption,
85    /// The remaining segment length cannot accommodate the MSS option.
86    MssRemaining,
87    /// The specified slice is shorter than the header length.
88    SliceTooShort,
89}
90
91// TODO: The implementation of TcpSegment is IPv4 specific in regard to checksum computation. Maybe
92// make it more generic at some point.
93
94/// Interprets the inner bytes as a TCP segment.
95#[derive(Debug)]
96pub struct TcpSegment<'a, T: 'a> {
97    bytes: InnerBytes<'a, T>,
98}
99
100#[allow(clippy::len_without_is_empty)]
101impl<T: NetworkBytes + Debug> TcpSegment<'_, T> {
102    /// Returns the source port.
103    #[inline]
104    pub fn source_port(&self) -> u16 {
105        self.bytes.ntohs_unchecked(SOURCE_PORT_OFFSET)
106    }
107
108    /// Returns the destination port.
109    #[inline]
110    pub fn destination_port(&self) -> u16 {
111        self.bytes.ntohs_unchecked(DESTINATION_PORT_OFFSET)
112    }
113
114    /// Returns the sequence number.
115    #[inline]
116    pub fn sequence_number(&self) -> u32 {
117        self.bytes.ntohl_unchecked(SEQ_NUMBER_OFFSET)
118    }
119
120    /// Returns the acknowledgement number (only valid if the `ACK` flag is set).
121    #[inline]
122    pub fn ack_number(&self) -> u32 {
123        self.bytes.ntohl_unchecked(ACK_NUMBER_OFFSET)
124    }
125
126    /// Returns the header length, the value of the reserved bits, and whether the `NS` flag
127    /// is set or not.
128    #[inline]
129    pub fn header_len_rsvd_ns(&self) -> (u8, u8, bool) {
130        let value = self.bytes[DATAOFF_RSVD_NS_OFFSET];
131        let data_offset = value >> 4;
132        let header_len = data_offset * 4;
133        let rsvd = value & 0x0e;
134        let ns = (value & 1) != 0;
135        (header_len, rsvd, ns)
136    }
137
138    /// Returns the length of the header.
139    #[inline]
140    pub fn header_len(&self) -> u8 {
141        self.header_len_rsvd_ns().0
142    }
143
144    /// Returns the TCP header flags, with the exception of `NS`.
145    #[inline]
146    pub fn flags_after_ns(&self) -> Flags {
147        Flags::from_bits_truncate(self.bytes[FLAGS_AFTER_NS_OFFSET])
148    }
149
150    /// Returns the value of the `window size` header field.
151    #[inline]
152    pub fn window_size(&self) -> u16 {
153        self.bytes.ntohs_unchecked(WINDOW_SIZE_OFFSET)
154    }
155
156    /// Returns the value of the `checksum` header field.
157    #[inline]
158    pub fn checksum(&self) -> u16 {
159        self.bytes.ntohs_unchecked(CHECKSUM_OFFSET)
160    }
161
162    /// Returns the value of the `urgent pointer` header field (only valid if the
163    /// `URG` flag is set).
164    #[inline]
165    pub fn urgent_pointer(&self) -> u16 {
166        self.bytes.ntohs_unchecked(URG_POINTER_OFFSET)
167    }
168
169    /// Returns the TCP header options as an `[&u8]` slice.
170    ///
171    /// # Panics
172    ///
173    /// This method may panic if the value of `header_len` is invalid.
174    #[inline]
175    pub fn options_unchecked(&self, header_len: usize) -> &[u8] {
176        &self.bytes[usize::from(OPTIONS_OFFSET)..header_len]
177    }
178
179    /// Returns a slice which contains the payload of the segment. May panic if the value of
180    /// `header_len` is invalid.
181    ///
182    /// # Panics
183    ///
184    /// This method may panic if the value of `header_len` is invalid.
185    #[inline]
186    pub fn payload_unchecked(&self, header_len: usize) -> &[u8] {
187        self.bytes.split_at(header_len).1
188    }
189
190    /// Returns the length of the segment.
191    #[inline]
192    pub fn len(&self) -> u16 {
193        // NOTE: This appears to be a safe conversion in all current cases.
194        // Packets are always set up in the context of an Ipv4Packet, which is
195        // capped at a u16 size. However, I'd rather be safe here.
196        u16::try_from(self.bytes.len()).unwrap_or(u16::MAX)
197    }
198
199    /// Returns a slice which contains the payload of the segment.
200    #[inline]
201    pub fn payload(&self) -> &[u8] {
202        self.payload_unchecked(self.header_len().into())
203    }
204
205    /// Returns the length of the payload.
206    #[inline]
207    pub fn payload_len(&self) -> u16 {
208        self.len() - u16::from(self.header_len())
209    }
210
211    /// Computes the TCP checksum of the segment. More details about TCP checksum computation can
212    /// be found [here].
213    ///
214    /// [here]: https://en.wikipedia.org/wiki/Transmission_Control_Protocol#Checksum_computation
215    pub fn compute_checksum(&self, src_addr: Ipv4Addr, dst_addr: Ipv4Addr) -> u16 {
216        crate::dumbo::pdu::compute_checksum(&self.bytes, src_addr, dst_addr, ChecksumProto::Tcp)
217    }
218
219    /// Parses TCP header options (only `MSS` is supported for now).
220    ///
221    /// If no error is encountered, returns the `MSS` value, or `None` if the option is not
222    /// present.
223    ///
224    /// # Panics
225    ///
226    /// This method may panic if the value of `header_len` is invalid.
227    pub fn parse_mss_option_unchecked(
228        &self,
229        header_len: usize,
230    ) -> Result<Option<NonZeroU16>, TcpError> {
231        let b = self.options_unchecked(header_len);
232        let mut i = 0;
233
234        // All TCP options (except EOL and NOP) are encoded using x bytes (x >= 2), where the first
235        // byte represents the option kind, the second is the option length (including these first
236        // two bytes), and finally the next x - 2 bytes represent option data. The length of
237        // the MSS option is 4, so the option data encodes an u16 in network order.
238
239        // The MSS option is 4 bytes wide, so we need at least 4 more bytes to look for it.
240        while i + 3 < b.len() {
241            match b[i] {
242                OPTION_KIND_EOL => break,
243                OPTION_KIND_NOP => {
244                    i += 1;
245                    continue;
246                }
247                OPTION_KIND_MSS => {
248                    // Read from option data (we skip checking if the len is valid).
249                    // TODO: To be super strict, we should make sure there aren't additional MSS
250                    // options present (which would be super wrong). Should we be super strict?
251                    let mss = b.ntohs_unchecked(i + 2);
252                    if mss < MSS_MIN {
253                        return Err(TcpError::MssOption);
254                    }
255                    // The unwarp() is safe because mms >= MSS_MIN at this point.
256                    return Ok(Some(NonZeroU16::new(mss).unwrap()));
257                }
258                _ => {
259                    // Some other option; just skip opt_len bytes in total.
260                    i += b[i + 1] as usize;
261                    continue;
262                }
263            }
264        }
265        Ok(None)
266    }
267
268    /// Interprets `bytes` as a TCP segment without any validity checks.
269    ///
270    /// # Panics
271    ///
272    /// This method does not panic, but further method calls on the resulting object may panic if
273    /// `bytes` contains invalid input.
274    #[inline]
275    pub fn from_bytes_unchecked(bytes: T) -> Self {
276        TcpSegment {
277            bytes: InnerBytes::new(bytes),
278        }
279    }
280
281    /// Attempts to interpret `bytes` as a TCP segment, checking the validity of the header fields.
282    ///
283    /// The `verify_checksum` parameter must contain the source and destination addresses from the
284    /// enclosing IPv4 packet if the TCP checksum must be validated.
285    #[inline]
286    pub fn from_bytes(
287        bytes: T,
288        verify_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
289    ) -> Result<Self, TcpError> {
290        if bytes.len() < usize::from(OPTIONS_OFFSET) {
291            return Err(TcpError::SliceTooShort);
292        }
293
294        let segment = Self::from_bytes_unchecked(bytes);
295
296        // We skip checking if the reserved bits are 0b000 (and a couple of other things).
297
298        let header_len = segment.header_len();
299
300        if header_len < OPTIONS_OFFSET
301            || u16::from(header_len) > min(u16::from(MAX_HEADER_LEN), segment.len())
302        {
303            return Err(TcpError::HeaderLen);
304        }
305
306        if let Some((src_addr, dst_addr)) = verify_checksum
307            && segment.compute_checksum(src_addr, dst_addr) != 0
308        {
309            return Err(TcpError::Checksum);
310        }
311
312        Ok(segment)
313    }
314}
315
316impl<T: NetworkBytesMut + Debug> TcpSegment<'_, T> {
317    /// Sets the source port.
318    #[inline]
319    pub fn set_source_port(&mut self, value: u16) -> &mut Self {
320        self.bytes.htons_unchecked(SOURCE_PORT_OFFSET, value);
321        self
322    }
323
324    /// Sets the destination port.
325    #[inline]
326    pub fn set_destination_port(&mut self, value: u16) -> &mut Self {
327        self.bytes.htons_unchecked(DESTINATION_PORT_OFFSET, value);
328        self
329    }
330
331    /// Sets the value of the sequence number field.
332    #[inline]
333    pub fn set_sequence_number(&mut self, value: u32) -> &mut Self {
334        self.bytes.htonl_unchecked(SEQ_NUMBER_OFFSET, value);
335        self
336    }
337
338    /// Sets the value of the acknowledgement number field.
339    #[inline]
340    pub fn set_ack_number(&mut self, value: u32) -> &mut Self {
341        self.bytes.htonl_unchecked(ACK_NUMBER_OFFSET, value);
342        self
343    }
344
345    /// Sets the value of the `ihl` header field based on `header_len` (which should be a multiple
346    /// of 4), clears the reserved bits, and sets the `NS` flag according to the last parameter.
347    // TODO: Check that header_len | 0b11 == 0 and the resulting data_offset is valid?
348    #[inline]
349    pub fn set_header_len_rsvd_ns(&mut self, header_len: u8, ns: bool) -> &mut Self {
350        let mut value = header_len << 2;
351        if ns {
352            value |= 1;
353        }
354        self.bytes[DATAOFF_RSVD_NS_OFFSET] = value;
355        self
356    }
357
358    /// Sets the value of the header byte containing every TCP flag except `NS`.
359    #[inline]
360    pub fn set_flags_after_ns(&mut self, flags: Flags) -> &mut Self {
361        self.bytes[FLAGS_AFTER_NS_OFFSET] = flags.bits();
362        self
363    }
364
365    /// Sets the value of the `window size` field.
366    #[inline]
367    pub fn set_window_size(&mut self, value: u16) -> &mut Self {
368        self.bytes.htons_unchecked(WINDOW_SIZE_OFFSET, value);
369        self
370    }
371
372    /// Sets the value of the `checksum` field.
373    #[inline]
374    pub fn set_checksum(&mut self, value: u16) -> &mut Self {
375        self.bytes.htons_unchecked(CHECKSUM_OFFSET, value);
376        self
377    }
378
379    /// Sets the value of the `urgent pointer` field.
380    #[inline]
381    pub fn set_urgent_pointer(&mut self, value: u16) -> &mut Self {
382        self.bytes.htons_unchecked(URG_POINTER_OFFSET, value);
383        self
384    }
385
386    /// Returns a mutable slice containing the segment payload.
387    ///
388    /// # Panics
389    ///
390    /// This method may panic if the value of `header_len` is invalid.
391    #[inline]
392    pub fn payload_mut_unchecked(&mut self, header_len: usize) -> &mut [u8] {
393        self.bytes.split_at_mut(header_len).1
394    }
395
396    /// Returns a mutable slice containing the segment payload.
397    #[inline]
398    pub fn payload_mut(&mut self) -> &mut [u8] {
399        let header_len = self.header_len();
400        self.payload_mut_unchecked(header_len.into())
401    }
402
403    /// Writes a complete TCP segment.
404    ///
405    /// # Arguments
406    ///
407    /// * `buf` - Write the segment to this buffer.
408    /// * `src_port` - Source port.
409    /// * `dst_port` - Destination port.
410    /// * `seq_number` - Sequence number.
411    /// * `ack_number` - Acknowledgement number.
412    /// * `flags_after_ns` - TCP flags to set (except `NS`, which is always set to 0).
413    /// * `window_size` - Value to write in the `window size` field.
414    /// * `mss_option` - When a value is specified, use it to add a TCP MSS option to the header.
415    /// * `mss_remaining` - Represents an upper bound on the payload length (the number of bytes
416    ///   used up by things like IP options have to be subtracted from the MSS). There is some
417    ///   redundancy looking at this argument and the next one, so we might end up removing or
418    ///   changing something.
419    /// * `payload` - May contain a buffer which holds payload data and the maximum amount of bytes
420    ///   we should read from that buffer. When `None`, the TCP segment will carry no payload.
421    /// * `compute_checksum` - May contain the pair addresses from the enclosing IPv4 packet, which
422    ///   are required for TCP checksum computation. Skip the checksum altogether when `None`.
423    #[allow(clippy::too_many_arguments)]
424    #[inline]
425    pub fn write_segment<R: ByteBuffer + ?Sized + Debug>(
426        buf: T,
427        src_port: u16,
428        dst_port: u16,
429        seq_number: u32,
430        ack_number: u32,
431        flags_after_ns: Flags,
432        window_size: u16,
433        mss_option: Option<u16>,
434        mss_remaining: u16,
435        payload: Option<(&R, usize)>,
436        compute_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
437    ) -> Result<Self, TcpError> {
438        Ok(Self::write_incomplete_segment(
439            buf,
440            seq_number,
441            ack_number,
442            flags_after_ns,
443            window_size,
444            mss_option,
445            mss_remaining,
446            payload,
447        )?
448        .finalize(src_port, dst_port, compute_checksum))
449    }
450
451    /// Writes an incomplete TCP segment, which is missing the `source port`, `destination port`,
452    /// and `checksum` fields.
453    ///
454    /// This method writes the rest of the segment, including data (when available). Only the `MSS`
455    /// option is supported for now. The `NS` flag, `URG` flag, and `urgent pointer` field are set
456    /// to 0.
457    ///
458    /// # Arguments
459    ///
460    /// * `buf` - Write the segment to this buffer.
461    /// * `seq_number` - Sequence number.
462    /// * `ack_number` - Acknowledgement number.
463    /// * `flags_after_ns` - TCP flags to set (except `NS`, which is always set to 0).
464    /// * `window_size` - Value to write in the `window size` field.
465    /// * `mss_option` - When a value is specified, use it to add a TCP MSS option to the header.
466    /// * `mss_remaining` - Represents an upper bound on the payload length (the number of bytes
467    ///   used up by things like IP options have to be subtracted from the MSS). There is some
468    ///   redundancy looking at this argument and the next one, so we might end up removing or
469    ///   changing something.
470    /// * `payload` - May contain a buffer which holds payload data and the maximum amount of bytes
471    ///   we should read from that buffer. When `None`, the TCP segment will carry no payload.
472    // Marked inline because a lot of code vanishes after constant folding when
473    // we don't add TCP options, or when mss_remaining is actually a constant, etc.
474    #[allow(clippy::too_many_arguments)]
475    #[inline]
476    pub fn write_incomplete_segment<R: ByteBuffer + ?Sized + Debug>(
477        buf: T,
478        seq_number: u32,
479        ack_number: u32,
480        flags_after_ns: Flags,
481        window_size: u16,
482        mss_option: Option<u16>,
483        mss_remaining: u16,
484        payload: Option<(&R, usize)>,
485    ) -> Result<Incomplete<Self>, TcpError> {
486        let mut mss_left = mss_remaining;
487
488        // We're going to need at least this many bytes.
489        let mut segment_len = u16::from(OPTIONS_OFFSET);
490
491        // The TCP options will require this much more bytes.
492        let options_len = if mss_option.is_some() {
493            mss_left = mss_left
494                .checked_sub(OPTION_LEN_MSS.into())
495                .ok_or(TcpError::MssRemaining)?;
496            OPTION_LEN_MSS
497        } else {
498            0
499        };
500
501        segment_len += u16::from(options_len);
502
503        if buf.len() < usize::from(segment_len) {
504            return Err(TcpError::SliceTooShort);
505        }
506
507        // The unchecked call is safe because buf.len() >= segment_len.
508        let mut segment = Self::from_bytes_unchecked(buf);
509
510        segment
511            .set_sequence_number(seq_number)
512            .set_ack_number(ack_number)
513            .set_header_len_rsvd_ns(OPTIONS_OFFSET + options_len, false)
514            .set_flags_after_ns(flags_after_ns)
515            .set_window_size(window_size)
516            .set_urgent_pointer(0);
517
518        // Let's write the MSS option if we have to.
519        if let Some(value) = mss_option {
520            segment.bytes[usize::from(OPTIONS_OFFSET)] = OPTION_KIND_MSS;
521            segment.bytes[usize::from(OPTIONS_OFFSET) + 1] = OPTION_LEN_MSS;
522            segment
523                .bytes
524                .htons_unchecked(usize::from(OPTIONS_OFFSET) + 2, value);
525        }
526
527        let payload_bytes_count = if let Some((payload_buf, max_payload_bytes)) = payload {
528            let left_to_read = min(payload_buf.len(), max_payload_bytes);
529
530            // The subtraction makes sense because we previously checked that
531            // buf.len() >= segment_len.
532            let mut room_for_payload = min(segment.len() - segment_len, mss_left);
533            // The unwrap is safe because room_for_payload is a u16.
534            room_for_payload =
535                u16::try_from(min(usize::from(room_for_payload), left_to_read)).unwrap();
536
537            if room_for_payload == 0 {
538                return Err(TcpError::EmptyPayload);
539            }
540
541            // Copy `room_for_payload` bytes into `payload_buf` using `offset=0`.
542            // Guaranteed not to panic since we checked above that:
543            // `offset + room_for_payload <= payload_buf.len()`.
544            payload_buf.read_to_slice(
545                0,
546                &mut segment.bytes
547                    [usize::from(segment_len)..usize::from(segment_len + room_for_payload)],
548            );
549            room_for_payload
550        } else {
551            0
552        };
553        segment_len += payload_bytes_count;
554
555        // This is ok because segment_len <= buf.len().
556        segment.bytes.shrink_unchecked(segment_len.into());
557
558        // Shrink the resulting segment to a slice of exact size, so using self.len() makes sense.
559        Ok(Incomplete::new(segment))
560    }
561}
562
563impl<'a, T: NetworkBytesMut + Debug> Incomplete<TcpSegment<'a, T>> {
564    /// Transforms `self` into a `TcpSegment<T>` by specifying values for the `source port`,
565    /// `destination port`, and (optionally) the information required to compute the TCP checksum.
566    #[inline]
567    pub fn finalize(
568        mut self,
569        src_port: u16,
570        dst_port: u16,
571        compute_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
572    ) -> TcpSegment<'a, T> {
573        self.inner.set_source_port(src_port);
574        self.inner.set_destination_port(dst_port);
575        if let Some((src_addr, dst_addr)) = compute_checksum {
576            // Set this to 0 first.
577            self.inner.set_checksum(0);
578            let checksum = self.inner.compute_checksum(src_addr, dst_addr);
579            self.inner.set_checksum(checksum);
580        }
581        self.inner
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_set_get() {
591        let mut a = [0u8; 100];
592        let mut p = TcpSegment::from_bytes_unchecked(a.as_mut());
593
594        assert_eq!(p.source_port(), 0);
595        p.set_source_port(123);
596        assert_eq!(p.source_port(), 123);
597
598        assert_eq!(p.destination_port(), 0);
599        p.set_destination_port(322);
600        assert_eq!(p.destination_port(), 322);
601
602        assert_eq!(p.sequence_number(), 0);
603        p.set_sequence_number(1_234_567);
604        assert_eq!(p.sequence_number(), 1_234_567);
605
606        assert_eq!(p.ack_number(), 0);
607        p.set_ack_number(345_234);
608        assert_eq!(p.ack_number(), 345_234);
609
610        assert_eq!(p.header_len_rsvd_ns(), (0, 0, false));
611        assert_eq!(p.header_len(), 0);
612        // Header_len must be a multiple of 4 here to be valid.
613        let header_len = 60;
614        p.set_header_len_rsvd_ns(header_len, true);
615        assert_eq!(p.header_len_rsvd_ns(), (header_len, 0, true));
616        assert_eq!(p.header_len(), header_len);
617
618        assert_eq!(p.flags_after_ns().bits(), 0);
619        p.set_flags_after_ns(Flags::SYN | Flags::URG);
620        assert_eq!(p.flags_after_ns(), Flags::SYN | Flags::URG);
621
622        assert_eq!(p.window_size(), 0);
623        p.set_window_size(60000);
624        assert_eq!(p.window_size(), 60000);
625
626        assert_eq!(p.checksum(), 0);
627        p.set_checksum(4321);
628        assert_eq!(p.checksum(), 4321);
629
630        assert_eq!(p.urgent_pointer(), 0);
631        p.set_urgent_pointer(5554);
632        assert_eq!(p.urgent_pointer(), 5554);
633    }
634
635    #[test]
636    fn test_constructors() {
637        let mut a = [1u8; 1460];
638        let b = [2u8; 1000];
639        let c = [3u8; 2000];
640
641        let src_addr = Ipv4Addr::new(10, 1, 2, 3);
642        let dst_addr = Ipv4Addr::new(192, 168, 44, 77);
643        let src_port = 1234;
644        let dst_port = 5678;
645        let seq_number = 11_111_222;
646        let ack_number = 34_566_543;
647        let flags_after_ns = Flags::SYN | Flags::RST;
648        let window_size = 19999;
649        let mss_left = 1460;
650        let mss_option = Some(mss_left);
651        let payload = Some((b.as_ref(), b.len()));
652
653        let header_len = OPTIONS_OFFSET + OPTION_LEN_MSS;
654
655        let segment_len = {
656            let mut segment = TcpSegment::write_segment(
657                a.as_mut(),
658                src_port,
659                dst_port,
660                seq_number,
661                ack_number,
662                flags_after_ns,
663                window_size,
664                mss_option,
665                mss_left,
666                payload,
667                Some((src_addr, dst_addr)),
668            )
669            .unwrap();
670
671            assert_eq!(segment.source_port(), src_port);
672            assert_eq!(segment.destination_port(), dst_port);
673            assert_eq!(segment.sequence_number(), seq_number);
674            assert_eq!(segment.ack_number(), ack_number);
675            assert_eq!(segment.header_len_rsvd_ns(), (header_len, 0, false));
676            assert_eq!(segment.flags_after_ns(), flags_after_ns);
677            assert_eq!(segment.window_size(), window_size);
678
679            let checksum = segment.checksum();
680            segment.set_checksum(0);
681            let computed_checksum = segment.compute_checksum(src_addr, dst_addr);
682            assert_eq!(checksum, computed_checksum);
683
684            segment.set_checksum(checksum);
685            assert_eq!(segment.compute_checksum(src_addr, dst_addr), 0);
686
687            assert_eq!(segment.urgent_pointer(), 0);
688
689            {
690                let options = segment.options_unchecked(header_len.into());
691                assert_eq!(options.len(), usize::from(OPTION_LEN_MSS));
692                assert_eq!(options[0], OPTION_KIND_MSS);
693                assert_eq!(options[1], OPTION_LEN_MSS);
694                assert_eq!(options.ntohs_unchecked(2), mss_left);
695            }
696
697            // Payload was smaller than mss_left after options.
698            assert_eq!(
699                usize::from(segment.len()),
700                usize::from(header_len) + b.len(),
701            );
702            segment.len()
703            // Mutable borrow of a goes out of scope.
704        };
705
706        {
707            let segment =
708                TcpSegment::from_bytes(&a[..segment_len.into()], Some((src_addr, dst_addr)))
709                    .unwrap();
710            assert_eq!(
711                segment.parse_mss_option_unchecked(header_len.into()),
712                Ok(Some(NonZeroU16::new(mss_left).unwrap()))
713            );
714        }
715
716        // Let's quickly see what happens when the payload buf is larger than our mutable slice.
717        {
718            let segment_len = TcpSegment::write_segment(
719                a.as_mut(),
720                src_port,
721                dst_port,
722                seq_number,
723                ack_number,
724                flags_after_ns,
725                window_size,
726                mss_option,
727                mss_left,
728                Some((c.as_ref(), c.len())),
729                Some((src_addr, dst_addr)),
730            )
731            .unwrap()
732            .len();
733
734            assert_eq!(segment_len, mss_left);
735        }
736
737        // Now let's test the error value for from_bytes().
738
739        // Using a helper function here instead of a closure because it's hard (impossible?) to
740        // specify lifetime bounds for closure arguments.
741        fn p(buf: &mut [u8]) -> TcpSegment<'_, &mut [u8]> {
742            TcpSegment::from_bytes_unchecked(buf)
743        }
744
745        // Just a helper closure.
746        let look_for_error = |buf: &[u8], err: TcpError| {
747            assert_eq!(
748                TcpSegment::from_bytes(buf, Some((src_addr, dst_addr))).unwrap_err(),
749                err
750            );
751        };
752
753        // Header length too short.
754        p(a.as_mut()).set_header_len_rsvd_ns(OPTIONS_OFFSET.checked_sub(1).unwrap(), false);
755        look_for_error(a.as_ref(), TcpError::HeaderLen);
756
757        // Header length too large.
758        p(a.as_mut()).set_header_len_rsvd_ns(MAX_HEADER_LEN.checked_add(4).unwrap(), false);
759        look_for_error(a.as_ref(), TcpError::HeaderLen);
760
761        // The previously set checksum should be valid.
762        assert_eq!(
763            p(a.as_mut())
764                .set_header_len_rsvd_ns(header_len, false)
765                .compute_checksum(src_addr, dst_addr),
766            0
767        );
768
769        // Let's make it invalid.
770        let checksum = p(a.as_mut()).checksum();
771        p(a.as_mut()).set_checksum(checksum.wrapping_add(1));
772        look_for_error(a.as_ref(), TcpError::Checksum);
773
774        // Now we use a very small buffer.
775        let mut small_buf = [0u8; 1];
776        look_for_error(small_buf.as_ref(), TcpError::SliceTooShort);
777
778        assert_eq!(
779            TcpSegment::write_segment(
780                small_buf.as_mut(),
781                src_port,
782                dst_port,
783                seq_number,
784                ack_number,
785                flags_after_ns,
786                window_size,
787                mss_option,
788                mss_left,
789                payload,
790                Some((src_addr, dst_addr)),
791            )
792            .unwrap_err(),
793            TcpError::SliceTooShort
794        );
795
796        // Make sure we get the proper error for an insufficient value of mss_remaining.
797        assert_eq!(
798            TcpSegment::write_segment(
799                small_buf.as_mut(),
800                src_port,
801                dst_port,
802                seq_number,
803                ack_number,
804                flags_after_ns,
805                window_size,
806                mss_option,
807                0,
808                payload,
809                Some((src_addr, dst_addr)),
810            )
811            .unwrap_err(),
812            TcpError::MssRemaining
813        );
814    }
815}