1use std::cmp::min;
11use std::fmt::Debug;
12use std::net::Ipv4Addr;
13use std::num::NonZeroU16;
14
15use bitflags::bitflags;
16
17use super::Incomplete;
18use super::bytes::{InnerBytes, NetworkBytes, NetworkBytesMut};
19use crate::dumbo::ByteBuffer;
20use crate::dumbo::pdu::ChecksumProto;
21
22const SOURCE_PORT_OFFSET: usize = 0;
23const DESTINATION_PORT_OFFSET: usize = 2;
24const SEQ_NUMBER_OFFSET: usize = 4;
25const ACK_NUMBER_OFFSET: usize = 8;
26const DATAOFF_RSVD_NS_OFFSET: usize = 12;
27const FLAGS_AFTER_NS_OFFSET: usize = 13;
28const WINDOW_SIZE_OFFSET: usize = 14;
29const CHECKSUM_OFFSET: usize = 16;
30const URG_POINTER_OFFSET: usize = 18;
31
32const OPTIONS_OFFSET: u8 = 20;
33
34const MAX_HEADER_LEN: u8 = 60;
35
36const OPTION_KIND_EOL: u8 = 0x00;
37const OPTION_KIND_NOP: u8 = 0x01;
38const OPTION_KIND_MSS: u8 = 0x02;
39
40const OPTION_LEN_MSS: u8 = 0x04;
41
42const MSS_MIN: u16 = 100;
44
45bitflags! {
46 #[derive(Debug, Copy, Clone, PartialEq)]
54 pub struct Flags: u8 {
55 const CWR = 1 << 7;
57 const ECE = 1 << 6;
59 const URG = 1 << 5;
61 const ACK = 1 << 4;
63 const PSH = 1 << 3;
65 const RST = 1 << 2;
67 const SYN = 1 << 1;
69 const FIN = 1;
71 }
72}
73
74#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
76pub enum TcpError {
77 Checksum,
79 EmptyPayload,
81 HeaderLen,
83 MssOption,
85 MssRemaining,
87 SliceTooShort,
89}
90
91#[derive(Debug)]
96pub struct TcpSegment<'a, T: 'a> {
97 bytes: InnerBytes<'a, T>,
98}
99
100#[allow(clippy::len_without_is_empty)]
101impl<T: NetworkBytes + Debug> TcpSegment<'_, T> {
102 #[inline]
104 pub fn source_port(&self) -> u16 {
105 self.bytes.ntohs_unchecked(SOURCE_PORT_OFFSET)
106 }
107
108 #[inline]
110 pub fn destination_port(&self) -> u16 {
111 self.bytes.ntohs_unchecked(DESTINATION_PORT_OFFSET)
112 }
113
114 #[inline]
116 pub fn sequence_number(&self) -> u32 {
117 self.bytes.ntohl_unchecked(SEQ_NUMBER_OFFSET)
118 }
119
120 #[inline]
122 pub fn ack_number(&self) -> u32 {
123 self.bytes.ntohl_unchecked(ACK_NUMBER_OFFSET)
124 }
125
126 #[inline]
129 pub fn header_len_rsvd_ns(&self) -> (u8, u8, bool) {
130 let value = self.bytes[DATAOFF_RSVD_NS_OFFSET];
131 let data_offset = value >> 4;
132 let header_len = data_offset * 4;
133 let rsvd = value & 0x0e;
134 let ns = (value & 1) != 0;
135 (header_len, rsvd, ns)
136 }
137
138 #[inline]
140 pub fn header_len(&self) -> u8 {
141 self.header_len_rsvd_ns().0
142 }
143
144 #[inline]
146 pub fn flags_after_ns(&self) -> Flags {
147 Flags::from_bits_truncate(self.bytes[FLAGS_AFTER_NS_OFFSET])
148 }
149
150 #[inline]
152 pub fn window_size(&self) -> u16 {
153 self.bytes.ntohs_unchecked(WINDOW_SIZE_OFFSET)
154 }
155
156 #[inline]
158 pub fn checksum(&self) -> u16 {
159 self.bytes.ntohs_unchecked(CHECKSUM_OFFSET)
160 }
161
162 #[inline]
165 pub fn urgent_pointer(&self) -> u16 {
166 self.bytes.ntohs_unchecked(URG_POINTER_OFFSET)
167 }
168
169 #[inline]
175 pub fn options_unchecked(&self, header_len: usize) -> &[u8] {
176 &self.bytes[usize::from(OPTIONS_OFFSET)..header_len]
177 }
178
179 #[inline]
186 pub fn payload_unchecked(&self, header_len: usize) -> &[u8] {
187 self.bytes.split_at(header_len).1
188 }
189
190 #[inline]
192 pub fn len(&self) -> u16 {
193 u16::try_from(self.bytes.len()).unwrap_or(u16::MAX)
197 }
198
199 #[inline]
201 pub fn payload(&self) -> &[u8] {
202 self.payload_unchecked(self.header_len().into())
203 }
204
205 #[inline]
207 pub fn payload_len(&self) -> u16 {
208 self.len() - u16::from(self.header_len())
209 }
210
211 pub fn compute_checksum(&self, src_addr: Ipv4Addr, dst_addr: Ipv4Addr) -> u16 {
216 crate::dumbo::pdu::compute_checksum(&self.bytes, src_addr, dst_addr, ChecksumProto::Tcp)
217 }
218
219 pub fn parse_mss_option_unchecked(
228 &self,
229 header_len: usize,
230 ) -> Result<Option<NonZeroU16>, TcpError> {
231 let b = self.options_unchecked(header_len);
232 let mut i = 0;
233
234 while i + 3 < b.len() {
241 match b[i] {
242 OPTION_KIND_EOL => break,
243 OPTION_KIND_NOP => {
244 i += 1;
245 continue;
246 }
247 OPTION_KIND_MSS => {
248 let mss = b.ntohs_unchecked(i + 2);
252 if mss < MSS_MIN {
253 return Err(TcpError::MssOption);
254 }
255 return Ok(Some(NonZeroU16::new(mss).unwrap()));
257 }
258 _ => {
259 i += b[i + 1] as usize;
261 continue;
262 }
263 }
264 }
265 Ok(None)
266 }
267
268 #[inline]
275 pub fn from_bytes_unchecked(bytes: T) -> Self {
276 TcpSegment {
277 bytes: InnerBytes::new(bytes),
278 }
279 }
280
281 #[inline]
286 pub fn from_bytes(
287 bytes: T,
288 verify_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
289 ) -> Result<Self, TcpError> {
290 if bytes.len() < usize::from(OPTIONS_OFFSET) {
291 return Err(TcpError::SliceTooShort);
292 }
293
294 let segment = Self::from_bytes_unchecked(bytes);
295
296 let header_len = segment.header_len();
299
300 if header_len < OPTIONS_OFFSET
301 || u16::from(header_len) > min(u16::from(MAX_HEADER_LEN), segment.len())
302 {
303 return Err(TcpError::HeaderLen);
304 }
305
306 if let Some((src_addr, dst_addr)) = verify_checksum
307 && segment.compute_checksum(src_addr, dst_addr) != 0
308 {
309 return Err(TcpError::Checksum);
310 }
311
312 Ok(segment)
313 }
314}
315
316impl<T: NetworkBytesMut + Debug> TcpSegment<'_, T> {
317 #[inline]
319 pub fn set_source_port(&mut self, value: u16) -> &mut Self {
320 self.bytes.htons_unchecked(SOURCE_PORT_OFFSET, value);
321 self
322 }
323
324 #[inline]
326 pub fn set_destination_port(&mut self, value: u16) -> &mut Self {
327 self.bytes.htons_unchecked(DESTINATION_PORT_OFFSET, value);
328 self
329 }
330
331 #[inline]
333 pub fn set_sequence_number(&mut self, value: u32) -> &mut Self {
334 self.bytes.htonl_unchecked(SEQ_NUMBER_OFFSET, value);
335 self
336 }
337
338 #[inline]
340 pub fn set_ack_number(&mut self, value: u32) -> &mut Self {
341 self.bytes.htonl_unchecked(ACK_NUMBER_OFFSET, value);
342 self
343 }
344
345 #[inline]
349 pub fn set_header_len_rsvd_ns(&mut self, header_len: u8, ns: bool) -> &mut Self {
350 let mut value = header_len << 2;
351 if ns {
352 value |= 1;
353 }
354 self.bytes[DATAOFF_RSVD_NS_OFFSET] = value;
355 self
356 }
357
358 #[inline]
360 pub fn set_flags_after_ns(&mut self, flags: Flags) -> &mut Self {
361 self.bytes[FLAGS_AFTER_NS_OFFSET] = flags.bits();
362 self
363 }
364
365 #[inline]
367 pub fn set_window_size(&mut self, value: u16) -> &mut Self {
368 self.bytes.htons_unchecked(WINDOW_SIZE_OFFSET, value);
369 self
370 }
371
372 #[inline]
374 pub fn set_checksum(&mut self, value: u16) -> &mut Self {
375 self.bytes.htons_unchecked(CHECKSUM_OFFSET, value);
376 self
377 }
378
379 #[inline]
381 pub fn set_urgent_pointer(&mut self, value: u16) -> &mut Self {
382 self.bytes.htons_unchecked(URG_POINTER_OFFSET, value);
383 self
384 }
385
386 #[inline]
392 pub fn payload_mut_unchecked(&mut self, header_len: usize) -> &mut [u8] {
393 self.bytes.split_at_mut(header_len).1
394 }
395
396 #[inline]
398 pub fn payload_mut(&mut self) -> &mut [u8] {
399 let header_len = self.header_len();
400 self.payload_mut_unchecked(header_len.into())
401 }
402
403 #[allow(clippy::too_many_arguments)]
424 #[inline]
425 pub fn write_segment<R: ByteBuffer + ?Sized + Debug>(
426 buf: T,
427 src_port: u16,
428 dst_port: u16,
429 seq_number: u32,
430 ack_number: u32,
431 flags_after_ns: Flags,
432 window_size: u16,
433 mss_option: Option<u16>,
434 mss_remaining: u16,
435 payload: Option<(&R, usize)>,
436 compute_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
437 ) -> Result<Self, TcpError> {
438 Ok(Self::write_incomplete_segment(
439 buf,
440 seq_number,
441 ack_number,
442 flags_after_ns,
443 window_size,
444 mss_option,
445 mss_remaining,
446 payload,
447 )?
448 .finalize(src_port, dst_port, compute_checksum))
449 }
450
451 #[allow(clippy::too_many_arguments)]
475 #[inline]
476 pub fn write_incomplete_segment<R: ByteBuffer + ?Sized + Debug>(
477 buf: T,
478 seq_number: u32,
479 ack_number: u32,
480 flags_after_ns: Flags,
481 window_size: u16,
482 mss_option: Option<u16>,
483 mss_remaining: u16,
484 payload: Option<(&R, usize)>,
485 ) -> Result<Incomplete<Self>, TcpError> {
486 let mut mss_left = mss_remaining;
487
488 let mut segment_len = u16::from(OPTIONS_OFFSET);
490
491 let options_len = if mss_option.is_some() {
493 mss_left = mss_left
494 .checked_sub(OPTION_LEN_MSS.into())
495 .ok_or(TcpError::MssRemaining)?;
496 OPTION_LEN_MSS
497 } else {
498 0
499 };
500
501 segment_len += u16::from(options_len);
502
503 if buf.len() < usize::from(segment_len) {
504 return Err(TcpError::SliceTooShort);
505 }
506
507 let mut segment = Self::from_bytes_unchecked(buf);
509
510 segment
511 .set_sequence_number(seq_number)
512 .set_ack_number(ack_number)
513 .set_header_len_rsvd_ns(OPTIONS_OFFSET + options_len, false)
514 .set_flags_after_ns(flags_after_ns)
515 .set_window_size(window_size)
516 .set_urgent_pointer(0);
517
518 if let Some(value) = mss_option {
520 segment.bytes[usize::from(OPTIONS_OFFSET)] = OPTION_KIND_MSS;
521 segment.bytes[usize::from(OPTIONS_OFFSET) + 1] = OPTION_LEN_MSS;
522 segment
523 .bytes
524 .htons_unchecked(usize::from(OPTIONS_OFFSET) + 2, value);
525 }
526
527 let payload_bytes_count = if let Some((payload_buf, max_payload_bytes)) = payload {
528 let left_to_read = min(payload_buf.len(), max_payload_bytes);
529
530 let mut room_for_payload = min(segment.len() - segment_len, mss_left);
533 room_for_payload =
535 u16::try_from(min(usize::from(room_for_payload), left_to_read)).unwrap();
536
537 if room_for_payload == 0 {
538 return Err(TcpError::EmptyPayload);
539 }
540
541 payload_buf.read_to_slice(
545 0,
546 &mut segment.bytes
547 [usize::from(segment_len)..usize::from(segment_len + room_for_payload)],
548 );
549 room_for_payload
550 } else {
551 0
552 };
553 segment_len += payload_bytes_count;
554
555 segment.bytes.shrink_unchecked(segment_len.into());
557
558 Ok(Incomplete::new(segment))
560 }
561}
562
563impl<'a, T: NetworkBytesMut + Debug> Incomplete<TcpSegment<'a, T>> {
564 #[inline]
567 pub fn finalize(
568 mut self,
569 src_port: u16,
570 dst_port: u16,
571 compute_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
572 ) -> TcpSegment<'a, T> {
573 self.inner.set_source_port(src_port);
574 self.inner.set_destination_port(dst_port);
575 if let Some((src_addr, dst_addr)) = compute_checksum {
576 self.inner.set_checksum(0);
578 let checksum = self.inner.compute_checksum(src_addr, dst_addr);
579 self.inner.set_checksum(checksum);
580 }
581 self.inner
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_set_get() {
591 let mut a = [0u8; 100];
592 let mut p = TcpSegment::from_bytes_unchecked(a.as_mut());
593
594 assert_eq!(p.source_port(), 0);
595 p.set_source_port(123);
596 assert_eq!(p.source_port(), 123);
597
598 assert_eq!(p.destination_port(), 0);
599 p.set_destination_port(322);
600 assert_eq!(p.destination_port(), 322);
601
602 assert_eq!(p.sequence_number(), 0);
603 p.set_sequence_number(1_234_567);
604 assert_eq!(p.sequence_number(), 1_234_567);
605
606 assert_eq!(p.ack_number(), 0);
607 p.set_ack_number(345_234);
608 assert_eq!(p.ack_number(), 345_234);
609
610 assert_eq!(p.header_len_rsvd_ns(), (0, 0, false));
611 assert_eq!(p.header_len(), 0);
612 let header_len = 60;
614 p.set_header_len_rsvd_ns(header_len, true);
615 assert_eq!(p.header_len_rsvd_ns(), (header_len, 0, true));
616 assert_eq!(p.header_len(), header_len);
617
618 assert_eq!(p.flags_after_ns().bits(), 0);
619 p.set_flags_after_ns(Flags::SYN | Flags::URG);
620 assert_eq!(p.flags_after_ns(), Flags::SYN | Flags::URG);
621
622 assert_eq!(p.window_size(), 0);
623 p.set_window_size(60000);
624 assert_eq!(p.window_size(), 60000);
625
626 assert_eq!(p.checksum(), 0);
627 p.set_checksum(4321);
628 assert_eq!(p.checksum(), 4321);
629
630 assert_eq!(p.urgent_pointer(), 0);
631 p.set_urgent_pointer(5554);
632 assert_eq!(p.urgent_pointer(), 5554);
633 }
634
635 #[test]
636 fn test_constructors() {
637 let mut a = [1u8; 1460];
638 let b = [2u8; 1000];
639 let c = [3u8; 2000];
640
641 let src_addr = Ipv4Addr::new(10, 1, 2, 3);
642 let dst_addr = Ipv4Addr::new(192, 168, 44, 77);
643 let src_port = 1234;
644 let dst_port = 5678;
645 let seq_number = 11_111_222;
646 let ack_number = 34_566_543;
647 let flags_after_ns = Flags::SYN | Flags::RST;
648 let window_size = 19999;
649 let mss_left = 1460;
650 let mss_option = Some(mss_left);
651 let payload = Some((b.as_ref(), b.len()));
652
653 let header_len = OPTIONS_OFFSET + OPTION_LEN_MSS;
654
655 let segment_len = {
656 let mut segment = TcpSegment::write_segment(
657 a.as_mut(),
658 src_port,
659 dst_port,
660 seq_number,
661 ack_number,
662 flags_after_ns,
663 window_size,
664 mss_option,
665 mss_left,
666 payload,
667 Some((src_addr, dst_addr)),
668 )
669 .unwrap();
670
671 assert_eq!(segment.source_port(), src_port);
672 assert_eq!(segment.destination_port(), dst_port);
673 assert_eq!(segment.sequence_number(), seq_number);
674 assert_eq!(segment.ack_number(), ack_number);
675 assert_eq!(segment.header_len_rsvd_ns(), (header_len, 0, false));
676 assert_eq!(segment.flags_after_ns(), flags_after_ns);
677 assert_eq!(segment.window_size(), window_size);
678
679 let checksum = segment.checksum();
680 segment.set_checksum(0);
681 let computed_checksum = segment.compute_checksum(src_addr, dst_addr);
682 assert_eq!(checksum, computed_checksum);
683
684 segment.set_checksum(checksum);
685 assert_eq!(segment.compute_checksum(src_addr, dst_addr), 0);
686
687 assert_eq!(segment.urgent_pointer(), 0);
688
689 {
690 let options = segment.options_unchecked(header_len.into());
691 assert_eq!(options.len(), usize::from(OPTION_LEN_MSS));
692 assert_eq!(options[0], OPTION_KIND_MSS);
693 assert_eq!(options[1], OPTION_LEN_MSS);
694 assert_eq!(options.ntohs_unchecked(2), mss_left);
695 }
696
697 assert_eq!(
699 usize::from(segment.len()),
700 usize::from(header_len) + b.len(),
701 );
702 segment.len()
703 };
705
706 {
707 let segment =
708 TcpSegment::from_bytes(&a[..segment_len.into()], Some((src_addr, dst_addr)))
709 .unwrap();
710 assert_eq!(
711 segment.parse_mss_option_unchecked(header_len.into()),
712 Ok(Some(NonZeroU16::new(mss_left).unwrap()))
713 );
714 }
715
716 {
718 let segment_len = TcpSegment::write_segment(
719 a.as_mut(),
720 src_port,
721 dst_port,
722 seq_number,
723 ack_number,
724 flags_after_ns,
725 window_size,
726 mss_option,
727 mss_left,
728 Some((c.as_ref(), c.len())),
729 Some((src_addr, dst_addr)),
730 )
731 .unwrap()
732 .len();
733
734 assert_eq!(segment_len, mss_left);
735 }
736
737 fn p(buf: &mut [u8]) -> TcpSegment<'_, &mut [u8]> {
742 TcpSegment::from_bytes_unchecked(buf)
743 }
744
745 let look_for_error = |buf: &[u8], err: TcpError| {
747 assert_eq!(
748 TcpSegment::from_bytes(buf, Some((src_addr, dst_addr))).unwrap_err(),
749 err
750 );
751 };
752
753 p(a.as_mut()).set_header_len_rsvd_ns(OPTIONS_OFFSET.checked_sub(1).unwrap(), false);
755 look_for_error(a.as_ref(), TcpError::HeaderLen);
756
757 p(a.as_mut()).set_header_len_rsvd_ns(MAX_HEADER_LEN.checked_add(4).unwrap(), false);
759 look_for_error(a.as_ref(), TcpError::HeaderLen);
760
761 assert_eq!(
763 p(a.as_mut())
764 .set_header_len_rsvd_ns(header_len, false)
765 .compute_checksum(src_addr, dst_addr),
766 0
767 );
768
769 let checksum = p(a.as_mut()).checksum();
771 p(a.as_mut()).set_checksum(checksum.wrapping_add(1));
772 look_for_error(a.as_ref(), TcpError::Checksum);
773
774 let mut small_buf = [0u8; 1];
776 look_for_error(small_buf.as_ref(), TcpError::SliceTooShort);
777
778 assert_eq!(
779 TcpSegment::write_segment(
780 small_buf.as_mut(),
781 src_port,
782 dst_port,
783 seq_number,
784 ack_number,
785 flags_after_ns,
786 window_size,
787 mss_option,
788 mss_left,
789 payload,
790 Some((src_addr, dst_addr)),
791 )
792 .unwrap_err(),
793 TcpError::SliceTooShort
794 );
795
796 assert_eq!(
798 TcpSegment::write_segment(
799 small_buf.as_mut(),
800 src_port,
801 dst_port,
802 seq_number,
803 ack_number,
804 flags_after_ns,
805 window_size,
806 mss_option,
807 0,
808 payload,
809 Some((src_addr, dst_addr)),
810 )
811 .unwrap_err(),
812 TcpError::MssRemaining
813 );
814 }
815}