Line data Source code
1 : //! Security headers middleware for HTTP responses.
2 : //!
3 : //! This module provides middleware for adding security headers to HTTP responses,
4 : //! including CSP, HSTS, X-Frame-Options, and other protective headers.
5 :
6 : use std::sync::Arc;
7 :
8 : use axum::{
9 : extract::Request,
10 : http::header::{
11 : HeaderName, HeaderValue, CONTENT_SECURITY_POLICY, REFERRER_POLICY,
12 : STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS, X_XSS_PROTECTION,
13 : },
14 : middleware::Next,
15 : response::Response,
16 : Extension,
17 : };
18 :
19 : use crate::config::SecurityHeadersConfig;
20 :
21 : /// Build security headers from configuration.
22 : ///
23 : /// Returns an `Arc`-wrapped vector of header name/value pairs that can be
24 : /// shared across requests via Axum's `Extension` layer.
25 : #[must_use]
26 22 : pub fn build_security_headers(
27 22 : config: &SecurityHeadersConfig,
28 22 : ) -> Arc<Vec<(HeaderName, HeaderValue)>> {
29 22 : let mut headers = Vec::new();
30 :
31 : // X-Content-Type-Options: nosniff (always)
32 22 : headers.push((X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff")));
33 :
34 : // X-Frame-Options
35 22 : if let Ok(value) = HeaderValue::from_str(&config.frame_options) {
36 22 : headers.push((X_FRAME_OPTIONS, value));
37 22 : }
38 :
39 : // X-XSS-Protection (legacy but still useful for older browsers)
40 22 : headers.push((X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block")));
41 :
42 : // Content-Security-Policy
43 22 : if let Ok(value) = HeaderValue::from_str(&config.content_security_policy) {
44 22 : headers.push((CONTENT_SECURITY_POLICY, value));
45 22 : }
46 :
47 : // Referrer-Policy
48 22 : if let Ok(value) = HeaderValue::from_str(&config.referrer_policy) {
49 22 : headers.push((REFERRER_POLICY, value));
50 22 : }
51 :
52 : // HSTS (only if enabled - should only be used with HTTPS)
53 22 : if config.hsts_enabled {
54 1 : let hsts_value = if config.hsts_include_subdomains {
55 1 : format!("max-age={}; includeSubDomains", config.hsts_max_age)
56 : } else {
57 0 : format!("max-age={}", config.hsts_max_age)
58 : };
59 1 : if let Ok(value) = HeaderValue::from_str(&hsts_value) {
60 1 : headers.push((STRICT_TRANSPORT_SECURITY, value));
61 1 : }
62 21 : }
63 :
64 22 : Arc::new(headers)
65 22 : }
66 :
67 : /// Middleware to add security headers to all responses.
68 : ///
69 : /// This middleware reads the pre-built headers from an `Extension` and applies
70 : /// them to every response. It should be added as the outermost layer so headers
71 : /// are applied to all routes.
72 : ///
73 : /// # Example
74 : ///
75 : /// ```ignore
76 : /// use axum::{middleware, Router, Extension};
77 : /// use tinycongress_api::http::security::{build_security_headers, security_headers_middleware};
78 : /// use tinycongress_api::config::SecurityHeadersConfig;
79 : ///
80 : /// let config = SecurityHeadersConfig::default();
81 : /// let headers = build_security_headers(&config);
82 : ///
83 : /// let app = Router::new()
84 : /// // ... routes ...
85 : /// .layer(middleware::from_fn(security_headers_middleware))
86 : /// .layer(Extension(headers));
87 : /// ```
88 33 : pub async fn security_headers_middleware(
89 33 : Extension(headers): Extension<Arc<Vec<(HeaderName, HeaderValue)>>>,
90 33 : request: Request,
91 33 : next: Next,
92 33 : ) -> Response {
93 33 : let mut response = next.run(request).await;
94 33 : let response_headers = response.headers_mut();
95 165 : for (name, value) in headers.iter() {
96 165 : response_headers.insert(name.clone(), value.clone());
97 165 : }
98 33 : response
99 33 : }
100 :
101 : #[cfg(test)]
102 : mod tests {
103 : use super::*;
104 :
105 : #[test]
106 1 : fn test_build_security_headers_default() {
107 1 : let config = SecurityHeadersConfig::default();
108 1 : let headers = build_security_headers(&config);
109 :
110 : // Should have at least the mandatory headers
111 1 : assert!(headers.iter().any(|(n, _)| n == X_CONTENT_TYPE_OPTIONS));
112 2 : assert!(headers.iter().any(|(n, _)| n == X_FRAME_OPTIONS));
113 3 : assert!(headers.iter().any(|(n, _)| n == X_XSS_PROTECTION));
114 4 : assert!(headers.iter().any(|(n, _)| n == CONTENT_SECURITY_POLICY));
115 5 : assert!(headers.iter().any(|(n, _)| n == REFERRER_POLICY));
116 1 : }
117 :
118 : #[test]
119 1 : fn test_build_security_headers_with_hsts() {
120 1 : let mut config = SecurityHeadersConfig::default();
121 1 : config.hsts_enabled = true;
122 1 : config.hsts_max_age = 31_536_000;
123 1 : config.hsts_include_subdomains = true;
124 :
125 1 : let headers = build_security_headers(&config);
126 :
127 1 : let hsts = headers
128 1 : .iter()
129 6 : .find(|(n, _)| n == STRICT_TRANSPORT_SECURITY)
130 1 : .map(|(_, v)| v.to_str().unwrap_or_default());
131 :
132 1 : assert!(hsts.is_some());
133 1 : assert!(hsts.unwrap().contains("max-age=31536000"));
134 1 : assert!(hsts.unwrap().contains("includeSubDomains"));
135 1 : }
136 :
137 : #[test]
138 1 : fn test_build_security_headers_custom_frame_options() {
139 1 : let mut config = SecurityHeadersConfig::default();
140 1 : config.frame_options = "SAMEORIGIN".to_string();
141 :
142 1 : let headers = build_security_headers(&config);
143 :
144 1 : let frame_options = headers
145 1 : .iter()
146 2 : .find(|(n, _)| n == X_FRAME_OPTIONS)
147 1 : .map(|(_, v)| v.to_str().unwrap_or_default());
148 :
149 1 : assert_eq!(frame_options, Some("SAMEORIGIN"));
150 1 : }
151 : }
|