vmm/logger/
logging.rs

1// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5use std::io::Write;
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8use std::sync::{Mutex, OnceLock};
9use std::thread;
10
11use log::{Log, Metadata, Record};
12use serde::{Deserialize, Deserializer, Serialize};
13use utils::time::LocalTime;
14
15use super::metrics::{IncMetric, METRICS};
16use crate::utils::open_file_write_nonblock;
17
18/// Default level filter for logger matching the swagger specification
19/// (`src/firecracker/swagger/firecracker.yaml`).
20pub const DEFAULT_LEVEL: log::LevelFilter = log::LevelFilter::Info;
21/// Default instance id.
22pub const DEFAULT_INSTANCE_ID: &str = "anonymous-instance";
23/// Instance id.
24pub static INSTANCE_ID: OnceLock<String> = OnceLock::new();
25
26/// The logger.
27///
28/// Default values matching the swagger specification (`src/firecracker/swagger/firecracker.yaml`).
29pub static LOGGER: Logger = Logger(Mutex::new(LoggerConfiguration {
30    target: None,
31    stdout: None,
32    filter: LogFilter { module: None },
33    format: LogFormat {
34        show_level: false,
35        show_log_origin: false,
36    },
37}));
38
39/// Error type for [`Logger::init`].
40pub type LoggerInitError = log::SetLoggerError;
41
42/// Error type for [`Logger::update`].
43#[derive(Debug, thiserror::Error)]
44#[error("Failed to open target file: {0}")]
45pub struct LoggerUpdateError(pub std::io::Error);
46
47impl Logger {
48    /// Initialize the logger.
49    pub fn init(&'static self) -> Result<(), LoggerInitError> {
50        log::set_logger(self)?;
51        log::set_max_level(DEFAULT_LEVEL);
52        let mut guard = self.0.lock().unwrap();
53        if guard.target.is_none() && guard.stdout.is_none() {
54            guard.stdout = open_stdout_nonblock();
55        }
56        Ok(())
57    }
58
59    /// Applies the given logger configuration the logger.
60    pub fn update(&self, config: LoggerConfig) -> Result<(), LoggerUpdateError> {
61        let mut guard = self.0.lock().unwrap();
62        log::set_max_level(
63            config
64                .level
65                .map(log::LevelFilter::from)
66                .unwrap_or(DEFAULT_LEVEL),
67        );
68
69        if let Some(log_path) = config.log_path {
70            let file = open_file_write_nonblock(&log_path).map_err(LoggerUpdateError)?;
71
72            guard.target = Some(file);
73        } else if guard.target.is_none() && guard.stdout.is_none() {
74            guard.stdout = open_stdout_nonblock();
75        };
76
77        if let Some(show_level) = config.show_level {
78            guard.format.show_level = show_level;
79        }
80
81        if let Some(show_log_origin) = config.show_log_origin {
82            guard.format.show_log_origin = show_log_origin;
83        }
84
85        if let Some(module) = config.module {
86            guard.filter.module = Some(module);
87        }
88
89        // Ensure we drop the guard before attempting to log, otherwise this
90        // would deadlock.
91        drop(guard);
92
93        Ok(())
94    }
95}
96
97#[derive(Debug)]
98pub struct LogFilter {
99    pub module: Option<String>,
100}
101#[derive(Debug)]
102pub struct LogFormat {
103    pub show_level: bool,
104    pub show_log_origin: bool,
105}
106#[derive(Debug)]
107pub struct LoggerConfiguration {
108    pub target: Option<std::fs::File>,
109    pub stdout: Option<std::fs::File>,
110    pub filter: LogFilter,
111    pub format: LogFormat,
112}
113#[derive(Debug)]
114pub struct Logger(pub Mutex<LoggerConfiguration>);
115
116impl Log for Logger {
117    // No additional filters to <https://docs.rs/log/latest/log/fn.max_level.html>.
118    fn enabled(&self, _metadata: &Metadata) -> bool {
119        true
120    }
121
122    fn log(&self, record: &Record) {
123        // Lock the logger.
124        let mut guard = self.0.lock().unwrap();
125
126        // Check if the log message is enabled
127        {
128            let enabled_module = match (&guard.filter.module, record.module_path()) {
129                (Some(filter), Some(source)) => source.starts_with(filter),
130                (Some(_), None) => false,
131                (None, _) => true,
132            };
133            let enabled = enabled_module;
134            if !enabled {
135                return;
136            }
137        }
138
139        // Prints log message
140        {
141            let thread = thread::current().name().unwrap_or("-").to_string();
142            let level = match guard.format.show_level {
143                true => format!(":{}", record.level()),
144                false => String::new(),
145            };
146
147            let origin = match guard.format.show_log_origin {
148                true => {
149                    let file = record.file().unwrap_or("?");
150                    let line = match record.line() {
151                        Some(x) => x.to_string(),
152                        None => String::from("?"),
153                    };
154                    format!(":{file}:{line}")
155                }
156                false => String::new(),
157            };
158
159            let message = format!(
160                "{} [{}:{thread}{level}{origin}] {}\n",
161                LocalTime::now(),
162                INSTANCE_ID
163                    .get()
164                    .map(|s| s.as_str())
165                    .unwrap_or(DEFAULT_INSTANCE_ID),
166                record.args()
167            );
168
169            let result = if let Some(file) = &mut guard.target {
170                file.write_all(message.as_bytes())
171            } else if let Some(file) = &mut guard.stdout {
172                file.write_all(message.as_bytes())
173            } else {
174                std::io::stdout().write_all(message.as_bytes())
175            };
176
177            // If the write returns an error, increment missed log count.
178            // No reason to log the error to stderr here, just increment the metric.
179            if result.is_err() {
180                METRICS.logger.missed_log_count.inc();
181            }
182        }
183    }
184
185    fn flush(&self) {}
186}
187
188fn open_stdout_nonblock() -> Option<std::fs::File> {
189    open_file_write_nonblock(Path::new("/dev/stdout")).ok()
190}
191
192/// Strongly typed structure used to describe the logger.
193#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
194#[serde(deny_unknown_fields)]
195pub struct LoggerConfig {
196    /// Named pipe or file used as output for logs.
197    pub log_path: Option<PathBuf>,
198    /// The level of the Logger.
199    pub level: Option<LevelFilter>,
200    /// Whether to show the log level in the log.
201    pub show_level: Option<bool>,
202    /// Whether to show the log origin in the log.
203    pub show_log_origin: Option<bool>,
204    /// The module to filter logs by.
205    pub module: Option<String>,
206}
207
208/// This is required since we originally supported `Warning` and uppercase variants being used as
209/// the log level filter. It would be a breaking change to no longer support this. In the next
210/// breaking release this should be removed (replaced with `log::LevelFilter` and only supporting
211/// its default deserialization).
212#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
213pub enum LevelFilter {
214    /// [`log::LevelFilter::Off`]
215    Off,
216    /// [`log::LevelFilter::Trace`]
217    Trace,
218    /// [`log::LevelFilter::Debug`]
219    Debug,
220    /// [`log::LevelFilter::Info`]
221    Info,
222    /// [`log::LevelFilter::Warn`]
223    Warn,
224    /// [`log::LevelFilter::Error`]
225    Error,
226}
227impl From<LevelFilter> for log::LevelFilter {
228    fn from(filter: LevelFilter) -> log::LevelFilter {
229        match filter {
230            LevelFilter::Off => log::LevelFilter::Off,
231            LevelFilter::Trace => log::LevelFilter::Trace,
232            LevelFilter::Debug => log::LevelFilter::Debug,
233            LevelFilter::Info => log::LevelFilter::Info,
234            LevelFilter::Warn => log::LevelFilter::Warn,
235            LevelFilter::Error => log::LevelFilter::Error,
236        }
237    }
238}
239impl<'de> Deserialize<'de> for LevelFilter {
240    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
241    where
242        D: Deserializer<'de>,
243    {
244        use serde::de::Error;
245        let key = String::deserialize(deserializer)?;
246
247        match key.to_lowercase().as_str() {
248            "off" => Ok(LevelFilter::Off),
249            "trace" => Ok(LevelFilter::Trace),
250            "debug" => Ok(LevelFilter::Debug),
251            "info" => Ok(LevelFilter::Info),
252            "warn" | "warning" => Ok(LevelFilter::Warn),
253            "error" => Ok(LevelFilter::Error),
254            _ => Err(D::Error::custom("Invalid LevelFilter")),
255        }
256    }
257}
258
259/// Error type for [`<LevelFilter as FromStr>::from_str`].
260#[derive(Debug, PartialEq, Eq, thiserror::Error)]
261#[error("Failed to parse string to level filter: {0}")]
262pub struct LevelFilterFromStrError(String);
263
264impl FromStr for LevelFilter {
265    type Err = LevelFilterFromStrError;
266    fn from_str(s: &str) -> Result<Self, Self::Err> {
267        match s.to_ascii_lowercase().as_str() {
268            "off" => Ok(Self::Off),
269            "trace" => Ok(Self::Trace),
270            "debug" => Ok(Self::Debug),
271            "info" => Ok(Self::Info),
272            "warn" | "warning" => Ok(Self::Warn),
273            "error" => Ok(Self::Error),
274            _ => Err(LevelFilterFromStrError(String::from(s))),
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use log::Level;
282
283    use super::*;
284
285    #[test]
286    fn levelfilter_from_levelfilter() {
287        assert_eq!(
288            log::LevelFilter::from(LevelFilter::Off),
289            log::LevelFilter::Off
290        );
291        assert_eq!(
292            log::LevelFilter::from(LevelFilter::Trace),
293            log::LevelFilter::Trace
294        );
295        assert_eq!(
296            log::LevelFilter::from(LevelFilter::Debug),
297            log::LevelFilter::Debug
298        );
299        assert_eq!(
300            log::LevelFilter::from(LevelFilter::Info),
301            log::LevelFilter::Info
302        );
303        assert_eq!(
304            log::LevelFilter::from(LevelFilter::Warn),
305            log::LevelFilter::Warn
306        );
307        assert_eq!(
308            log::LevelFilter::from(LevelFilter::Error),
309            log::LevelFilter::Error
310        );
311    }
312
313    #[test]
314    fn levelfilter_from_str_all_variants() {
315        use itertools::Itertools;
316
317        #[derive(Deserialize)]
318        struct Foo {
319            #[allow(dead_code)]
320            level: LevelFilter,
321        }
322
323        for (level, level_enum) in [
324            ("off", LevelFilter::Off),
325            ("trace", LevelFilter::Trace),
326            ("debug", LevelFilter::Debug),
327            ("info", LevelFilter::Info),
328            ("warn", LevelFilter::Warn),
329            ("warning", LevelFilter::Warn),
330            ("error", LevelFilter::Error),
331        ] {
332            let multi = level.chars().map(|_| 0..=1).multi_cartesian_product();
333            for combination in multi {
334                let variant = level
335                    .chars()
336                    .zip_eq(combination)
337                    .map(|(c, v)| match v {
338                        0 => c.to_ascii_lowercase(),
339                        1 => c.to_ascii_uppercase(),
340                        _ => unreachable!(),
341                    })
342                    .collect::<String>();
343
344                let ex = format!("{{ \"level\": \"{}\" }}", variant);
345                assert_eq!(LevelFilter::from_str(&variant), Ok(level_enum));
346                assert!(serde_json::from_str::<Foo>(&ex).is_ok(), "{ex}");
347            }
348        }
349        let ex = "{{ \"level\": \"blah\" }}".to_string();
350        assert!(
351            serde_json::from_str::<Foo>(&ex).is_err(),
352            "expected error got {ex:#?}"
353        );
354        assert_eq!(
355            LevelFilter::from_str("bad"),
356            Err(LevelFilterFromStrError(String::from("bad")))
357        );
358    }
359
360    #[test]
361    fn logger() {
362        // Get temp file path.
363        let file = vmm_sys_util::tempfile::TempFile::new().unwrap();
364        let path = file.as_path().to_str().unwrap().to_string();
365        drop(file);
366
367        // Create temp file.
368        let target = std::fs::OpenOptions::new()
369            .create(true)
370            .write(true)
371            .truncate(true)
372            .open(&path)
373            .unwrap();
374
375        // Create logger.
376        let logger = Logger(Mutex::new(LoggerConfiguration {
377            target: Some(target),
378            filter: LogFilter {
379                module: Some(String::from("module")),
380            },
381            format: LogFormat {
382                show_level: true,
383                show_log_origin: true,
384            },
385        }));
386
387        // Assert results of enabled given specific metadata.
388        assert!(logger.enabled(&Metadata::builder().level(Level::Warn).build()));
389        assert!(logger.enabled(&Metadata::builder().level(Level::Debug).build()));
390
391        // Log
392        let metadata = Metadata::builder().level(Level::Error).build();
393        let record = Record::builder()
394            .args(format_args!("Error!"))
395            .metadata(metadata)
396            .file(Some("dir/app.rs"))
397            .line(Some(200))
398            .module_path(Some("module::server"))
399            .build();
400        logger.log(&record);
401
402        // Test calling flush.
403        logger.flush();
404
405        // Asserts result of log.
406        let contents = std::fs::read_to_string(&path).unwrap();
407        let (_time, rest) = contents.split_once(' ').unwrap();
408        let thread = thread::current().name().unwrap_or("-").to_string();
409        assert_eq!(
410            rest,
411            format!("[{DEFAULT_INSTANCE_ID}:{thread}:ERROR:dir/app.rs:200] Error!\n")
412        );
413
414        std::fs::remove_file(path).unwrap();
415    }
416}