vmm/mmds/
mod.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4/// MMDS data store
5pub mod data_store;
6/// MMDS network stack
7pub mod ns;
8/// Defines the structures needed for saving/restoring MmdsNetworkStack.
9pub mod persist;
10mod token;
11/// MMDS token headers
12pub mod token_headers;
13
14use std::sync::{Arc, Mutex};
15
16use micro_http::{
17    Body, HttpHeaderError, MediaType, Method, Request, RequestError, Response, StatusCode, Version,
18};
19use serde_json::{Map, Value};
20
21use crate::logger::{IncMetric, METRICS};
22use crate::mmds::data_store::{Mmds, MmdsDatastoreError as MmdsError, MmdsVersion, OutputFormat};
23use crate::mmds::token::PATH_TO_TOKEN;
24use crate::mmds::token_headers::{
25    X_AWS_EC2_METADATA_TOKEN_HEADER, X_AWS_EC2_METADATA_TOKEN_SSL_SECONDS_HEADER,
26    X_FORWARDED_FOR_HEADER, X_METADATA_TOKEN_HEADER, X_METADATA_TOKEN_TTL_SECONDS_HEADER,
27    get_header_value_pair,
28};
29
30#[rustfmt::skip]
31#[derive(Debug, thiserror::Error, displaydoc::Display)]
32/// MMDS token errors
33pub enum VmmMmdsError {
34    /// MMDS token not valid.
35    InvalidToken,
36    /// Invalid URI.
37    InvalidURI,
38    /// Not allowed HTTP method.
39    MethodNotAllowed,
40    /// No MMDS token provided. Use `X-metadata-token` or `X-aws-ec2-metadata-token` header to specify the session token.
41    NoTokenProvided,
42    /// Token time to live value not found. Use `X-metadata-token-ttl-seconds` or `X-aws-ec2-metadata-token-ttl-seconds` header to specify the token's lifetime.
43    NoTtlProvided,
44    /// Resource not found: {0}.
45    ResourceNotFound(String),
46}
47
48impl From<MediaType> for OutputFormat {
49    fn from(media_type: MediaType) -> Self {
50        match media_type {
51            MediaType::ApplicationJson => OutputFormat::Json,
52            MediaType::PlainText => OutputFormat::Imds,
53        }
54    }
55}
56
57// Builds the `micro_http::Response` with a given HTTP version, status code, and body.
58fn build_response(
59    http_version: Version,
60    status_code: StatusCode,
61    content_type: MediaType,
62    body: Body,
63) -> Response {
64    let mut response = Response::new(http_version, status_code);
65    response.set_content_type(content_type);
66    response.set_body(body);
67    response
68}
69
70/// Patch provided JSON document (given as `serde_json::Value`) in-place with JSON Merge Patch
71/// [RFC 7396](https://tools.ietf.org/html/rfc7396).
72pub fn json_patch(target: &mut Value, patch: &Value) {
73    if patch.is_object() {
74        if !target.is_object() {
75            // Replace target with a serde_json object so we can recursively copy patch values.
76            *target = Value::Object(Map::new());
77        }
78
79        // This is safe since we make sure patch and target are objects beforehand.
80        let doc = target.as_object_mut().unwrap();
81        for (key, value) in patch.as_object().unwrap() {
82            if value.is_null() {
83                // If the value in the patch is null we remove the entry.
84                doc.remove(key.as_str());
85            } else {
86                // Recursive call to update target document.
87                // If `key` is not in the target document (it's a new field defined in `patch`)
88                // insert a null placeholder and pass it as the new target
89                // so we can insert new values recursively.
90                json_patch(doc.entry(key.as_str()).or_insert(Value::Null), value);
91            }
92        }
93    } else {
94        *target = patch.clone();
95    }
96}
97
98// Make the URI a correct JSON pointer value.
99fn sanitize_uri(mut uri: String) -> String {
100    let mut len = u32::MAX as usize;
101    // Loop while the deduping decreases the sanitized len.
102    // Each iteration will attempt to dedup "//".
103    while uri.len() < len {
104        len = uri.len();
105        uri = uri.replace("//", "/");
106    }
107
108    uri
109}
110
111/// Build a response for `request` and return response based on MMDS version
112pub fn convert_to_response(mmds: Arc<Mutex<Mmds>>, request: Request) -> Response {
113    // Check URI is not empty
114    let uri = request.uri().get_abs_path();
115    if uri.is_empty() {
116        return build_response(
117            request.http_version(),
118            StatusCode::BadRequest,
119            MediaType::PlainText,
120            Body::new(VmmMmdsError::InvalidURI.to_string()),
121        );
122    }
123
124    let mut mmds_guard = mmds.lock().expect("Poisoned lock");
125
126    // Allow only GET and PUT requests
127    match request.method() {
128        Method::Get => match mmds_guard.version() {
129            MmdsVersion::V1 => respond_to_get_request_v1(&mmds_guard, request),
130            MmdsVersion::V2 => respond_to_get_request_v2(&mmds_guard, request),
131        },
132        Method::Put => respond_to_put_request(&mut mmds_guard, request),
133        _ => {
134            let mut response = build_response(
135                request.http_version(),
136                StatusCode::MethodNotAllowed,
137                MediaType::PlainText,
138                Body::new(VmmMmdsError::MethodNotAllowed.to_string()),
139            );
140            response.allow_method(Method::Get);
141            response.allow_method(Method::Put);
142            response
143        }
144    }
145}
146
147fn respond_to_get_request_v1(mmds: &Mmds, request: Request) -> Response {
148    match get_header_value_pair(
149        request.headers.custom_entries(),
150        &[X_METADATA_TOKEN_HEADER, X_AWS_EC2_METADATA_TOKEN_HEADER],
151    ) {
152        Some((_, token)) => {
153            if !mmds.is_valid_token(token) {
154                METRICS.mmds.rx_invalid_token.inc();
155            }
156        }
157        None => {
158            METRICS.mmds.rx_no_token.inc();
159        }
160    }
161
162    respond_to_get_request(mmds, request)
163}
164
165fn respond_to_get_request_v2(mmds: &Mmds, request: Request) -> Response {
166    // Check whether a token exists.
167    let token = match get_header_value_pair(
168        request.headers.custom_entries(),
169        &[X_METADATA_TOKEN_HEADER, X_AWS_EC2_METADATA_TOKEN_HEADER],
170    ) {
171        Some((_, token)) => token,
172        None => {
173            METRICS.mmds.rx_no_token.inc();
174            let error_msg = VmmMmdsError::NoTokenProvided.to_string();
175            return build_response(
176                request.http_version(),
177                StatusCode::Unauthorized,
178                MediaType::PlainText,
179                Body::new(error_msg),
180            );
181        }
182    };
183
184    // Validate the token.
185    match mmds.is_valid_token(token) {
186        true => respond_to_get_request(mmds, request),
187        false => {
188            METRICS.mmds.rx_invalid_token.inc();
189            build_response(
190                request.http_version(),
191                StatusCode::Unauthorized,
192                MediaType::PlainText,
193                Body::new(VmmMmdsError::InvalidToken.to_string()),
194            )
195        }
196    }
197}
198
199fn respond_to_get_request(mmds: &Mmds, request: Request) -> Response {
200    let uri = request.uri().get_abs_path();
201
202    // The data store expects a strict json path, so we need to
203    // sanitize the URI.
204    let json_path = sanitize_uri(uri.to_string());
205
206    let content_type = request.headers.accept();
207
208    match mmds.get_value(json_path, content_type.into()) {
209        Ok(response_body) => build_response(
210            request.http_version(),
211            StatusCode::OK,
212            content_type,
213            Body::new(response_body),
214        ),
215        Err(err) => match err {
216            MmdsError::NotFound => {
217                let error_msg = VmmMmdsError::ResourceNotFound(String::from(uri)).to_string();
218                build_response(
219                    request.http_version(),
220                    StatusCode::NotFound,
221                    MediaType::PlainText,
222                    Body::new(error_msg),
223                )
224            }
225            MmdsError::UnsupportedValueType => build_response(
226                request.http_version(),
227                StatusCode::NotImplemented,
228                MediaType::PlainText,
229                Body::new(err.to_string()),
230            ),
231            MmdsError::DataStoreLimitExceeded => build_response(
232                request.http_version(),
233                StatusCode::PayloadTooLarge,
234                MediaType::PlainText,
235                Body::new(err.to_string()),
236            ),
237            _ => unreachable!(),
238        },
239    }
240}
241
242fn respond_to_put_request(mmds: &mut Mmds, request: Request) -> Response {
243    let custom_headers = request.headers.custom_entries();
244
245    // Reject `PUT` requests that contain `X-Forwarded-For` header.
246    if let Some((header, _)) = get_header_value_pair(custom_headers, &[X_FORWARDED_FOR_HEADER]) {
247        let error_msg =
248            RequestError::HeaderError(HttpHeaderError::UnsupportedName(header.to_string()))
249                .to_string();
250        return build_response(
251            request.http_version(),
252            StatusCode::BadRequest,
253            MediaType::PlainText,
254            Body::new(error_msg),
255        );
256    }
257
258    let uri = request.uri().get_abs_path();
259    // Sanitize the URI into a strict json path.
260    let json_path = sanitize_uri(uri.to_string());
261
262    // Only accept PUT requests towards TOKEN_PATH.
263    if json_path != PATH_TO_TOKEN {
264        let error_msg = VmmMmdsError::ResourceNotFound(String::from(uri)).to_string();
265        return build_response(
266            request.http_version(),
267            StatusCode::NotFound,
268            MediaType::PlainText,
269            Body::new(error_msg),
270        );
271    }
272
273    // Get token lifetime value.
274    let (header, ttl_seconds) = match get_header_value_pair(
275        custom_headers,
276        &[
277            X_METADATA_TOKEN_TTL_SECONDS_HEADER,
278            X_AWS_EC2_METADATA_TOKEN_SSL_SECONDS_HEADER,
279        ],
280    ) {
281        // Header found
282        Some((header, value)) => match value.parse::<u32>() {
283            Ok(ttl_seconds) => (header, ttl_seconds),
284            Err(_) => {
285                return build_response(
286                    request.http_version(),
287                    StatusCode::BadRequest,
288                    MediaType::PlainText,
289                    Body::new(
290                        RequestError::HeaderError(HttpHeaderError::InvalidValue(
291                            header.into(),
292                            value.into(),
293                        ))
294                        .to_string(),
295                    ),
296                );
297            }
298        },
299        // Header not found
300        None => {
301            return build_response(
302                request.http_version(),
303                StatusCode::BadRequest,
304                MediaType::PlainText,
305                Body::new(VmmMmdsError::NoTtlProvided.to_string()),
306            );
307        }
308    };
309
310    // Generate token.
311    let result = mmds.generate_token(ttl_seconds);
312    match result {
313        Ok(token) => {
314            let mut response = build_response(
315                request.http_version(),
316                StatusCode::OK,
317                MediaType::PlainText,
318                Body::new(token),
319            );
320            let custom_headers = [(header.into(), ttl_seconds.to_string())].into();
321            // Safe to unwrap because the header name and the value are valid as US-ASCII.
322            // - `header` is either `X_METADATA_TOKEN_TTL_SECONDS_HEADER` or
323            //   `X_AWS_EC2_METADATA_TOKEN_SSL_SECONDS_HEADER`.
324            // - `ttl_seconds` is a decimal number between `MIN_TOKEN_TTL_SECONDS` and
325            //   `MAX_TOKEN_TTL_SECONDS`.
326            response.set_custom_headers(&custom_headers).unwrap();
327            response
328        }
329        Err(err) => build_response(
330            request.http_version(),
331            StatusCode::BadRequest,
332            MediaType::PlainText,
333            Body::new(err.to_string()),
334        ),
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use std::time::Duration;
341
342    use super::*;
343    use crate::mmds::token::{MAX_TOKEN_TTL_SECONDS, MIN_TOKEN_TTL_SECONDS};
344
345    fn populate_mmds() -> Arc<Mutex<Mmds>> {
346        let data = r#"{
347            "name": {
348                "first": "John",
349                "second": "Doe"
350            },
351            "age": 43,
352            "phones": {
353                "home": {
354                    "RO": "+401234567",
355                    "UK": "+441234567"
356                },
357                "mobile": "+442345678"
358            }
359        }"#;
360        let mmds = Arc::new(Mutex::new(Mmds::default()));
361        mmds.lock()
362            .expect("Poisoned lock")
363            .put_data(serde_json::from_str(data).unwrap())
364            .unwrap();
365
366        mmds
367    }
368
369    fn get_json_data() -> &'static str {
370        r#"{
371            "age": 43,
372            "name": {
373                "first": "John",
374                "second": "Doe"
375            },
376            "phones": {
377                "home": {
378                    "RO": "+401234567",
379                    "UK": "+441234567"
380                },
381                "mobile": "+442345678"
382            }
383        }"#
384    }
385
386    fn get_plain_text_data() -> &'static str {
387        "age\nname/\nphones/"
388    }
389
390    fn generate_request_and_expected_response(
391        request_bytes: &[u8],
392        media_type: MediaType,
393    ) -> (Request, Response) {
394        let request = Request::try_from(request_bytes, None).unwrap();
395
396        let mut response = Response::new(Version::Http10, StatusCode::OK);
397        response.set_content_type(media_type);
398        let body = match media_type {
399            MediaType::ApplicationJson => {
400                let mut body = get_json_data().to_string();
401                body.retain(|c| !c.is_whitespace());
402                body
403            }
404            MediaType::PlainText => get_plain_text_data().to_string(),
405        };
406        response.set_body(Body::new(body));
407
408        (request, response)
409    }
410
411    #[test]
412    fn test_sanitize_uri() {
413        let sanitized = "/a/b/c/d";
414        assert_eq!(sanitize_uri("/a/b/c/d".to_owned()), sanitized);
415        assert_eq!(sanitize_uri("/a////b/c//d".to_owned()), sanitized);
416        assert_eq!(sanitize_uri("/a///b/c///d".to_owned()), sanitized);
417        assert_eq!(sanitize_uri("/a//b/c////d".to_owned()), sanitized);
418        assert_eq!(sanitize_uri("///////a//b///c//d".to_owned()), sanitized);
419        assert_eq!(sanitize_uri("a".to_owned()), "a");
420        assert_eq!(sanitize_uri("a/".to_owned()), "a/");
421        assert_eq!(sanitize_uri("aa//".to_owned()), "aa/");
422        assert_eq!(sanitize_uri("aa".to_owned()), "aa");
423        assert_eq!(sanitize_uri("/".to_owned()), "/");
424        assert_eq!(sanitize_uri("".to_owned()), "");
425        assert_eq!(sanitize_uri("////".to_owned()), "/");
426        assert_eq!(sanitize_uri("aa//bb///cc//d".to_owned()), "aa/bb/cc/d");
427        assert_eq!(sanitize_uri("//aa//bb///cc//d".to_owned()), "/aa/bb/cc/d");
428    }
429
430    #[test]
431    fn test_request_accept_header() {
432        // This test validates the response `Content-Type` header and the response content for
433        // various request `Accept` headers.
434
435        // Populate MMDS with data.
436        let mmds = populate_mmds();
437
438        // Test without `Accept` header. micro-http defaults to `Accept: text/plain`.
439        let (request, expected_response) = generate_request_and_expected_response(
440            b"GET http://169.254.169.254/ HTTP/1.0\r\n\r\n",
441            MediaType::PlainText,
442        );
443        assert_eq!(
444            convert_to_response(mmds.clone(), request),
445            expected_response
446        );
447
448        // Test with empty `Accept` header. micro-http defaults to `Accept: text/plain`.
449        let (request, expected_response) = generate_request_and_expected_response(
450            b"GET http://169.254.169.254/ HTTP/1.0\r\n\"
451              Accept:\r\n\r\n",
452            MediaType::PlainText,
453        );
454        assert_eq!(
455            convert_to_response(mmds.clone(), request),
456            expected_response
457        );
458
459        // Test with `Accept: */*` header.
460        let (request, expected_response) = generate_request_and_expected_response(
461            b"GET http://169.254.169.254/ HTTP/1.0\r\n\"
462              Accept: */*\r\n\r\n",
463            MediaType::PlainText,
464        );
465        assert_eq!(
466            convert_to_response(mmds.clone(), request),
467            expected_response
468        );
469
470        // Test with `Accept: text/plain`.
471        let (request, expected_response) = generate_request_and_expected_response(
472            b"GET http://169.254.169.254/ HTTP/1.0\r\n\
473              Accept: text/plain\r\n\r\n",
474            MediaType::PlainText,
475        );
476        assert_eq!(
477            convert_to_response(mmds.clone(), request),
478            expected_response
479        );
480
481        // Test with `Accept: application/json`.
482        let (request, expected_response) = generate_request_and_expected_response(
483            b"GET http://169.254.169.254/ HTTP/1.0\r\n\
484              Accept: application/json\r\n\r\n",
485            MediaType::ApplicationJson,
486        );
487        assert_eq!(convert_to_response(mmds, request), expected_response);
488    }
489
490    // Test the version-independent error paths of `convert_to_response()`.
491    #[test]
492    fn test_convert_to_response_negative() {
493        for version in [MmdsVersion::V1, MmdsVersion::V2] {
494            let mmds = populate_mmds();
495            mmds.lock().expect("Poisoned lock").set_version(version);
496
497            // Test InvalidURI (empty absolute path).
498            let request = Request::try_from(b"GET http:// HTTP/1.0\r\n\r\n", None).unwrap();
499            let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest);
500            expected_response.set_content_type(MediaType::PlainText);
501            expected_response.set_body(Body::new(VmmMmdsError::InvalidURI.to_string()));
502            let actual_response = convert_to_response(mmds.clone(), request);
503            assert_eq!(actual_response, expected_response);
504
505            // Test MethodNotAllowed (PATCH method).
506            let request =
507                Request::try_from(b"PATCH http://169.254.169.255/ HTTP/1.0\r\n\r\n", None).unwrap();
508            let mut expected_response =
509                Response::new(Version::Http10, StatusCode::MethodNotAllowed);
510            expected_response.set_content_type(MediaType::PlainText);
511            expected_response.set_body(Body::new(VmmMmdsError::MethodNotAllowed.to_string()));
512            expected_response.allow_method(Method::Get);
513            expected_response.allow_method(Method::Put);
514            let actual_response = convert_to_response(mmds.clone(), request);
515            assert_eq!(actual_response, expected_response);
516        }
517    }
518
519    #[test]
520    fn test_respond_to_request_mmdsv1() {
521        let mmds = populate_mmds();
522        mmds.lock()
523            .expect("Poisoned lock")
524            .set_version(MmdsVersion::V1);
525
526        // Test valid v1 GET request.
527        let (request, expected_response) = generate_request_and_expected_response(
528            b"GET http://169.254.169.254/ HTTP/1.0\r\n\
529              Accept: application/json\r\n\r\n",
530            MediaType::ApplicationJson,
531        );
532        let prev_rx_invalid_token = METRICS.mmds.rx_invalid_token.count();
533        let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
534        let actual_response = convert_to_response(mmds.clone(), request);
535        assert_eq!(actual_response, expected_response);
536        assert_eq!(prev_rx_invalid_token, METRICS.mmds.rx_invalid_token.count());
537        assert_eq!(prev_rx_no_token + 1, METRICS.mmds.rx_no_token.count());
538
539        // Test valid PUT request to generate a valid token.
540        let request = Request::try_from(
541            b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
542              X-metadata-token-ttl-seconds: 60\r\n\r\n",
543            None,
544        )
545        .unwrap();
546        let actual_response = convert_to_response(mmds.clone(), request);
547        assert_eq!(actual_response.status(), StatusCode::OK);
548        assert_eq!(actual_response.content_type(), MediaType::PlainText);
549        let valid_token = String::from_utf8(actual_response.body().unwrap().body).unwrap();
550
551        // Test valid v2 GET request.
552        #[rustfmt::skip]
553        let (request, expected_response) = generate_request_and_expected_response(
554            format!(
555                "GET http://169.254.169.254/ HTTP/1.0\r\n\
556                 Accept: application/json\r\n\
557                 X-metadata-token: {valid_token}\r\n\r\n",
558            )
559            .as_bytes(),
560            MediaType::ApplicationJson,
561        );
562        let prev_rx_invalid_token = METRICS.mmds.rx_invalid_token.count();
563        let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
564        let actual_response = convert_to_response(mmds.clone(), request);
565        assert_eq!(actual_response, expected_response);
566        assert_eq!(prev_rx_invalid_token, METRICS.mmds.rx_invalid_token.count());
567        assert_eq!(prev_rx_no_token, METRICS.mmds.rx_no_token.count());
568
569        // Test GET request with invalid token is accepted when v1 is configured.
570        let (request, expected_response) = generate_request_and_expected_response(
571            b"GET http://169.254.169.254/ HTTP/1.0\r\n\
572              Accept: application/json\r\n\
573              X-metadata-token: INVALID_TOKEN\r\n\r\n",
574            MediaType::ApplicationJson,
575        );
576        let prev_rx_invalid_token = METRICS.mmds.rx_invalid_token.count();
577        let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
578        let actual_response = convert_to_response(mmds, request);
579        assert_eq!(actual_response, expected_response);
580        assert_eq!(
581            prev_rx_invalid_token + 1,
582            METRICS.mmds.rx_invalid_token.count()
583        );
584        assert_eq!(prev_rx_no_token, METRICS.mmds.rx_no_token.count());
585    }
586
587    #[test]
588    fn test_respond_to_request_mmdsv2() {
589        let mmds = populate_mmds();
590        mmds.lock()
591            .expect("Poisoned lock")
592            .set_version(MmdsVersion::V2);
593
594        // Test valid PUT to generate a valid token.
595        let request = Request::try_from(
596            b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
597              X-metadata-token-ttl-seconds: 60\r\n\r\n",
598            None,
599        )
600        .unwrap();
601        let actual_response = convert_to_response(mmds.clone(), request);
602        assert_eq!(actual_response.status(), StatusCode::OK);
603        assert_eq!(actual_response.content_type(), MediaType::PlainText);
604        let valid_token = String::from_utf8(actual_response.body().unwrap().body).unwrap();
605
606        // Test valid GET.
607        #[rustfmt::skip]
608        let (request, expected_response) = generate_request_and_expected_response(
609            format!(
610                "GET http://169.254.169.254/ HTTP/1.0\r\n\
611                 Accept: application/json\r\n\
612                 X-metadata-token: {valid_token}\r\n\r\n",
613            )
614            .as_bytes(),
615            MediaType::ApplicationJson,
616        );
617        let prev_rx_invalid_token = METRICS.mmds.rx_invalid_token.count();
618        let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
619        let actual_response = convert_to_response(mmds.clone(), request);
620        assert_eq!(actual_response, expected_response);
621        assert_eq!(prev_rx_invalid_token, METRICS.mmds.rx_invalid_token.count());
622        assert_eq!(prev_rx_no_token, METRICS.mmds.rx_no_token.count());
623
624        // Test GET request without token should return Unauthorized status code.
625        let request =
626            Request::try_from(b"GET http://169.254.169.254/ HTTP/1.0\r\n\r\n", None).unwrap();
627        let mut expected_response = Response::new(Version::Http10, StatusCode::Unauthorized);
628        expected_response.set_content_type(MediaType::PlainText);
629        expected_response.set_body(Body::new(VmmMmdsError::NoTokenProvided.to_string()));
630        let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
631        let actual_response = convert_to_response(mmds.clone(), request);
632        assert_eq!(actual_response, expected_response);
633        assert_eq!(prev_rx_no_token + 1, METRICS.mmds.rx_no_token.count());
634
635        // Create an expired token.
636        let request = Request::try_from(
637            b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
638              X-metadata-token-ttl-seconds: 1\r\n\r\n",
639            None,
640        )
641        .unwrap();
642        let actual_response = convert_to_response(mmds.clone(), request);
643        assert_eq!(actual_response.status(), StatusCode::OK);
644        assert_eq!(actual_response.content_type(), MediaType::PlainText);
645        let expired_token = String::from_utf8(actual_response.body().unwrap().body).unwrap();
646        std::thread::sleep(Duration::from_secs(1));
647
648        // Test GET request with invalid tokens.
649        let tokens = ["INVALID_TOKEN", &expired_token];
650        for token in tokens.iter() {
651            #[rustfmt::skip]
652            let request = Request::try_from(
653                format!(
654                    "GET http://169.254.169.254/ HTTP/1.0\r\n\
655                     X-metadata-token: {token}\r\n\r\n",
656                )
657                .as_bytes(),
658                None,
659            )
660            .unwrap();
661            let mut expected_response = Response::new(Version::Http10, StatusCode::Unauthorized);
662            expected_response.set_content_type(MediaType::PlainText);
663            expected_response.set_body(Body::new(VmmMmdsError::InvalidToken.to_string()));
664            let prev_rx_invalid_token = METRICS.mmds.rx_invalid_token.count();
665            let prev_rx_no_token = METRICS.mmds.rx_no_token.count();
666            let actual_response = convert_to_response(mmds.clone(), request);
667            assert_eq!(actual_response, expected_response);
668            assert_eq!(
669                prev_rx_invalid_token + 1,
670                METRICS.mmds.rx_invalid_token.count()
671            );
672            assert_eq!(prev_rx_no_token, METRICS.mmds.rx_no_token.count());
673        }
674    }
675
676    // Test the version-independent parts of GET request
677    #[test]
678    fn test_respond_to_get_request() {
679        for version in [MmdsVersion::V1, MmdsVersion::V2] {
680            let mmds = populate_mmds();
681            mmds.lock().expect("Poisoned lock").set_version(version);
682
683            // Generate a token
684            let request = Request::try_from(
685                b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
686                  X-metadata-token-ttl-seconds: 60\r\n\r\n",
687                None,
688            )
689            .unwrap();
690            let actual_response = convert_to_response(mmds.clone(), request);
691            assert_eq!(actual_response.status(), StatusCode::OK);
692            assert_eq!(actual_response.content_type(), MediaType::PlainText);
693            let valid_token = String::from_utf8(actual_response.body().unwrap().body).unwrap();
694
695            // Test invalid path
696            #[rustfmt::skip]
697            let request = Request::try_from(
698                format!(
699                    "GET http://169.254.169.254/invalid HTTP/1.0\r\n\
700                     X-metadata-token: {valid_token}\r\n\r\n",
701                )
702                .as_bytes(),
703                None,
704            )
705            .unwrap();
706            let mut expected_response = Response::new(Version::Http10, StatusCode::NotFound);
707            expected_response.set_content_type(MediaType::PlainText);
708            expected_response.set_body(Body::new(
709                VmmMmdsError::ResourceNotFound(String::from("/invalid")).to_string(),
710            ));
711            let actual_response = convert_to_response(mmds.clone(), request);
712            assert_eq!(actual_response, expected_response);
713
714            // Test unsupported type
715            #[rustfmt::skip]
716            let request = Request::try_from(
717                format!(
718                    "GET /age HTTP/1.1\r\n\
719                     X-metadata-token: {valid_token}\r\n\r\n",
720                )
721                .as_bytes(),
722                None,
723            )
724            .unwrap();
725            let mut expected_response = Response::new(Version::Http11, StatusCode::NotImplemented);
726            expected_response.set_content_type(MediaType::PlainText);
727            let body = "Cannot retrieve value. The value has an unsupported type.".to_string();
728            expected_response.set_body(Body::new(body));
729            let actual_response = convert_to_response(mmds.clone(), request);
730            assert_eq!(actual_response, expected_response);
731
732            // Test invalid `X-metadata-token-ttl-seconds` value is ignored if not PUT request.
733            #[rustfmt::skip]
734            let (request, expected_response) = generate_request_and_expected_response(
735                format!(
736                    "GET http://169.254.169.254/ HTTP/1.0\r\n\
737                     X-metadata-token: {valid_token}\r\n\
738                     X-metadata-token-ttl-seconds: application/json\r\n\r\n",
739                )
740                .as_bytes(),
741                MediaType::PlainText,
742            );
743            let actual_response = convert_to_response(mmds.clone(), request);
744            assert_eq!(actual_response, expected_response);
745        }
746    }
747
748    // Test PUT request (version-independent)
749    #[test]
750    fn test_respond_to_put_request() {
751        for version in [MmdsVersion::V1, MmdsVersion::V2] {
752            let mmds = populate_mmds();
753            mmds.lock().expect("Poisoned lock").set_version(version);
754
755            // Test valid PUT
756            let request = Request::try_from(
757                b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
758                  X-metadata-token-ttl-seconds: 60\r\n\r\n",
759                None,
760            )
761            .unwrap();
762            let actual_response = convert_to_response(mmds.clone(), request);
763            assert_eq!(actual_response.status(), StatusCode::OK);
764            assert_eq!(actual_response.content_type(), MediaType::PlainText);
765            assert_eq!(
766                actual_response
767                    .custom_headers()
768                    .get("X-metadata-token-ttl-seconds")
769                    .unwrap(),
770                "60"
771            );
772
773            // Test unsupported `X-Forwarded-For` header
774            for header in ["X-Forwarded-For", "x-forwarded-for", "X-fOrWaRdEd-FoR"] {
775                #[rustfmt::skip]
776                let request = Request::try_from(
777                    format!(
778                        "PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
779                         {header}: 203.0.113.195\r\n\r\n"
780                    )
781                    .as_bytes(),
782                    None,
783                )
784                .unwrap();
785                let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest);
786                expected_response.set_content_type(MediaType::PlainText);
787                expected_response.set_body(Body::new(format!(
788                    "Invalid header. Reason: Unsupported header name. Key: {header}"
789                )));
790                let actual_response = convert_to_response(mmds.clone(), request);
791                assert_eq!(actual_response, expected_response);
792            }
793
794            // Test invalid path
795            let request = Request::try_from(
796                b"PUT http://169.254.169.254/token HTTP/1.0\r\n\
797                  X-metadata-token-ttl-seconds: 60\r\n\r\n",
798                None,
799            )
800            .unwrap();
801            let mut expected_response = Response::new(Version::Http10, StatusCode::NotFound);
802            expected_response.set_content_type(MediaType::PlainText);
803            expected_response.set_body(Body::new(
804                VmmMmdsError::ResourceNotFound(String::from("/token")).to_string(),
805            ));
806            let actual_response = convert_to_response(mmds.clone(), request);
807            assert_eq!(actual_response, expected_response);
808
809            // Test non-numeric `X-metadata-token-ttl-seconds` value
810            let request = Request::try_from(
811                b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
812                  X-metadata-token-ttl-seconds: application/json\r\n\r\n",
813                None,
814            )
815            .unwrap();
816            let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest);
817            expected_response.set_content_type(MediaType::PlainText);
818            #[rustfmt::skip]
819            expected_response.set_body(Body::new(
820                "Invalid header. Reason: Invalid value. \
821                 Key:X-metadata-token-ttl-seconds; Value:application/json"
822                    .to_string(),
823            ));
824            let actual_response = convert_to_response(mmds.clone(), request);
825            assert_eq!(actual_response, expected_response);
826
827            // Test out-of-range `X-metadata-token-ttl-seconds` value
828            let invalid_values = [MIN_TOKEN_TTL_SECONDS - 1, MAX_TOKEN_TTL_SECONDS + 1];
829            for invalid_value in invalid_values.iter() {
830                #[rustfmt::skip]
831                let request = Request::try_from(
832                    format!(
833                        "PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\
834                         X-metadata-token-ttl-seconds: {invalid_value}\r\n\r\n",
835                    )
836                    .as_bytes(),
837                    None,
838                )
839                .unwrap();
840                let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest);
841                expected_response.set_content_type(MediaType::PlainText);
842                #[rustfmt::skip]
843                let error_msg = format!(
844                    "Invalid time to live value provided for token: {invalid_value}. \
845                     Please provide a value between {MIN_TOKEN_TTL_SECONDS} and {MAX_TOKEN_TTL_SECONDS}.",
846                );
847                expected_response.set_body(Body::new(error_msg));
848                let actual_response = convert_to_response(mmds.clone(), request);
849                assert_eq!(actual_response, expected_response);
850            }
851
852            // Test lack of `X-metadata-token-ttl-seconds` header
853            let request = Request::try_from(
854                b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\r\n",
855                None,
856            )
857            .unwrap();
858            let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest);
859            expected_response.set_content_type(MediaType::PlainText);
860            expected_response.set_body(Body::new(VmmMmdsError::NoTtlProvided.to_string()));
861            let actual_response = convert_to_response(mmds.clone(), request);
862            assert_eq!(actual_response, expected_response);
863        }
864    }
865
866    #[test]
867    fn test_json_patch() {
868        let mut data = serde_json::json!({
869            "name": {
870                "first": "John",
871                "second": "Doe"
872            },
873            "age": "43",
874            "phones": {
875                "home": {
876                    "RO": "+40 1234567",
877                    "UK": "+44 1234567"
878                },
879                "mobile": "+44 2345678"
880            }
881        });
882
883        let patch = serde_json::json!({
884            "name": {
885                "second": null,
886                "last": "Kennedy"
887            },
888            "age": "44",
889            "phones": {
890                "home": "+44 1234567",
891                "mobile": {
892                    "RO": "+40 2345678",
893                    "UK": "+44 2345678"
894                }
895            }
896        });
897        json_patch(&mut data, &patch);
898
899        // Test value replacement in target document.
900        assert_eq!(data["age"], patch["age"]);
901
902        // Test null value removal from target document.
903        assert_eq!(data["name"]["second"], Value::Null);
904
905        // Test add value to target document.
906        assert_eq!(data["name"]["last"], patch["name"]["last"]);
907        assert!(!data["phones"]["home"].is_object());
908        assert_eq!(data["phones"]["home"], patch["phones"]["home"]);
909        assert!(data["phones"]["mobile"].is_object());
910        assert_eq!(
911            data["phones"]["mobile"]["RO"],
912            patch["phones"]["mobile"]["RO"]
913        );
914        assert_eq!(
915            data["phones"]["mobile"]["UK"],
916            patch["phones"]["mobile"]["UK"]
917        );
918    }
919
920    #[test]
921    fn test_error_display() {
922        assert_eq!(
923            VmmMmdsError::InvalidToken.to_string(),
924            "MMDS token not valid."
925        );
926
927        assert_eq!(VmmMmdsError::InvalidURI.to_string(), "Invalid URI.");
928
929        assert_eq!(
930            VmmMmdsError::MethodNotAllowed.to_string(),
931            "Not allowed HTTP method."
932        );
933
934        assert_eq!(
935            VmmMmdsError::NoTokenProvided.to_string(),
936            "No MMDS token provided. Use `X-metadata-token` or `X-aws-ec2-metadata-token` header \
937             to specify the session token."
938        );
939
940        assert_eq!(
941            VmmMmdsError::NoTtlProvided.to_string(),
942            "Token time to live value not found. Use `X-metadata-token-ttl-seconds` or \
943             `X-aws-ec2-metadata-token-ttl-seconds` header to specify the token's lifetime."
944        );
945
946        assert_eq!(
947            VmmMmdsError::ResourceNotFound(String::from("invalid/")).to_string(),
948            "Resource not found: invalid/."
949        )
950    }
951}