1use std::fmt::Debug;
13use std::net::Ipv4Addr;
14
15use super::bytes::{InnerBytes, NetworkBytes};
16use crate::dumbo::pdu::bytes::NetworkBytesMut;
17use crate::dumbo::pdu::{ChecksumProto, Incomplete};
18
19const SOURCE_PORT_OFFSET: usize = 0;
20const DESTINATION_PORT_OFFSET: usize = 2;
21const LENGTH_OFFSET: usize = 4;
22const CHECKSUM_OFFSET: usize = 6;
23const PAYLOAD_OFFSET: usize = 8;
24pub const UDP_HEADER_SIZE: usize = 8;
26
27const IPV4_MAX_UDP_PACKET_SIZE: u16 = 65507;
30
31#[derive(Debug, PartialEq, Eq, thiserror::Error, displaydoc::Display)]
33pub enum UdpError {
34 Checksum,
36 DatagramTooShort,
38 PayloadTooBig,
40}
41
42#[derive(Debug)]
44pub struct UdpDatagram<'a, T: 'a> {
45 bytes: InnerBytes<'a, T>,
46}
47
48#[allow(clippy::len_without_is_empty)]
49impl<T: NetworkBytes + Debug> UdpDatagram<'_, T> {
50 #[inline]
57 pub fn from_bytes_unchecked(bytes: T) -> Self {
58 UdpDatagram {
59 bytes: InnerBytes::new(bytes),
60 }
61 }
62
63 #[inline]
66 pub fn from_bytes(
67 bytes: T,
68 verify_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
69 ) -> Result<Self, UdpError> {
70 if bytes.len() < UDP_HEADER_SIZE {
71 return Err(UdpError::DatagramTooShort);
72 }
73
74 let datagram = UdpDatagram::from_bytes_unchecked(bytes);
75 if let Some((src_addr, dst_addr)) = verify_checksum {
76 if datagram.checksum() != 0 && datagram.compute_checksum(src_addr, dst_addr) != 0xffff {
80 return Err(UdpError::Checksum);
81 }
82 }
83
84 Ok(datagram)
85 }
86
87 #[inline]
89 pub fn source_port(&self) -> u16 {
90 self.bytes.ntohs_unchecked(SOURCE_PORT_OFFSET)
91 }
92
93 #[inline]
95 pub fn destination_port(&self) -> u16 {
96 self.bytes.ntohs_unchecked(DESTINATION_PORT_OFFSET)
97 }
98
99 #[inline]
101 pub fn len(&self) -> u16 {
102 self.bytes.ntohs_unchecked(LENGTH_OFFSET)
103 }
104
105 #[inline]
107 pub fn checksum(&self) -> u16 {
108 self.bytes.ntohs_unchecked(CHECKSUM_OFFSET)
109 }
110
111 #[inline]
113 pub fn payload(&self) -> &[u8] {
114 self.bytes.split_at(PAYLOAD_OFFSET).1
116 }
117
118 #[inline]
120 pub fn compute_checksum(&self, src_addr: Ipv4Addr, dst_addr: Ipv4Addr) -> u16 {
121 crate::dumbo::pdu::compute_checksum(&self.bytes, src_addr, dst_addr, ChecksumProto::Udp)
122 }
123}
124
125impl<T: NetworkBytesMut + Debug> UdpDatagram<'_, T> {
126 #[inline]
134 pub fn write_incomplete_datagram(buf: T, payload: &[u8]) -> Result<Incomplete<Self>, UdpError> {
135 let mut packet = UdpDatagram::from_bytes(buf, None)?;
136 let len = payload.len() + UDP_HEADER_SIZE;
137
138 let len = match u16::try_from(len) {
139 Ok(len) if len <= IPV4_MAX_UDP_PACKET_SIZE => len,
140 _ => return Err(UdpError::PayloadTooBig),
141 };
142
143 packet.bytes.shrink_unchecked(len.into());
144 packet.payload_mut().copy_from_slice(payload);
145 packet.set_len(len);
146
147 Ok(Incomplete::new(packet))
148 }
149
150 #[inline]
152 pub fn set_source_port(&mut self, src_port: u16) -> &mut Self {
153 self.bytes.htons_unchecked(SOURCE_PORT_OFFSET, src_port);
154 self
155 }
156
157 #[inline]
159 pub fn set_destination_port(&mut self, dst_port: u16) -> &mut Self {
160 self.bytes
161 .htons_unchecked(DESTINATION_PORT_OFFSET, dst_port);
162 self
163 }
164
165 #[inline]
167 pub fn payload_mut(&mut self) -> &mut [u8] {
168 &mut self.bytes[PAYLOAD_OFFSET..]
169 }
170
171 #[inline]
173 pub fn set_len(&mut self, len: u16) -> &mut Self {
174 self.bytes.htons_unchecked(LENGTH_OFFSET, len);
175 self
176 }
177
178 #[inline]
180 pub fn set_checksum(&mut self, checksum: u16) -> &mut Self {
181 self.bytes.htons_unchecked(CHECKSUM_OFFSET, checksum);
182 self
183 }
184}
185
186impl<'a, T: NetworkBytesMut + Debug> Incomplete<UdpDatagram<'a, T>> {
187 #[inline]
190 pub fn finalize(
191 mut self,
192 src_port: u16,
193 dst_port: u16,
194 compute_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
195 ) -> UdpDatagram<'a, T> {
196 self.inner.set_source_port(src_port);
197 self.inner.set_destination_port(dst_port);
198 self.inner.set_checksum(0);
199
200 if let Some((src_addr, dst_addr)) = compute_checksum {
201 let checksum = self.inner.compute_checksum(src_addr, dst_addr);
202 self.inner.set_checksum(checksum);
203 }
204
205 self.inner
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::dumbo::pdu::udp::UdpDatagram;
213
214 #[test]
215 #[allow(clippy::len_zero)]
216 fn test_set_get() {
217 let mut raw = [0u8; 30];
218 let total_len = raw.len();
219 let mut p: UdpDatagram<&mut [u8]> = UdpDatagram::from_bytes_unchecked(raw.as_mut());
220
221 assert_eq!(p.source_port(), 0);
222 let src_port: u16 = 213;
223 p.set_source_port(src_port);
224 assert_eq!(p.source_port(), src_port);
225
226 assert_eq!(p.destination_port(), 0);
227 let dst_port: u16 = 64193;
228 p.set_destination_port(dst_port);
229 assert_eq!(p.destination_port(), dst_port);
230
231 assert_eq!(p.len(), 0);
232 let len = 12;
233 p.set_len(len);
234 assert_eq!(p.len(), len);
235
236 assert_eq!(p.checksum(), 0);
237 let checksum: u16 = 32;
238 p.set_checksum(32);
239 assert_eq!(p.checksum(), checksum);
240
241 let payload_length = total_len - UDP_HEADER_SIZE;
242 assert_eq!(p.payload().len(), payload_length);
243
244 let payload: Vec<u8> = (0..u8::try_from(payload_length).unwrap()).collect();
245 p.payload_mut().copy_from_slice(&payload);
246 assert_eq!(*p.payload(), payload[..]);
247 }
248
249 #[test]
250 fn test_failing_construction() {
251 let mut raw = [0u8; 8];
252 let huge_payload = [0u8; IPV4_MAX_UDP_PACKET_SIZE as usize];
253
254 assert_eq!(
255 UdpDatagram::write_incomplete_datagram(raw.as_mut(), &huge_payload).unwrap_err(),
256 UdpError::PayloadTooBig
257 );
258
259 let mut short_header = [0u8; UDP_HEADER_SIZE - 1];
260 assert_eq!(
261 UdpDatagram::from_bytes(short_header.as_mut(), None).unwrap_err(),
262 UdpError::DatagramTooShort
263 )
264 }
265
266 #[test]
267 fn test_construction() {
268 let mut packet = [0u8; 32 + UDP_HEADER_SIZE]; let payload: Vec<u8> = (0..32).collect();
270 let src_port = 32133;
271 let dst_port = 22113;
272 let src_addr = Ipv4Addr::new(10, 100, 11, 21);
273 let dst_addr = Ipv4Addr::new(192, 168, 121, 35);
274 let p = UdpDatagram::write_incomplete_datagram(packet.as_mut(), &payload[..]).unwrap();
275 let mut p = p.finalize(src_port, dst_port, Some((src_addr, dst_addr)));
276
277 let checksum = p.checksum();
278 let c = p.compute_checksum(src_addr, dst_addr);
279 assert_eq!(c, 0xffff);
280
281 p.set_checksum(0);
282 let computed_checksum = p.compute_checksum(src_addr, dst_addr);
283 assert_eq!(checksum, computed_checksum);
284
285 let mut a = [1u8; 128];
286 let checksum = UdpDatagram::from_bytes_unchecked(a.as_mut()).checksum();
287 let _ =
290 UdpDatagram::from_bytes_unchecked(a.as_mut()).set_checksum(checksum.wrapping_add(1));
291 let p_err = UdpDatagram::from_bytes(a.as_mut(), Some((src_addr, dst_addr))).unwrap_err();
292 assert_eq!(p_err, UdpError::Checksum);
293 }
294
295 #[test]
296 fn test_checksum() {
297 let mut bytes = [0u8; 2 + UDP_HEADER_SIZE]; let correct_checksum: u16 = 0x14de;
299 let payload_bytes = b"bb";
300 let src_ip = Ipv4Addr::new(152, 1, 51, 27);
301 let dst_ip = Ipv4Addr::new(152, 14, 94, 75);
302 let p = UdpDatagram::write_incomplete_datagram(bytes.as_mut(), payload_bytes).unwrap();
303 let p = p.finalize(41103, 9876, Some((src_ip, dst_ip)));
304 assert_eq!(p.checksum(), correct_checksum);
305 }
306}