1use std::convert::From;
11use std::fmt::Debug;
12use std::net::Ipv4Addr;
13
14use crate::dumbo::pdu::bytes::{InnerBytes, NetworkBytes, NetworkBytesMut};
15use crate::dumbo::pdu::{Incomplete, ethernet};
16
17const VERSION_AND_IHL_OFFSET: usize = 0;
18const DSCP_AND_ECN_OFFSET: usize = 1;
19const TOTAL_LEN_OFFSET: usize = 2;
20const IDENTIFICATION_OFFSET: usize = 4;
21const FLAGS_AND_FRAGMENTOFF_OFFSET: usize = 6;
22const TTL_OFFSET: usize = 8;
23const PROTOCOL_OFFSET: usize = 9;
24const HEADER_CHECKSUM_OFFSET: usize = 10;
25const SOURCE_ADDRESS_OFFSET: usize = 12;
26const DESTINATION_ADDRESS_OFFSET: usize = 16;
27const OPTIONS_OFFSET: u8 = 20;
28
29pub const IPV4_VERSION: u8 = 0x04;
31pub const DEFAULT_TTL: u8 = 1;
33
34pub const PROTOCOL_TCP: u8 = 0x06;
36
37pub const PROTOCOL_UDP: u8 = 0x11;
39
40#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
42pub enum Ipv4Error {
43 Checksum,
45 HeaderLen,
47 InvalidTotalLen,
49 SliceExactLen,
51 SliceTooShort,
53 Version,
55}
56
57#[derive(Debug)]
59pub struct IPv4Packet<'a, T: 'a> {
60 bytes: InnerBytes<'a, T>,
61}
62
63#[allow(clippy::len_without_is_empty)]
64impl<T: NetworkBytes + Debug> IPv4Packet<'_, T> {
65 #[inline]
73 pub fn from_bytes_unchecked(bytes: T) -> Self {
74 IPv4Packet {
75 bytes: InnerBytes::new(bytes),
76 }
77 }
78
79 pub fn from_bytes(bytes: T, verify_checksum: bool) -> Result<Self, Ipv4Error> {
82 let bytes_len = bytes.len();
83
84 if bytes_len < usize::from(OPTIONS_OFFSET) {
85 return Err(Ipv4Error::SliceTooShort);
86 }
87
88 let packet = IPv4Packet::from_bytes_unchecked(bytes);
89
90 let (version, header_len) = packet.version_and_header_len();
91
92 if version != IPV4_VERSION {
93 return Err(Ipv4Error::Version);
94 }
95
96 let total_len = packet.total_len() as usize;
97
98 if total_len < header_len.into() {
99 return Err(Ipv4Error::InvalidTotalLen);
100 }
101
102 if total_len != bytes_len {
103 return Err(Ipv4Error::SliceExactLen);
104 }
105
106 if header_len < OPTIONS_OFFSET {
107 return Err(Ipv4Error::HeaderLen);
108 }
109
110 if verify_checksum && packet.compute_checksum_unchecked(header_len.into()) != 0 {
114 return Err(Ipv4Error::Checksum);
115 }
116
117 Ok(packet)
118 }
119
120 #[inline]
125 pub fn version_and_header_len(&self) -> (u8, u8) {
126 let x = self.bytes[VERSION_AND_IHL_OFFSET];
127 let ihl = x & 0x0f;
128 let header_len = ihl << 2;
129 (x >> 4, header_len)
130 }
131
132 #[inline]
134 pub fn header_len(&self) -> u8 {
135 let (_, header_len) = self.version_and_header_len();
136 header_len
137 }
138
139 #[inline]
141 pub fn dscp_and_ecn(&self) -> (u8, u8) {
142 let x = self.bytes[DSCP_AND_ECN_OFFSET];
143 (x >> 2, x & 0b11)
144 }
145
146 #[inline]
148 pub fn total_len(&self) -> u16 {
149 self.bytes.ntohs_unchecked(TOTAL_LEN_OFFSET)
150 }
151
152 #[inline]
154 pub fn identification(&self) -> u16 {
155 self.bytes.ntohs_unchecked(IDENTIFICATION_OFFSET)
156 }
157
158 #[inline]
160 pub fn flags_and_fragment_offset(&self) -> (u8, u16) {
161 let x = self.bytes.ntohs_unchecked(FLAGS_AND_FRAGMENTOFF_OFFSET);
162 ((x >> 13) as u8, x & 0x1fff)
163 }
164
165 #[inline]
167 pub fn ttl(&self) -> u8 {
168 self.bytes[TTL_OFFSET]
169 }
170
171 #[inline]
173 pub fn protocol(&self) -> u8 {
174 self.bytes[PROTOCOL_OFFSET]
175 }
176
177 #[inline]
179 pub fn header_checksum(&self) -> u16 {
180 self.bytes.ntohs_unchecked(HEADER_CHECKSUM_OFFSET)
181 }
182
183 #[inline]
185 pub fn source_address(&self) -> Ipv4Addr {
186 Ipv4Addr::from(self.bytes.ntohl_unchecked(SOURCE_ADDRESS_OFFSET))
187 }
188
189 #[inline]
191 pub fn destination_address(&self) -> Ipv4Addr {
192 Ipv4Addr::from(self.bytes.ntohl_unchecked(DESTINATION_ADDRESS_OFFSET))
193 }
194
195 #[inline]
202 pub fn payload_unchecked(&self, header_len: usize) -> &[u8] {
203 self.bytes.split_at(header_len).1
204 }
205
206 #[inline]
208 pub fn payload(&self) -> &[u8] {
209 self.payload_unchecked(self.header_len().into())
210 }
211
212 #[inline]
217 pub fn len(&self) -> usize {
218 self.bytes.len()
219 }
220
221 pub fn compute_checksum_unchecked(&self, header_len: usize) -> u16 {
232 let mut sum = 0u32;
233 for i in 0..header_len / 2 {
234 sum += u32::from(self.bytes.ntohs_unchecked(i * 2));
235 }
236
237 while sum >> 16 != 0 {
238 sum = (sum & 0xffff) + (sum >> 16);
239 }
240
241 !u16::try_from(sum).unwrap()
243 }
244
245 #[inline]
247 pub fn compute_checksum(&self) -> u16 {
248 self.compute_checksum_unchecked(self.header_len().into())
249 }
250}
251
252impl<T: NetworkBytesMut + Debug> IPv4Packet<'_, T> {
253 pub fn write_header(
261 buf: T,
262 protocol: u8,
263 src_addr: Ipv4Addr,
264 dst_addr: Ipv4Addr,
265 ) -> Result<Incomplete<Self>, Ipv4Error> {
266 if buf.len() < usize::from(OPTIONS_OFFSET) {
267 return Err(Ipv4Error::SliceTooShort);
268 }
269 let mut packet = IPv4Packet::from_bytes_unchecked(buf);
270 packet
271 .set_version_and_header_len(IPV4_VERSION, OPTIONS_OFFSET)
272 .set_dscp_and_ecn(0, 0)
273 .set_identification(0)
274 .set_flags_and_fragment_offset(0, 0)
275 .set_ttl(DEFAULT_TTL)
276 .set_protocol(protocol)
277 .set_source_address(src_addr)
278 .set_destination_address(dst_addr);
279
280 Ok(Incomplete::new(packet))
281 }
282
283 #[inline]
286 pub fn set_version_and_header_len(&mut self, version: u8, header_len: u8) -> &mut Self {
287 let version = version << 4;
288 let ihl = (header_len >> 2) & 0xf;
289 self.bytes[VERSION_AND_IHL_OFFSET] = version | ihl;
290 self
291 }
292
293 #[inline]
295 pub fn set_dscp_and_ecn(&mut self, dscp: u8, ecn: u8) -> &mut Self {
296 self.bytes[DSCP_AND_ECN_OFFSET] = (dscp << 2) | ecn;
297 self
298 }
299
300 #[inline]
302 pub fn set_total_len(&mut self, value: u16) -> &mut Self {
303 self.bytes.htons_unchecked(TOTAL_LEN_OFFSET, value);
304 self
305 }
306
307 #[inline]
309 pub fn set_identification(&mut self, value: u16) -> &mut Self {
310 self.bytes.htons_unchecked(IDENTIFICATION_OFFSET, value);
311 self
312 }
313
314 #[inline]
316 pub fn set_flags_and_fragment_offset(&mut self, flags: u8, fragment_offset: u16) -> &mut Self {
317 let value = (u16::from(flags) << 13) | fragment_offset;
318 self.bytes
319 .htons_unchecked(FLAGS_AND_FRAGMENTOFF_OFFSET, value);
320 self
321 }
322
323 #[inline]
325 pub fn set_ttl(&mut self, value: u8) -> &mut Self {
326 self.bytes[TTL_OFFSET] = value;
327 self
328 }
329
330 #[inline]
332 pub fn set_protocol(&mut self, value: u8) -> &mut Self {
333 self.bytes[PROTOCOL_OFFSET] = value;
334 self
335 }
336
337 #[inline]
339 pub fn set_header_checksum(&mut self, value: u16) -> &mut Self {
340 self.bytes.htons_unchecked(HEADER_CHECKSUM_OFFSET, value);
341 self
342 }
343
344 #[inline]
346 pub fn set_source_address(&mut self, addr: Ipv4Addr) -> &mut Self {
347 self.bytes
348 .htonl_unchecked(SOURCE_ADDRESS_OFFSET, u32::from(addr));
349 self
350 }
351
352 #[inline]
354 pub fn set_destination_address(&mut self, addr: Ipv4Addr) -> &mut Self {
355 self.bytes
356 .htonl_unchecked(DESTINATION_ADDRESS_OFFSET, u32::from(addr));
357 self
358 }
359
360 #[inline]
367 pub fn payload_mut_unchecked(&mut self, header_len: usize) -> &mut [u8] {
368 self.bytes.split_at_mut(header_len).1
369 }
370
371 #[inline]
373 pub fn payload_mut(&mut self) -> &mut [u8] {
374 let header_len = self.header_len();
377 self.payload_mut_unchecked(header_len.into())
378 }
379}
380
381impl<'a, T: NetworkBytesMut + Debug> Incomplete<IPv4Packet<'a, T>> {
387 #[inline]
395 pub fn with_header_and_payload_len_unchecked(
396 mut self,
397 header_len: u8,
398 payload_len: u16,
399 compute_checksum: bool,
400 ) -> IPv4Packet<'a, T> {
401 let total_len = u16::from(header_len) + payload_len;
402 {
403 let packet = &mut self.inner;
404
405 packet.bytes.shrink_unchecked(total_len.into());
408 packet.set_total_len(total_len);
410 if compute_checksum {
411 packet.set_header_checksum(0);
413 let checksum = packet.compute_checksum_unchecked(header_len.into());
415 packet.set_header_checksum(checksum);
416 }
417 }
418 self.inner
419 }
420
421 #[inline]
428 pub fn with_options_and_payload_len_unchecked(
429 self,
430 options_len: u8,
431 payload_len: u16,
432 compute_checksum: bool,
433 ) -> IPv4Packet<'a, T> {
434 let header_len = OPTIONS_OFFSET + options_len;
435 self.with_header_and_payload_len_unchecked(header_len, payload_len, compute_checksum)
436 }
437
438 #[inline]
445 pub fn with_payload_len_unchecked(
446 self,
447 payload_len: u16,
448 compute_checksum: bool,
449 ) -> IPv4Packet<'a, T> {
450 let header_len = self.inner().header_len();
451 self.with_header_and_payload_len_unchecked(header_len, payload_len, compute_checksum)
452 }
453}
454
455#[inline]
458pub fn test_speculative_dst_addr(buf: &[u8], addr: Ipv4Addr) -> bool {
459 if buf.len() >= ethernet::PAYLOAD_OFFSET + usize::from(OPTIONS_OFFSET) {
461 let bytes = &buf[ethernet::PAYLOAD_OFFSET..];
462 if IPv4Packet::from_bytes_unchecked(bytes).destination_address() == addr {
463 return true;
464 }
465 }
466 false
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::dumbo::MacAddr;
473
474 const MAX_HEADER_LEN: u8 = 60;
475
476 #[test]
477 fn test_set_get() {
478 let mut a = [0u8; 100];
479 let mut p = IPv4Packet::from_bytes_unchecked(a.as_mut());
480
481 assert_eq!(p.version_and_header_len(), (0, 0));
482 p.set_version_and_header_len(IPV4_VERSION, 24);
483 assert_eq!(p.version_and_header_len(), (IPV4_VERSION, 24));
484
485 assert_eq!(p.dscp_and_ecn(), (0, 0));
486 p.set_dscp_and_ecn(3, 2);
487 assert_eq!(p.dscp_and_ecn(), (3, 2));
488
489 assert_eq!(p.total_len(), 0);
490 p.set_total_len(123);
491 assert_eq!(p.total_len(), 123);
492
493 assert_eq!(p.identification(), 0);
494 p.set_identification(1112);
495 assert_eq!(p.identification(), 1112);
496
497 assert_eq!(p.flags_and_fragment_offset(), (0, 0));
498 p.set_flags_and_fragment_offset(7, 1000);
499 assert_eq!(p.flags_and_fragment_offset(), (7, 1000));
500
501 assert_eq!(p.ttl(), 0);
502 p.set_ttl(123);
503 assert_eq!(p.ttl(), 123);
504
505 assert_eq!(p.protocol(), 0);
506 p.set_protocol(114);
507 assert_eq!(p.protocol(), 114);
508
509 assert_eq!(p.header_checksum(), 0);
510 p.set_header_checksum(1234);
511 assert_eq!(p.header_checksum(), 1234);
512
513 let addr = Ipv4Addr::new(10, 11, 12, 13);
514
515 assert_eq!(p.source_address(), Ipv4Addr::from(0));
516 p.set_source_address(addr);
517 assert_eq!(p.source_address(), addr);
518
519 assert_eq!(p.destination_address(), Ipv4Addr::from(0));
520 p.set_destination_address(addr);
521 assert_eq!(p.destination_address(), addr);
522 }
523
524 #[test]
525 fn test_constructors() {
526 let mut buf = [1u8; 100];
528
529 let src = Ipv4Addr::new(10, 100, 11, 21);
530 let dst = Ipv4Addr::new(192, 168, 121, 35);
531
532 let buf_len = u16::try_from(buf.len()).unwrap();
533 let header_len = OPTIONS_OFFSET;
535 let payload_len = buf_len - u16::from(OPTIONS_OFFSET);
536
537 {
538 let mut p = IPv4Packet::write_header(buf.as_mut(), PROTOCOL_TCP, src, dst)
539 .unwrap()
540 .with_header_and_payload_len_unchecked(header_len, payload_len, true);
541
542 assert_eq!(p.version_and_header_len(), (IPV4_VERSION, header_len));
543 assert_eq!(p.dscp_and_ecn(), (0, 0));
544 assert_eq!(p.total_len(), buf_len);
545 assert_eq!(p.identification(), 0);
546 assert_eq!(p.flags_and_fragment_offset(), (0, 0));
547 assert_eq!(p.ttl(), DEFAULT_TTL);
548 assert_eq!(p.protocol(), PROTOCOL_TCP);
549
550 let checksum = p.header_checksum();
551 p.set_header_checksum(0);
552 let computed_checksum = p.compute_checksum();
553 assert_eq!(computed_checksum, checksum);
554
555 p.set_header_checksum(computed_checksum);
556 assert_eq!(p.compute_checksum(), 0);
557
558 assert_eq!(p.source_address(), src);
559 assert_eq!(p.destination_address(), dst);
560
561 }
563
564 IPv4Packet::from_bytes(buf.as_ref(), true).unwrap();
565
566 fn p(buf: &mut [u8]) -> IPv4Packet<'_, &mut [u8]> {
571 IPv4Packet::from_bytes_unchecked(buf)
572 }
573
574 let look_for_error = |buf: &[u8], err: Ipv4Error| {
576 assert_eq!(IPv4Packet::from_bytes(buf, true).unwrap_err(), err);
577 };
578
579 p(buf.as_mut()).set_version_and_header_len(IPV4_VERSION + 1, header_len);
581 look_for_error(buf.as_ref(), Ipv4Error::Version);
582
583 p(buf.as_mut()).set_version_and_header_len(IPV4_VERSION, OPTIONS_OFFSET - 1);
585 look_for_error(buf.as_ref(), Ipv4Error::HeaderLen);
586
587 p(buf.as_mut()).set_version_and_header_len(IPV4_VERSION, MAX_HEADER_LEN + 4);
593 look_for_error(buf.as_ref(), Ipv4Error::HeaderLen);
594
595 p(buf.as_mut())
597 .set_version_and_header_len(IPV4_VERSION, OPTIONS_OFFSET)
598 .set_total_len(u16::from(OPTIONS_OFFSET) - 1);
599 look_for_error(buf.as_ref(), Ipv4Error::InvalidTotalLen);
600
601 p(buf.as_mut()).set_total_len(buf_len - 1);
603 look_for_error(buf.as_ref(), Ipv4Error::SliceExactLen);
604
605 assert_eq!(p(buf.as_mut()).set_total_len(buf_len).compute_checksum(), 0);
607
608 let checksum = p(buf.as_mut()).header_checksum();
610 p(buf.as_mut()).set_header_checksum(checksum.wrapping_add(1));
611 look_for_error(buf.as_ref(), Ipv4Error::Checksum);
612
613 let mut small_buf = [0u8; 1];
615
616 look_for_error(small_buf.as_ref(), Ipv4Error::SliceTooShort);
617
618 assert_eq!(
619 IPv4Packet::write_header(small_buf.as_mut(), PROTOCOL_TCP, src, dst).unwrap_err(),
620 Ipv4Error::SliceTooShort
621 );
622 }
623
624 #[test]
625 fn test_incomplete() {
626 let mut buf = [0u8; 100];
627 let src = Ipv4Addr::new(10, 100, 11, 21);
628 let dst = Ipv4Addr::new(192, 168, 121, 35);
629 let payload_len = 30;
630 let options_len = 0;
632 let header_len = OPTIONS_OFFSET + options_len;
633
634 {
635 let p = IPv4Packet::write_header(buf.as_mut(), PROTOCOL_TCP, src, dst)
636 .unwrap()
637 .with_payload_len_unchecked(payload_len, true);
638
639 assert_eq!(p.compute_checksum(), 0);
640 assert_eq!(p.total_len() as usize, p.len());
641 assert_eq!(p.len(), usize::from(header_len) + usize::from(payload_len));
642 }
643
644 {
645 let p = IPv4Packet::write_header(buf.as_mut(), PROTOCOL_TCP, src, dst)
646 .unwrap()
647 .with_options_and_payload_len_unchecked(options_len, payload_len, true);
648
649 assert_eq!(p.compute_checksum(), 0);
650 assert_eq!(p.total_len() as usize, p.len());
651 assert_eq!(p.len(), usize::from(header_len) + usize::from(payload_len));
652 }
653 }
654
655 #[test]
656 fn test_speculative() {
657 let mut buf = [0u8; 1000];
658 let mac = MacAddr::from_bytes_unchecked(&[0; 6]);
659 let ip = Ipv4Addr::new(1, 2, 3, 4);
660 let other_ip = Ipv4Addr::new(5, 6, 7, 8);
661
662 {
663 let mut eth = crate::dumbo::pdu::ethernet::EthernetFrame::write_incomplete(
664 buf.as_mut(),
665 mac,
666 mac,
667 0,
668 )
669 .unwrap();
670 IPv4Packet::from_bytes_unchecked(eth.inner_mut().payload_mut())
671 .set_destination_address(ip);
672 }
673 assert!(test_speculative_dst_addr(buf.as_ref(), ip));
674
675 {
676 let mut eth = crate::dumbo::pdu::ethernet::EthernetFrame::write_incomplete(
677 buf.as_mut(),
678 mac,
679 mac,
680 0,
681 )
682 .unwrap();
683 IPv4Packet::from_bytes_unchecked(eth.inner_mut().payload_mut())
684 .set_destination_address(other_ip);
685 }
686 assert!(!test_speculative_dst_addr(buf.as_ref(), ip));
687
688 let small = [0u8; 1];
689 assert!(!test_speculative_dst_addr(small.as_ref(), ip));
690 }
691}