1use std::fmt;
5use std::fmt::{Display, Formatter};
6
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, to_vec};
9
10use crate::mmds::token::{MmdsTokenError as TokenError, TokenAuthority};
11
12#[derive(Debug)]
14pub struct Mmds {
15 version: MmdsVersion,
16 data_store: Value,
17 token_authority: TokenAuthority,
18 is_initialized: bool,
19 data_store_limit: usize,
20 imds_compat: bool,
21}
22
23#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
25pub enum MmdsVersion {
26 #[default]
27 V1,
29 V2,
31}
32
33impl Display for MmdsVersion {
34 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
35 match self {
36 MmdsVersion::V1 => write!(f, "V1"),
37 MmdsVersion::V2 => write!(f, "V2"),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy)]
44pub enum OutputFormat {
45 Json,
47 Imds,
49}
50
51#[derive(Debug, thiserror::Error, displaydoc::Display)]
52pub enum MmdsDatastoreError {
54 DataStoreLimitExceeded,
56 NotFound,
58 NotInitialized,
60 TokenAuthority(#[from] TokenError),
62 UnsupportedValueType,
64}
65
66impl Default for Mmds {
68 fn default() -> Self {
69 Self::try_new(51200).unwrap()
70 }
71}
72
73impl Mmds {
74 pub fn try_new(data_store_limit: usize) -> Result<Self, MmdsDatastoreError> {
76 Ok(Mmds {
77 version: MmdsVersion::default(),
78 data_store: Value::default(),
79 token_authority: TokenAuthority::try_new()?,
80 is_initialized: false,
81 data_store_limit,
82 imds_compat: false,
83 })
84 }
85
86 fn check_data_store_initialized(&self) -> Result<(), MmdsDatastoreError> {
90 if self.is_initialized {
91 Ok(())
92 } else {
93 Err(MmdsDatastoreError::NotInitialized)
94 }
95 }
96
97 pub fn set_version(&mut self, version: MmdsVersion) {
99 self.version = version;
100 }
101
102 pub fn version(&self) -> MmdsVersion {
104 self.version
105 }
106
107 pub fn set_imds_compat(&mut self, imds_compat: bool) {
109 self.imds_compat = imds_compat;
110 }
111
112 pub fn imds_compat(&self) -> bool {
114 self.imds_compat
115 }
116
117 pub fn set_aad(&mut self, instance_id: &str) {
120 self.token_authority.set_aad(instance_id);
121 }
122
123 pub fn is_valid_token(&self, token: &str) -> bool {
125 self.token_authority.is_valid(token)
126 }
127
128 pub fn generate_token(&mut self, ttl_seconds: u32) -> Result<String, TokenError> {
130 self.token_authority.generate_token_secret(ttl_seconds)
131 }
132
133 pub fn set_data_store_limit(&mut self, data_store_limit: usize) {
135 self.data_store_limit = data_store_limit;
136 }
137
138 pub fn put_data(&mut self, data: Value) -> Result<(), MmdsDatastoreError> {
140 if to_vec(&data).unwrap().len() > self.data_store_limit {
143 Err(MmdsDatastoreError::DataStoreLimitExceeded)
144 } else {
145 self.data_store = data;
146 self.is_initialized = true;
147
148 Ok(())
149 }
150 }
151
152 pub fn patch_data(&mut self, patch_data: Value) -> Result<(), MmdsDatastoreError> {
154 self.check_data_store_initialized()?;
155 let mut data_store_clone = self.data_store.clone();
156
157 super::json_patch(&mut data_store_clone, &patch_data);
158 if to_vec(&data_store_clone).unwrap().len() > self.data_store_limit {
161 return Err(MmdsDatastoreError::DataStoreLimitExceeded);
162 }
163 self.data_store = data_store_clone;
164 Ok(())
165 }
166
167 pub fn data_store_value(&self) -> Value {
172 self.data_store.clone()
173 }
174
175 fn format_imds(json: &Value) -> Result<String, MmdsDatastoreError> {
213 match json.as_object() {
217 Some(map) => {
218 let mut ret = Vec::new();
219 for key in map.keys() {
221 let mut key = key.clone();
222 if map[&key].is_object() {
225 key.push('/');
226 }
227
228 ret.push(key);
229 }
230 Ok(ret.join("\n"))
231 }
232 None => {
233 match json.as_str() {
236 Some(str_val) => Ok(str_val.to_string()),
237 None => Err(MmdsDatastoreError::UnsupportedValueType),
238 }
239 }
240 }
241 }
242
243 pub fn get_value(
246 &self,
247 path: String,
248 format: OutputFormat,
249 ) -> Result<String, MmdsDatastoreError> {
250 let value = if path.ends_with('/') {
253 self.data_store.pointer(&path.as_str()[..(path.len() - 1)])
254 } else {
255 self.data_store.pointer(path.as_str())
256 };
257
258 if let Some(json) = value {
259 match self.imds_compat {
260 true => Mmds::format_imds(json),
262 false => match format {
263 OutputFormat::Json => Ok(json.to_string()),
264 OutputFormat::Imds => Mmds::format_imds(json),
265 },
266 }
267 } else {
268 Err(MmdsDatastoreError::NotFound)
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 impl Mmds {
278 fn get_data_str(&self) -> String {
279 if self.data_store.is_null() {
280 return String::from("{}");
281 }
282 self.data_store.to_string()
283 }
284 }
285
286 #[test]
287 fn test_display_mmds_version() {
288 assert_eq!(MmdsVersion::V1.to_string(), "V1");
289 assert_eq!(MmdsVersion::V2.to_string(), "V2");
290 assert_eq!(MmdsVersion::default().to_string(), "V1");
291 }
292
293 #[test]
294 fn test_mmds_version() {
295 let mut mmds = Mmds::default();
296
297 assert_eq!(mmds.version(), MmdsVersion::V1);
299
300 mmds.set_version(MmdsVersion::V2);
302 assert_eq!(mmds.version(), MmdsVersion::V2);
303
304 mmds.set_version(MmdsVersion::V1);
306 assert_eq!(mmds.version(), MmdsVersion::V1);
307 }
308
309 #[test]
310 fn test_mmds() {
311 let mut mmds = Mmds::default();
312
313 assert_eq!(
314 mmds.check_data_store_initialized().unwrap_err().to_string(),
315 "The MMDS data store is not initialized.".to_string(),
316 );
317
318 let mut mmds_json = "{\"meta-data\":{\"iam\":\"dummy\"},\"user-data\":\"1522850095\"}";
319
320 mmds.put_data(serde_json::from_str(mmds_json).unwrap())
321 .unwrap();
322 mmds.check_data_store_initialized().unwrap();
323
324 assert_eq!(mmds.get_data_str(), mmds_json);
325
326 let patch_json = "{\"user-data\":\"10\"}";
328 mmds.patch_data(serde_json::from_str(patch_json).unwrap())
329 .unwrap();
330 mmds_json = "{\"meta-data\":{\"iam\":\"dummy\"},\"user-data\":\"10\"}";
331 assert_eq!(mmds.get_data_str(), mmds_json);
332 }
333
334 #[test]
335 fn test_get_value() {
336 for imds_compat in [false, true] {
337 let mut mmds = Mmds::default();
338 mmds.set_imds_compat(imds_compat);
339 let data = r#"{
340 "name": {
341 "first": "John",
342 "second": "Doe"
343 },
344 "age": 43,
345 "phones": [
346 "+401234567",
347 "+441234567"
348 ],
349 "member": false,
350 "shares_percentage": 12.12,
351 "balance": -24,
352 "json_string": "{\n \"hello\": \"world\"\n}"
353 }"#;
354 let data_store: Value = serde_json::from_str(data).unwrap();
355 mmds.put_data(data_store).unwrap();
356
357 for format in [OutputFormat::Imds, OutputFormat::Json] {
358 assert_eq!(
360 mmds.get_value("/invalid_path".to_string(), format)
361 .unwrap_err()
362 .to_string(),
363 MmdsDatastoreError::NotFound.to_string()
364 );
365
366 let expected = match (imds_compat, format) {
368 (false, OutputFormat::Imds) | (true, _) => "first\nsecond",
369 (false, OutputFormat::Json) => r#"{"first":"John","second":"Doe"}"#,
370 };
371 assert_eq!(
372 mmds.get_value("/name".to_string(), format).unwrap(),
373 expected
374 );
375
376 match (imds_compat, format) {
378 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
379 mmds.get_value("/age".to_string(), format)
380 .err()
381 .unwrap()
382 .to_string(),
383 MmdsDatastoreError::UnsupportedValueType.to_string()
384 ),
385 (false, OutputFormat::Json) => {
386 assert_eq!(mmds.get_value("/age".to_string(), format).unwrap(), "43")
387 }
388 };
389
390 match (imds_compat, format) {
393 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
394 mmds.get_value("/phones/".to_string(), format)
395 .err()
396 .unwrap()
397 .to_string(),
398 MmdsDatastoreError::UnsupportedValueType.to_string()
399 ),
400 (false, OutputFormat::Json) => assert_eq!(
401 mmds.get_value("/phones/".to_string(), format).unwrap(),
402 r#"["+401234567","+441234567"]"#
403 ),
404 }
405
406 match (imds_compat, format) {
408 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
409 mmds.get_value("/phones".to_string(), format)
410 .err()
411 .unwrap()
412 .to_string(),
413 MmdsDatastoreError::UnsupportedValueType.to_string()
414 ),
415 (false, OutputFormat::Json) => assert_eq!(
416 mmds.get_value("/phones".to_string(), format).unwrap(),
417 r#"["+401234567","+441234567"]"#
418 ),
419 }
420
421 let expected = match (imds_compat, format) {
423 (false, OutputFormat::Imds) | (true, _) => "+401234567",
424 (false, OutputFormat::Json) => "\"+401234567\"",
425 };
426 assert_eq!(
427 mmds.get_value("/phones/0/".to_string(), format).unwrap(),
428 expected
429 );
430
431 match (imds_compat, format) {
433 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
434 mmds.get_value("/member".to_string(), format)
435 .err()
436 .unwrap()
437 .to_string(),
438 MmdsDatastoreError::UnsupportedValueType.to_string()
439 ),
440 (false, OutputFormat::Json) => assert_eq!(
441 mmds.get_value("/member".to_string(), format).unwrap(),
442 "false"
443 ),
444 }
445
446 match (imds_compat, format) {
448 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
449 mmds.get_value("/shares_percentage".to_string(), format)
450 .err()
451 .unwrap()
452 .to_string(),
453 MmdsDatastoreError::UnsupportedValueType.to_string()
454 ),
455 (false, OutputFormat::Json) => assert_eq!(
456 mmds.get_value("/shares_percentage".to_string(), format)
457 .unwrap(),
458 "12.12"
459 ),
460 }
461
462 match (imds_compat, format) {
464 (false, OutputFormat::Imds) | (true, _) => assert_eq!(
465 mmds.get_value("/balance".to_string(), format)
466 .err()
467 .unwrap()
468 .to_string(),
469 MmdsDatastoreError::UnsupportedValueType.to_string(),
470 ),
471 (false, OutputFormat::Json) => assert_eq!(
472 mmds.get_value("/balance".to_string(), format).unwrap(),
473 "-24"
474 ),
475 }
476
477 let expected = match (imds_compat, format) {
479 (false, OutputFormat::Imds) | (true, _) => "{\n \"hello\": \"world\"\n}",
480 (false, OutputFormat::Json) => r#""{\n \"hello\": \"world\"\n}""#,
481 };
482 assert_eq!(
483 mmds.get_value("/json_string".to_string(), format).unwrap(),
484 expected
485 );
486 }
487 }
488 }
489
490 #[test]
491 fn test_update_data_store() {
492 let mut mmds = Mmds::default();
493
494 let data = r#"{
495 "name": {
496 "first": "John",
497 "second": "Doe"
498 },
499 "age": "43"
500 }"#;
501 let data_store: Value = serde_json::from_str(data).unwrap();
502 mmds.put_data(data_store).unwrap();
503
504 let data = r#"{
505 "name": {
506 "first": "John",
507 "second": "Doe"
508 },
509 "age": "100"
510 }"#;
511 let data_store: Value = serde_json::from_str(data).unwrap();
512 mmds.patch_data(data_store).unwrap();
513
514 let data = r#"{
515 "name": {
516 "first": "John",
517 "second": "Doe"
518 },
519 "age": 43
520 }"#;
521 let data_store: Value = serde_json::from_str(data).unwrap();
522 mmds.put_data(data_store).unwrap();
523
524 let data = r#"{
525 "name": {
526 "first": "John",
527 "second": null
528 },
529 "age": "43"
530 }"#;
531 let data_store: Value = serde_json::from_str(data).unwrap();
532 mmds.patch_data(data_store).unwrap();
533
534 let filling = (0..51151).map(|_| "X").collect::<String>();
535 let data = "{\"new_key\": \"".to_string() + &filling + "\"}";
536
537 let data_store: Value = serde_json::from_str(&data).unwrap();
538 mmds.patch_data(data_store).unwrap();
539
540 let data = "{\"new_key2\" : \"smth\"}";
541 let data_store: Value = serde_json::from_str(data).unwrap();
542 assert_eq!(
543 mmds.patch_data(data_store).unwrap_err().to_string(),
544 MmdsDatastoreError::DataStoreLimitExceeded.to_string()
545 );
546 assert!(!mmds.get_data_str().contains("smth"));
547
548 let data = "{\"new_key\" : \"smth\"}";
549 let data_store: Value = serde_json::from_str(data).unwrap();
550 mmds.patch_data(data_store).unwrap();
551 assert!(mmds.get_data_str().contains("smth"));
552 assert_eq!(mmds.get_data_str().len(), 53);
553
554 let data = "{\"new_key2\" : \"smth2\"}";
555 let data_store: Value = serde_json::from_str(data).unwrap();
556 mmds.patch_data(data_store).unwrap();
557 assert!(mmds.get_data_str().contains("smth2"));
558 assert_eq!(mmds.get_data_str().len(), 72);
559 }
560
561 #[test]
562 fn test_put_size_limit() {
563 let mut mmds = Mmds::default();
564 let filling = (0..51300).map(|_| "X").collect::<String>();
565 let data = "{\"key\": \"".to_string() + &filling + "\"}";
566
567 let data_store: Value = serde_json::from_str(&data).unwrap();
568
569 assert_eq!(
570 mmds.put_data(data_store).unwrap_err().to_string(),
571 MmdsDatastoreError::DataStoreLimitExceeded.to_string()
572 );
573
574 assert_eq!(mmds.get_data_str().len(), 2);
575 }
576}