vmm/mmds/
token_headers.rs

1// Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5
6// `X-Forwarded-For`
7pub(crate) const X_FORWARDED_FOR_HEADER: &str = "x-forwarded-for";
8// `X-metadata-token`
9pub(crate) const X_METADATA_TOKEN_HEADER: &str = "x-metadata-token";
10// `X-aws-ec2-metadata-token`
11pub(crate) const X_AWS_EC2_METADATA_TOKEN_HEADER: &str = "x-aws-ec2-metadata-token";
12// `X-metadata-token-ttl-seconds`
13pub(crate) const X_METADATA_TOKEN_TTL_SECONDS_HEADER: &str = "x-metadata-token-ttl-seconds";
14// `X-aws-ec2-metadata-token-ttl-seconds`
15pub(crate) const X_AWS_EC2_METADATA_TOKEN_SSL_SECONDS_HEADER: &str =
16    "x-aws-ec2-metadata-token-ttl-seconds";
17
18pub(crate) fn get_header_value_pair<'a>(
19    custom_headers: &'a HashMap<String, String>,
20    headers: &'a [&'static str],
21) -> Option<(&'a String, &'a String)> {
22    custom_headers
23        .iter()
24        .find(|(k, _)| headers.iter().any(|header| k.eq_ignore_ascii_case(header)))
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30
31    fn to_mixed_case(s: &str) -> String {
32        s.chars()
33            .enumerate()
34            .map(|(i, c)| {
35                if i % 2 == 0 {
36                    c.to_ascii_lowercase()
37                } else {
38                    c.to_ascii_uppercase()
39                }
40            })
41            .collect()
42    }
43
44    #[test]
45    fn test_get_header_value_pair() {
46        let headers = [X_METADATA_TOKEN_HEADER, X_AWS_EC2_METADATA_TOKEN_HEADER];
47
48        // No custom headers
49        let custom_headers = HashMap::default();
50        let token = get_header_value_pair(&custom_headers, &headers);
51        assert!(token.is_none());
52
53        // Unrelated custom headers
54        let custom_headers = HashMap::from([
55            ("Some-Header".into(), "10".into()),
56            ("Another-Header".into(), "value".into()),
57        ]);
58        let token = get_header_value_pair(&custom_headers, &headers);
59        assert!(token.is_none());
60
61        for header in headers {
62            // Valid header
63            let expected = "THIS_IS_TOKEN";
64            let custom_headers = HashMap::from([(header.into(), expected.into())]);
65            let token = get_header_value_pair(&custom_headers, &headers).unwrap();
66            assert_eq!(token, (&header.into(), &expected.into()));
67
68            // Valid header in unrelated custom headers
69            let custom_headers = HashMap::from([
70                ("Some-Header".into(), "10".into()),
71                ("Another-Header".into(), "value".into()),
72                (header.into(), expected.into()),
73            ]);
74            let token = get_header_value_pair(&custom_headers, &headers).unwrap();
75            assert_eq!(token, (&header.into(), &expected.into()));
76
77            // Test case-insensitiveness
78            let header = to_mixed_case(header);
79            let custom_headers = HashMap::from([(header.clone(), expected.into())]);
80            let token = get_header_value_pair(&custom_headers, &headers).unwrap();
81            assert_eq!(token, (&header, &expected.into()));
82        }
83    }
84}