1use std::collections::{HashMap, HashSet};
9use std::fmt::Debug;
10use std::net::Ipv4Addr;
11use std::num::NonZeroUsize;
12
13use micro_http::{Request, Response};
14
15use crate::dumbo::pdu::bytes::NetworkBytes;
16use crate::dumbo::pdu::ipv4::{IPv4Packet, Ipv4Error as IPv4PacketError, PROTOCOL_TCP};
17use crate::dumbo::pdu::tcp::{Flags as TcpFlags, TcpError as TcpSegmentError, TcpSegment};
18use crate::dumbo::tcp::endpoint::Endpoint;
19use crate::dumbo::tcp::{NextSegmentStatus, RstConfig};
20
21#[derive(Debug, PartialEq, Eq)]
25pub enum RecvEvent {
26 EndpointDone,
28 FailedNewConnection,
31 NewConnectionSuccessful,
33 NewConnectionDropped,
36 NewConnectionReplacing,
39 Nothing,
41 UnexpectedSegment,
44}
45
46#[derive(Debug, PartialEq, Eq)]
48pub enum WriteEvent {
49 EndpointDone,
51 Nothing,
53}
54
55#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
61pub enum RecvError {
62 InvalidPort,
64 TcpSegment(#[from] TcpSegmentError),
66}
67
68#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
74pub enum WriteNextError {
75 IPv4Packet(#[from] IPv4PacketError),
77 TcpSegment(#[from] TcpSegmentError),
79}
80
81#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
85struct ConnectionTuple {
86 remote_addr: Ipv4Addr,
87 remote_port: u16,
88}
89
90impl ConnectionTuple {
91 fn new(remote_addr: Ipv4Addr, remote_port: u16) -> Self {
92 ConnectionTuple {
93 remote_addr,
94 remote_port,
95 }
96 }
97}
98
99#[derive(Debug)]
124pub struct TcpIPv4Handler {
125 local_ipv4_addr: Ipv4Addr,
127 local_port: u16,
129 connections: HashMap<ConnectionTuple, Endpoint>,
131 max_connections: NonZeroUsize,
133 active_connections: HashSet<ConnectionTuple>,
135 next_timeout: Option<(u64, ConnectionTuple)>,
138 rst_queue: Vec<(ConnectionTuple, RstConfig)>,
140 max_pending_resets: NonZeroUsize,
142}
143
144#[derive(Debug)]
147enum RecvSegmentOutcome {
148 EndpointDone,
149 EndpointRunning(NextSegmentStatus),
150 NewConnection,
151 UnexpectedSegment(bool),
152}
153
154impl TcpIPv4Handler {
155 #[inline]
161 pub fn new(
162 local_ipv4_addr: Ipv4Addr,
163 local_port: u16,
164 max_connections: NonZeroUsize,
165 max_pending_resets: NonZeroUsize,
166 ) -> Self {
167 TcpIPv4Handler {
168 local_ipv4_addr,
169 local_port,
170 connections: HashMap::with_capacity(max_connections.get()),
171 max_connections,
172 active_connections: HashSet::with_capacity(max_connections.get()),
173 next_timeout: None,
174 rst_queue: Vec::with_capacity(max_pending_resets.get()),
175 max_pending_resets,
176 }
177 }
178
179 pub fn set_local_ipv4_addr(&mut self, ipv4_addr: Ipv4Addr) {
181 self.local_ipv4_addr = ipv4_addr;
182 }
183
184 pub fn local_ipv4_addr(&self) -> Ipv4Addr {
186 self.local_ipv4_addr
187 }
188
189 pub fn local_port(&self) -> u16 {
191 self.local_port
192 }
193
194 pub fn max_connections(&self) -> NonZeroUsize {
196 self.max_connections
197 }
198
199 pub fn max_pending_resets(&self) -> NonZeroUsize {
201 self.max_pending_resets
202 }
203
204 pub fn receive_packet<T: NetworkBytes + Debug, F: FnOnce(Request) -> Response>(
208 &mut self,
209 packet: &IPv4Packet<T>,
210 callback: F,
211 ) -> Result<RecvEvent, RecvError> {
212 let segment = TcpSegment::from_bytes(packet.payload(), None)?;
216
217 if segment.destination_port() != self.local_port {
218 return Err(RecvError::InvalidPort);
219 }
220
221 let tuple = ConnectionTuple::new(packet.source_address(), segment.source_port());
222
223 let outcome = if let Some(endpoint) = self.connections.get_mut(&tuple) {
224 endpoint.receive_segment(&segment, callback);
225 if endpoint.is_done() {
226 RecvSegmentOutcome::EndpointDone
227 } else {
228 RecvSegmentOutcome::EndpointRunning(endpoint.next_segment_status())
229 }
230 } else if segment.flags_after_ns() == TcpFlags::SYN {
231 RecvSegmentOutcome::NewConnection
232 } else {
233 RecvSegmentOutcome::UnexpectedSegment(
235 !segment.flags_after_ns().intersects(TcpFlags::RST),
236 )
237 };
238
239 match outcome {
240 RecvSegmentOutcome::EndpointDone => {
241 self.remove_connection(tuple);
242 Ok(RecvEvent::EndpointDone)
243 }
244 RecvSegmentOutcome::EndpointRunning(status) => {
245 if !self.check_next_segment_status(tuple, status) {
246 self.active_connections.remove(&tuple);
249 }
250 Ok(RecvEvent::Nothing)
251 }
252 RecvSegmentOutcome::NewConnection => {
253 let endpoint = match Endpoint::new_with_defaults(&segment) {
254 Ok(endpoint) => endpoint,
255 Err(_) => return Ok(RecvEvent::FailedNewConnection),
256 };
257
258 if self.connections.len() >= self.max_connections.get() {
259 if let Some(evict_tuple) = self.find_evictable_connection() {
260 let rst_config = self.connections[&evict_tuple]
261 .connection()
262 .make_rst_config();
263 self.enqueue_rst_config(evict_tuple, rst_config);
264 self.remove_connection(evict_tuple);
265 self.add_connection(tuple, endpoint);
266 Ok(RecvEvent::NewConnectionReplacing)
267 } else {
268 self.enqueue_rst(tuple, &segment);
271 Ok(RecvEvent::NewConnectionDropped)
272 }
273 } else {
274 self.add_connection(tuple, endpoint);
275 Ok(RecvEvent::NewConnectionSuccessful)
276 }
277 }
278 RecvSegmentOutcome::UnexpectedSegment(enqueue_rst) => {
279 if enqueue_rst {
280 self.enqueue_rst(tuple, &segment);
281 }
282 Ok(RecvEvent::UnexpectedSegment)
283 }
284 }
285 }
286
287 fn check_timeout(&mut self, value: u64, tuple: ConnectionTuple) {
288 match self.next_timeout {
289 Some((t, _)) if t > value => self.next_timeout = Some((value, tuple)),
290 None => self.next_timeout = Some((value, tuple)),
291 _ => (),
292 };
293 }
294
295 fn find_next_timeout(&mut self) {
296 let mut next_timeout = None;
297 for (tuple, endpoint) in self.connections.iter() {
298 if let NextSegmentStatus::Timeout(value) = endpoint.next_segment_status() {
299 if let Some((t, _)) = next_timeout {
300 if t > value {
301 next_timeout = Some((value, *tuple));
302 }
303 } else {
304 next_timeout = Some((value, *tuple));
305 }
306 }
307 }
308 self.next_timeout = next_timeout;
309 }
310
311 fn check_next_segment_status(
314 &mut self,
315 tuple: ConnectionTuple,
316 status: NextSegmentStatus,
317 ) -> bool {
318 if let Some((_, timeout_tuple)) = self.next_timeout
319 && tuple == timeout_tuple
320 {
321 self.find_next_timeout();
322 }
323 match status {
324 NextSegmentStatus::Available => {
325 self.active_connections.insert(tuple);
326 return true;
327 }
328 NextSegmentStatus::Timeout(value) => self.check_timeout(value, tuple),
329 NextSegmentStatus::Nothing => (),
330 };
331
332 false
333 }
334
335 fn add_connection(&mut self, tuple: ConnectionTuple, endpoint: Endpoint) {
336 self.check_next_segment_status(tuple, endpoint.next_segment_status());
337 self.connections.insert(tuple, endpoint);
338 }
339
340 fn remove_connection(&mut self, tuple: ConnectionTuple) {
341 self.active_connections.remove(&tuple);
343 self.connections.remove(&tuple);
344
345 if let Some((_, timeout_tuple)) = self.next_timeout
346 && timeout_tuple == tuple
347 {
348 self.find_next_timeout();
349 }
350 }
351
352 fn find_evictable_connection(&self) -> Option<ConnectionTuple> {
354 for (tuple, endpoint) in self.connections.iter() {
355 if endpoint.is_evictable() {
356 return Some(*tuple);
357 }
358 }
359 None
360 }
361
362 fn enqueue_rst_config(&mut self, tuple: ConnectionTuple, cfg: RstConfig) {
363 if self.rst_queue.len() < self.max_pending_resets.get() {
365 self.rst_queue.push((tuple, cfg));
366 }
367 }
368
369 fn enqueue_rst<T: NetworkBytes + Debug>(&mut self, tuple: ConnectionTuple, s: &TcpSegment<T>) {
370 self.enqueue_rst_config(tuple, RstConfig::new(s));
371 }
372
373 pub fn write_next_packet(
381 &mut self,
382 buf: &mut [u8],
383 ) -> Result<(Option<NonZeroUsize>, WriteEvent), WriteNextError> {
384 let mut len = None;
385 let mut writer_status = None;
386 let mut event = WriteEvent::Nothing;
387
388 let mut packet =
390 IPv4Packet::write_header(buf, PROTOCOL_TCP, Ipv4Addr::LOCALHOST, Ipv4Addr::LOCALHOST)?;
391
392 let mss_reserved = 0;
395
396 if let Some((tuple, rst_cfg)) = self.rst_queue.pop() {
400 let (seq, ack, flags_after_ns) = rst_cfg.seq_ack_tcp_flags();
401 let segment_len = TcpSegment::write_incomplete_segment::<[u8]>(
402 packet.inner_mut().payload_mut(),
403 seq,
404 ack,
405 flags_after_ns,
406 10000,
407 None,
408 0,
409 None,
410 )?
411 .finalize(
412 self.local_port,
413 tuple.remote_port,
414 Some((self.local_ipv4_addr, tuple.remote_addr)),
415 )
416 .len();
417
418 packet
419 .inner_mut()
420 .set_source_address(self.local_ipv4_addr)
421 .set_destination_address(tuple.remote_addr);
422
423 let packet_len = packet.with_payload_len_unchecked(segment_len, true).len();
424 return Ok((
426 Some(NonZeroUsize::new(packet_len).unwrap()),
427 WriteEvent::Nothing,
428 ));
429 }
430
431 for tuple in self
432 .active_connections
433 .iter()
434 .chain(self.next_timeout.as_ref().map(|(_, x)| x))
435 {
436 let endpoint = self.connections.get_mut(tuple).unwrap();
439 let segment_len = {
442 let maybe_segment =
443 endpoint.write_next_segment(packet.inner_mut().payload_mut(), mss_reserved);
444
445 match maybe_segment {
446 Some(segment) => segment
447 .finalize(
448 self.local_port,
449 tuple.remote_port,
450 Some((self.local_ipv4_addr, tuple.remote_addr)),
451 )
452 .len(),
453 None => continue,
454 }
455 };
456
457 packet
458 .inner_mut()
459 .set_source_address(self.local_ipv4_addr)
460 .set_destination_address(tuple.remote_addr);
461
462 let ip_len = packet.with_payload_len_unchecked(segment_len, true).len();
463
464 len = Some(NonZeroUsize::new(ip_len).unwrap());
466 writer_status = Some((*tuple, endpoint.is_done()));
467
468 break;
469 }
470
471 if let Some((tuple, is_done)) = writer_status {
472 if is_done {
473 self.remove_connection(tuple);
474 event = WriteEvent::EndpointDone;
475 } else {
476 let status = self.connections[&tuple].next_segment_status();
479 if !self.check_next_segment_status(tuple, status) {
480 self.active_connections.remove(&tuple);
481 }
482 }
483 }
484
485 Ok((len, event))
486 }
487
488 #[inline]
490 pub fn next_segment_status(&self) -> NextSegmentStatus {
491 if !self.active_connections.is_empty() || !self.rst_queue.is_empty() {
492 return NextSegmentStatus::Available;
493 }
494
495 if let Some((value, _)) = self.next_timeout {
496 return NextSegmentStatus::Timeout(value);
497 }
498
499 NextSegmentStatus::Nothing
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use std::fmt::Debug;
506
507 use super::*;
508 use crate::dumbo::pdu::bytes::NetworkBytesMut;
509 use crate::dumbo::tcp::tests::mock_callback;
510
511 fn inner_tcp_mut<'a, T: NetworkBytesMut + Debug>(
512 p: &'a mut IPv4Packet<'_, T>,
513 ) -> TcpSegment<'a, &'a mut [u8]> {
514 TcpSegment::from_bytes(p.payload_mut(), None).unwrap()
515 }
516
517 #[allow(clippy::type_complexity)]
518 fn write_next<'a>(
519 h: &mut TcpIPv4Handler,
520 buf: &'a mut [u8],
521 ) -> Result<(Option<IPv4Packet<'a, &'a mut [u8]>>, WriteEvent), WriteNextError> {
522 h.write_next_packet(buf).map(|(o, err)| {
523 (
524 o.map(move |len| {
525 let len = len.get();
526 IPv4Packet::from_bytes(buf.split_at_mut(len).0, true).unwrap()
527 }),
528 err,
529 )
530 })
531 }
532
533 fn next_written_segment<'a>(
534 h: &mut TcpIPv4Handler,
535 buf: &'a mut [u8],
536 expected_event: WriteEvent,
537 ) -> TcpSegment<'a, &'a mut [u8]> {
538 let (segment_start, segment_end) = {
539 let (o, event) = write_next(h, buf).unwrap();
540 assert_eq!(event, expected_event);
541 let p = o.unwrap();
542 (p.header_len(), p.len())
543 };
544
545 TcpSegment::from_bytes(&mut buf[segment_start.into()..segment_end], None).unwrap()
546 }
547
548 fn drain_packets(
552 h: &mut TcpIPv4Handler,
553 src_addr: Ipv4Addr,
554 remote_addr: Ipv4Addr,
555 ) -> Result<usize, WriteNextError> {
556 let mut buf = [0u8; 2000];
557 let mut count: usize = 0;
558 loop {
559 let (o, _) = write_next(h, buf.as_mut())?;
560 if let Some(packet) = o {
561 count += 1;
562 assert_eq!(packet.source_address(), src_addr);
563 assert_eq!(packet.destination_address(), remote_addr);
564 } else {
565 break;
566 }
567 }
568 Ok(count)
569 }
570
571 #[test]
572 #[allow(clippy::cognitive_complexity)]
573 fn test_handler() {
574 let mut buf = [0u8; 100];
575 let mut buf2 = [0u8; 2000];
576
577 let wrong_local_addr = Ipv4Addr::new(123, 123, 123, 123);
578 let local_addr = Ipv4Addr::new(169, 254, 169, 254);
579 let local_port = 80;
580 let remote_addr = Ipv4Addr::new(10, 0, 0, 1);
581 let remote_port = 1012;
582 let max_connections = 2;
583 let max_pending_resets = 2;
584
585 let mut h = TcpIPv4Handler::new(
586 local_addr,
587 local_port,
588 NonZeroUsize::new(max_connections).unwrap(),
589 NonZeroUsize::new(max_pending_resets).unwrap(),
590 );
591
592 let mut p =
595 IPv4Packet::write_header(buf.as_mut(), PROTOCOL_TCP, remote_addr, wrong_local_addr)
596 .unwrap();
597
598 let seq_number = 123;
599
600 let s_len = {
601 let s = TcpSegment::write_segment::<[u8]>(
603 p.inner_mut().payload_mut(),
604 remote_port,
605 local_port + 1,
607 seq_number,
608 456,
609 TcpFlags::empty(),
610 10000,
611 None,
612 100,
613 None,
614 None,
615 )
616 .unwrap();
617 s.len()
618 };
619
620 assert_eq!(h.next_segment_status(), NextSegmentStatus::Nothing);
622 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(0));
623
624 let mut p = p.with_payload_len_unchecked(s_len, false);
625
626 p.set_destination_address(local_addr);
627 assert_eq!(
628 h.receive_packet(&p, mock_callback).unwrap_err(),
629 RecvError::InvalidPort
630 );
631
632 assert_eq!(h.rst_queue.len(), 0);
635 inner_tcp_mut(&mut p).set_destination_port(local_port);
636 assert_eq!(
637 h.receive_packet(&p, mock_callback),
638 Ok(RecvEvent::UnexpectedSegment)
639 );
640 assert_eq!(h.rst_queue.len(), 1);
641 assert_eq!(h.next_segment_status(), NextSegmentStatus::Available);
642 {
643 let s = next_written_segment(&mut h, buf2.as_mut(), WriteEvent::Nothing);
644 assert!(s.flags_after_ns().intersects(TcpFlags::RST));
645 assert_eq!(s.source_port(), local_port);
646 assert_eq!(s.destination_port(), remote_port);
647 }
648
649 assert_eq!(h.rst_queue.len(), 0);
650 assert_eq!(h.next_segment_status(), NextSegmentStatus::Nothing);
651
652 assert_eq!(
654 h.receive_packet(&p, mock_callback),
655 Ok(RecvEvent::UnexpectedSegment)
656 );
657 assert_eq!(h.rst_queue.len(), 1);
658 assert_eq!(
659 h.receive_packet(&p, mock_callback),
660 Ok(RecvEvent::UnexpectedSegment)
661 );
662 assert_eq!(h.rst_queue.len(), 2);
663 assert_eq!(
664 h.receive_packet(&p, mock_callback),
665 Ok(RecvEvent::UnexpectedSegment)
666 );
667 assert_eq!(h.rst_queue.len(), 2);
668
669 assert_eq!(h.next_segment_status(), NextSegmentStatus::Available);
671 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(2));
672 assert_eq!(h.next_segment_status(), NextSegmentStatus::Nothing);
673
674 assert_eq!(h.connections.len(), 0);
676 inner_tcp_mut(&mut p).set_flags_after_ns(TcpFlags::SYN);
677 assert_eq!(
678 h.receive_packet(&p, mock_callback),
679 Ok(RecvEvent::NewConnectionSuccessful)
680 );
681 assert_eq!(h.connections.len(), 1);
682 assert_eq!(h.active_connections.len(), 1);
683
684 inner_tcp_mut(&mut p)
687 .set_flags_after_ns(TcpFlags::RST)
688 .set_sequence_number(seq_number.wrapping_add(1));
689 assert_eq!(
690 h.receive_packet(&p, mock_callback),
691 Ok(RecvEvent::EndpointDone)
692 );
693 assert_eq!(h.connections.len(), 0);
694 assert_eq!(h.active_connections.len(), 0);
695
696 inner_tcp_mut(&mut p)
698 .set_flags_after_ns(TcpFlags::SYN)
699 .set_sequence_number(seq_number);
700 assert_eq!(
701 h.receive_packet(&p, mock_callback),
702 Ok(RecvEvent::NewConnectionSuccessful)
703 );
704 assert_eq!(h.connections.len(), 1);
705 assert_eq!(h.active_connections.len(), 1);
706
707 assert_eq!(h.next_segment_status(), NextSegmentStatus::Available);
709 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(1));
710
711 let remote_tuple = ConnectionTuple::new(remote_addr, remote_port);
712 let remote_tuple2 = ConnectionTuple::new(remote_addr, remote_port + 1);
713
714 assert_eq!(h.active_connections.len(), 0);
716 let old_timeout_value = if let Some((t, tuple)) = h.next_timeout {
717 assert_eq!(tuple, remote_tuple);
718 t
719 } else {
720 panic!("missing first expected timeout");
721 };
722
723 assert_eq!(h.receive_packet(&p, mock_callback), Ok(RecvEvent::Nothing));
726 assert_eq!(h.connections.len(), 1);
727 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(1));
729
730 assert_eq!(h.active_connections.len(), 0);
732 if let Some((t, tuple)) = h.next_timeout {
733 assert_eq!(tuple, remote_tuple);
734 assert!(t > old_timeout_value);
739 } else {
740 panic!("missing second expected timeout");
741 };
742
743 {
745 let seq = h.connections[&remote_tuple].connection().first_not_sent().0;
746 inner_tcp_mut(&mut p)
747 .set_flags_after_ns(TcpFlags::ACK)
748 .set_ack_number(seq);
749 assert_eq!(h.receive_packet(&p, mock_callback), Ok(RecvEvent::Nothing));
750 }
751
752 assert_eq!(h.active_connections.len(), 0);
754 assert_eq!(h.next_timeout, None);
755
756 inner_tcp_mut(&mut p).set_flags_after_ns(TcpFlags::SYN);
758
759 inner_tcp_mut(&mut p).set_source_port(remote_port + 1);
761 assert_eq!(
762 h.receive_packet(&p, mock_callback),
763 Ok(RecvEvent::NewConnectionSuccessful)
764 );
765 assert_eq!(h.connections.len(), 2);
766 assert_eq!(h.active_connections.len(), 1);
767 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(1));
769
770 assert_eq!(h.active_connections.len(), 0);
772 if let Some((_, tuple)) = h.next_timeout {
773 assert_ne!(tuple, ConnectionTuple::new(remote_addr, remote_port));
774 } else {
775 panic!("missing third expected timeout");
776 }
777
778 {
780 let port = remote_port + 2;
781 inner_tcp_mut(&mut p).set_source_port(port);
782 assert_eq!(
783 h.receive_packet(&p, mock_callback),
784 Ok(RecvEvent::NewConnectionDropped)
785 );
786 assert_eq!(h.connections.len(), 2);
787
788 assert_eq!(h.rst_queue.len(), 1);
790 let s = next_written_segment(&mut h, buf2.as_mut(), WriteEvent::Nothing);
791 assert!(s.flags_after_ns().intersects(TcpFlags::RST));
792 assert_eq!(s.destination_port(), port);
793 }
794
795 h.connections
797 .get_mut(&remote_tuple2)
798 .unwrap()
799 .set_eviction_threshold(0);
800
801 assert_eq!(
803 h.receive_packet(&p, mock_callback),
804 Ok(RecvEvent::NewConnectionReplacing)
805 );
806 assert_eq!(h.connections.len(), 2);
807 assert_eq!(h.active_connections.len(), 1);
808
809 assert_eq!(h.rst_queue.len(), 1);
811 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(2));
812 assert_eq!(h.rst_queue.len(), 0);
813 assert_eq!(h.active_connections.len(), 0);
814
815 inner_tcp_mut(&mut p).set_source_port(remote_port);
819 assert_eq!(h.receive_packet(&p, mock_callback), Ok(RecvEvent::Nothing));
820 assert_eq!(h.active_connections.len(), 1);
821 assert_eq!(drain_packets(&mut h, local_addr, remote_addr), Ok(1));
822 assert_eq!(h.connections.len(), 1);
823 assert_eq!(h.active_connections.len(), 0);
824 }
825}