tinycongress_api/http/
security.rs1use std::sync::Arc;
7
8use axum::{
9 extract::Request,
10 http::header::{
11 HeaderName, HeaderValue, CONTENT_SECURITY_POLICY, REFERRER_POLICY,
12 STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS, X_XSS_PROTECTION,
13 },
14 middleware::Next,
15 response::Response,
16 Extension,
17};
18
19use crate::config::SecurityHeadersConfig;
20
21#[must_use]
26pub fn build_security_headers(
27 config: &SecurityHeadersConfig,
28) -> Arc<Vec<(HeaderName, HeaderValue)>> {
29 let mut headers = Vec::new();
30
31 headers.push((X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff")));
33
34 if let Ok(value) = HeaderValue::from_str(&config.frame_options) {
36 headers.push((X_FRAME_OPTIONS, value));
37 }
38
39 headers.push((X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block")));
41
42 if let Ok(value) = HeaderValue::from_str(&config.content_security_policy) {
44 headers.push((CONTENT_SECURITY_POLICY, value));
45 }
46
47 if let Ok(value) = HeaderValue::from_str(&config.referrer_policy) {
49 headers.push((REFERRER_POLICY, value));
50 }
51
52 if config.hsts_enabled {
54 let hsts_value = if config.hsts_include_subdomains {
55 format!("max-age={}; includeSubDomains", config.hsts_max_age)
56 } else {
57 format!("max-age={}", config.hsts_max_age)
58 };
59 if let Ok(value) = HeaderValue::from_str(&hsts_value) {
60 headers.push((STRICT_TRANSPORT_SECURITY, value));
61 }
62 }
63
64 Arc::new(headers)
65}
66
67pub async fn security_headers_middleware(
89 Extension(headers): Extension<Arc<Vec<(HeaderName, HeaderValue)>>>,
90 request: Request,
91 next: Next,
92) -> Response {
93 let mut response = next.run(request).await;
94 let response_headers = response.headers_mut();
95 for (name, value) in headers.iter() {
96 response_headers.insert(name.clone(), value.clone());
97 }
98 response
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_build_security_headers_default() {
107 let config = SecurityHeadersConfig::default();
108 let headers = build_security_headers(&config);
109
110 assert!(headers.iter().any(|(n, _)| n == X_CONTENT_TYPE_OPTIONS));
112 assert!(headers.iter().any(|(n, _)| n == X_FRAME_OPTIONS));
113 assert!(headers.iter().any(|(n, _)| n == X_XSS_PROTECTION));
114 assert!(headers.iter().any(|(n, _)| n == CONTENT_SECURITY_POLICY));
115 assert!(headers.iter().any(|(n, _)| n == REFERRER_POLICY));
116 }
117
118 #[test]
119 fn test_build_security_headers_with_hsts() {
120 let mut config = SecurityHeadersConfig::default();
121 config.hsts_enabled = true;
122 config.hsts_max_age = 31_536_000;
123 config.hsts_include_subdomains = true;
124
125 let headers = build_security_headers(&config);
126
127 let hsts = headers
128 .iter()
129 .find(|(n, _)| n == STRICT_TRANSPORT_SECURITY)
130 .map(|(_, v)| v.to_str().unwrap_or_default());
131
132 assert!(hsts.is_some());
133 assert!(hsts.unwrap().contains("max-age=31536000"));
134 assert!(hsts.unwrap().contains("includeSubDomains"));
135 }
136
137 #[test]
138 fn test_build_security_headers_custom_frame_options() {
139 let mut config = SecurityHeadersConfig::default();
140 config.frame_options = "SAMEORIGIN".to_string();
141
142 let headers = build_security_headers(&config);
143
144 let frame_options = headers
145 .iter()
146 .find(|(n, _)| n == X_FRAME_OPTIONS)
147 .map(|(_, v)| v.to_str().unwrap_or_default());
148
149 assert_eq!(frame_options, Some("SAMEORIGIN"));
150 }
151}