1use std::collections::HashMap;
5use std::io::Read;
6use std::sync::Arc;
7
8use bincode::config;
9use bincode::config::{Configuration, Fixint, Limit, LittleEndian};
10
11const DESERIALIZATION_BYTES_LIMIT: usize = 100_000;
16
17const BINCODE_CONFIG: Configuration<LittleEndian, Fixint, Limit<DESERIALIZATION_BYTES_LIMIT>> =
18 config::standard()
19 .with_fixed_int_encoding()
20 .with_limit::<DESERIALIZATION_BYTES_LIMIT>()
21 .with_little_endian();
22
23pub type BpfInstruction = u64;
27
28pub type BpfProgram = Vec<BpfInstruction>;
30
31pub type BpfProgramRef<'a> = &'a [BpfInstruction];
33
34pub type BpfThreadMap = HashMap<String, Arc<BpfProgram>>;
36
37pub type DeserializationError = bincode::error::DecodeError;
39
40pub fn get_empty_filters() -> BpfThreadMap {
42 let mut map = BpfThreadMap::new();
43 map.insert("vmm".to_string(), Arc::new(vec![]));
44 map.insert("api".to_string(), Arc::new(vec![]));
45 map.insert("vcpu".to_string(), Arc::new(vec![]));
46 map
47}
48
49pub fn deserialize_binary<R: Read>(mut reader: R) -> Result<BpfThreadMap, DeserializationError> {
51 let result: HashMap<String, _> = bincode::decode_from_std_read(&mut reader, BINCODE_CONFIG)?;
52
53 Ok(result
54 .into_iter()
55 .map(|(k, v)| (k.to_lowercase(), Arc::new(v)))
56 .collect())
57}
58
59#[derive(Debug, thiserror::Error, displaydoc::Display)]
61pub enum InstallationError {
62 FilterTooLarge,
64 Prctl(std::io::Error),
66}
67
68pub const BPF_MAX_LEN: usize = 4096;
70
71#[repr(C)]
74#[derive(Debug)]
75struct SockFprog {
76 len: u16,
77 filter: *const BpfInstruction,
78}
79
80pub fn apply_filter(bpf_filter: BpfProgramRef) -> Result<(), InstallationError> {
82 if bpf_filter.is_empty() {
84 return Ok(());
85 }
86
87 if BPF_MAX_LEN < bpf_filter.len() {
90 return Err(InstallationError::FilterTooLarge);
91 }
92
93 let bpf_filter_len =
94 u16::try_from(bpf_filter.len()).map_err(|_| InstallationError::FilterTooLarge)?;
95
96 unsafe {
98 {
99 let rc = libc::prctl(libc::PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0);
100 if rc != 0 {
101 return Err(InstallationError::Prctl(std::io::Error::last_os_error()));
102 }
103 }
104
105 let bpf_prog = SockFprog {
106 len: bpf_filter_len,
107 filter: bpf_filter.as_ptr(),
108 };
109 let bpf_prog_ptr = &bpf_prog as *const SockFprog;
110 {
111 let rc = libc::syscall(
112 libc::SYS_seccomp,
113 libc::SECCOMP_SET_MODE_FILTER,
114 0,
115 bpf_prog_ptr,
116 );
117 if rc != 0 {
118 return Err(InstallationError::Prctl(std::io::Error::last_os_error()));
119 }
120 }
121 }
122
123 Ok(())
124}
125
126#[cfg(test)]
127mod tests {
128 #![allow(clippy::undocumented_unsafe_blocks)]
129
130 use std::collections::HashMap;
131 use std::sync::Arc;
132 use std::thread;
133
134 use super::*;
135
136 #[test]
137 fn test_deserialize_binary() {
138 let data = "adassafvc".to_string();
140 deserialize_binary(data.as_bytes()).unwrap_err();
141
142 let bpf_prog = vec![0; 2];
145 let mut filter_map: HashMap<String, BpfProgram> = HashMap::new();
146 filter_map.insert("VcpU".to_string(), bpf_prog.clone());
147 let bytes = bincode::serde::encode_to_vec(&filter_map, BINCODE_CONFIG).unwrap();
148
149 let mut expected_res = BpfThreadMap::new();
150 expected_res.insert("vcpu".to_string(), Arc::new(bpf_prog));
151 assert_eq!(deserialize_binary(&bytes[..]).unwrap(), expected_res);
152
153 let bpf_prog = vec![0; DESERIALIZATION_BYTES_LIMIT + 1];
154 let mut filter_map: HashMap<String, BpfProgram> = HashMap::new();
155 filter_map.insert("VcpU".to_string(), bpf_prog.clone());
156 let bytes = bincode::serde::encode_to_vec(&filter_map, BINCODE_CONFIG).unwrap();
157 assert!(matches!(
158 deserialize_binary(&bytes[..]).unwrap_err(),
159 bincode::error::DecodeError::LimitExceeded
160 ));
161 }
162
163 #[test]
164 fn test_filter_apply() {
165 thread::spawn(|| {
167 let filter: BpfProgram = vec![0; 5000];
168
169 assert!(matches!(
171 apply_filter(&filter).unwrap_err(),
172 InstallationError::FilterTooLarge
173 ));
174 })
175 .join()
176 .unwrap();
177
178 thread::spawn(|| {
180 let filter: BpfProgram = vec![];
181
182 assert_eq!(filter.len(), 0);
183
184 let seccomp_level = unsafe { libc::prctl(libc::PR_GET_SECCOMP) };
185 assert_eq!(seccomp_level, 0);
186
187 apply_filter(&filter).unwrap();
188
189 let seccomp_level = unsafe { libc::prctl(libc::PR_GET_SECCOMP) };
191 assert_eq!(seccomp_level, 0);
192 })
193 .join()
194 .unwrap();
195
196 thread::spawn(|| {
198 let filter = vec![0xFF; 1];
199
200 let seccomp_level = unsafe { libc::prctl(libc::PR_GET_SECCOMP) };
201 assert_eq!(seccomp_level, 0);
202
203 assert!(matches!(
204 apply_filter(&filter).unwrap_err(),
205 InstallationError::Prctl(_)
206 ));
207
208 let seccomp_level = unsafe { libc::prctl(libc::PR_GET_SECCOMP) };
210 assert_eq!(seccomp_level, 0);
211 })
212 .join()
213 .unwrap();
214 }
215}