tinycongress_api/
config.rs

1use figment::{
2    providers::{Env, Format, Serialized, Yaml},
3    Figment,
4};
5use serde::{Deserialize, Deserializer, Serialize};
6use serde_aux::prelude::deserialize_vec_from_string_or_vec;
7
8/// Application configuration loaded from multiple sources.
9///
10/// Configuration is loaded in priority order (lowest to highest):
11/// 1. Struct defaults
12/// 2. config.yaml file (if exists)
13/// 3. Environment variables with TC_ prefix (always wins)
14#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct Config {
16    pub database: DatabaseConfig,
17    pub server: ServerConfig,
18    pub logging: LoggingConfig,
19    #[serde(default)]
20    pub cors: CorsConfig,
21    #[serde(default)]
22    pub security_headers: SecurityHeadersConfig,
23    #[serde(default)]
24    pub graphql: GraphQLConfig,
25    #[serde(default)]
26    pub swagger: SwaggerConfig,
27}
28
29#[derive(Debug, Clone, Deserialize, Serialize)]
30pub struct DatabaseConfig {
31    /// `PostgreSQL` connection URL (required).
32    /// Example: `postgres://user:pass@host:5432/dbname`
33    pub url: String,
34
35    /// Maximum number of connections in the pool.
36    #[serde(default = "default_max_connections")]
37    pub max_connections: u32,
38
39    /// Optional custom migrations directory path.
40    pub migrations_dir: Option<String>,
41}
42
43#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct ServerConfig {
45    /// HTTP server port.
46    #[serde(default = "default_port")]
47    pub port: u16,
48
49    /// HTTP server bind address.
50    #[serde(default = "default_host")]
51    pub host: String,
52}
53
54#[derive(Debug, Clone, Deserialize, Serialize)]
55pub struct LoggingConfig {
56    /// Log level filter (debug, info, warn, error).
57    #[serde(default = "default_log_level")]
58    pub level: String,
59}
60
61#[derive(Debug, Clone, Deserialize, Serialize)]
62pub struct CorsConfig {
63    /// Allowed origins for CORS requests.
64    /// Use `"*"` to allow any origin (not recommended for production).
65    /// Accepts either an array or comma-separated string.
66    /// Example: `["http://localhost:5173"]` or `"http://localhost:5173,https://app.example.com"`
67    #[serde(
68        default = "default_allowed_origins",
69        deserialize_with = "deserialize_origins"
70    )]
71    pub allowed_origins: Vec<String>,
72}
73
74/// Deserialize origins from comma-separated string or array, filtering empty values.
75fn deserialize_origins<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
76where
77    D: Deserializer<'de>,
78{
79    let origins: Vec<String> = deserialize_vec_from_string_or_vec(deserializer)?;
80    Ok(origins.into_iter().filter(|s| !s.is_empty()).collect())
81}
82
83// These functions cannot be const because serde uses function pointers for defaults
84#[allow(clippy::missing_const_for_fn)]
85fn default_max_connections() -> u32 {
86    10
87}
88
89#[allow(clippy::missing_const_for_fn)]
90fn default_port() -> u16 {
91    8080
92}
93
94fn default_host() -> String {
95    "0.0.0.0".to_string()
96}
97
98fn default_log_level() -> String {
99    "info".to_string()
100}
101
102#[allow(clippy::missing_const_for_fn)]
103fn default_allowed_origins() -> Vec<String> {
104    // Default to empty (no cross-origin requests allowed) - safe for production
105    // Configure explicitly via TC_CORS__ALLOWED_ORIGINS or config.yaml
106    vec![]
107}
108
109impl Default for CorsConfig {
110    fn default() -> Self {
111        Self {
112            allowed_origins: default_allowed_origins(),
113        }
114    }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct SecurityHeadersConfig {
119    /// Enable security headers (default: true).
120    #[serde(default = "default_true")]
121    pub enabled: bool,
122
123    /// Enable HSTS header (default: false, enable in production with HTTPS).
124    #[serde(default)]
125    pub hsts_enabled: bool,
126
127    /// HSTS max-age in seconds (default: 31536000 = 1 year).
128    #[serde(default = "default_hsts_max_age")]
129    pub hsts_max_age: u64,
130
131    /// Include subdomains in HSTS (default: true).
132    #[serde(default = "default_true")]
133    pub hsts_include_subdomains: bool,
134
135    /// X-Frame-Options value: "DENY" or "SAMEORIGIN" (default: "DENY").
136    #[serde(default = "default_frame_options")]
137    pub frame_options: String,
138
139    /// Content-Security-Policy header value (default: "default-src 'self'").
140    #[serde(default = "default_csp")]
141    pub content_security_policy: String,
142
143    /// Referrer-Policy header value (default: "strict-origin-when-cross-origin").
144    #[serde(default = "default_referrer_policy")]
145    pub referrer_policy: String,
146}
147
148#[allow(clippy::missing_const_for_fn)]
149fn default_true() -> bool {
150    true
151}
152
153#[allow(clippy::missing_const_for_fn)]
154fn default_hsts_max_age() -> u64 {
155    31_536_000 // 1 year
156}
157
158fn default_frame_options() -> String {
159    "DENY".to_string()
160}
161
162fn default_csp() -> String {
163    "default-src 'self'".to_string()
164}
165
166fn default_referrer_policy() -> String {
167    "strict-origin-when-cross-origin".to_string()
168}
169
170impl Default for SecurityHeadersConfig {
171    fn default() -> Self {
172        Self {
173            enabled: default_true(),
174            hsts_enabled: false,
175            hsts_max_age: default_hsts_max_age(),
176            hsts_include_subdomains: default_true(),
177            frame_options: default_frame_options(),
178            content_security_policy: default_csp(),
179            referrer_policy: default_referrer_policy(),
180        }
181    }
182}
183
184#[derive(Debug, Clone, Default, Deserialize, Serialize)]
185pub struct GraphQLConfig {
186    /// Enable GraphQL Playground UI at /graphql (GET).
187    /// Default: false (disabled for security - exposes schema to potential attackers).
188    /// Enable in development via `TC_GRAPHQL__PLAYGROUND_ENABLED=true`
189    #[serde(default)]
190    pub playground_enabled: bool,
191}
192
193#[derive(Debug, Clone, Default, Deserialize, Serialize)]
194pub struct SwaggerConfig {
195    /// Enable Swagger UI at /swagger-ui.
196    /// Default: false (disabled for security - exposes API documentation).
197    /// Enable in development via `TC_SWAGGER__ENABLED=true`
198    #[serde(default)]
199    pub enabled: bool,
200}
201
202impl Default for Config {
203    fn default() -> Self {
204        Self {
205            database: DatabaseConfig {
206                url: String::new(), // Will fail validation if not provided
207                max_connections: default_max_connections(),
208                migrations_dir: None,
209            },
210            server: ServerConfig {
211                port: default_port(),
212                host: default_host(),
213            },
214            logging: LoggingConfig {
215                level: default_log_level(),
216            },
217            cors: CorsConfig::default(),
218            security_headers: SecurityHeadersConfig::default(),
219            graphql: GraphQLConfig::default(),
220            swagger: SwaggerConfig::default(),
221        }
222    }
223}
224
225/// Configuration loading and validation errors.
226#[derive(Debug, thiserror::Error)]
227pub enum ConfigError {
228    #[error("Configuration error: {0}")]
229    Figment(#[from] Box<figment::Error>),
230
231    #[error("Validation error: {0}")]
232    Validation(String),
233}
234
235impl From<figment::Error> for ConfigError {
236    fn from(err: figment::Error) -> Self {
237        Self::Figment(Box::new(err))
238    }
239}
240
241impl Config {
242    /// Load configuration from all sources.
243    ///
244    /// Sources are merged in priority order:
245    /// 1. Struct defaults (lowest)
246    /// 2. config.yaml file (if exists)
247    /// 3. Environment variables with TC_ prefix (highest)
248    ///
249    /// # Errors
250    /// Returns an error if configuration cannot be loaded or is invalid.
251    pub fn load() -> Result<Self, ConfigError> {
252        let config: Self = Figment::new()
253            .merge(Serialized::defaults(Self::default()))
254            .merge(Yaml::file("config.yaml"))
255            .merge(Env::prefixed("TC_").split("__"))
256            .extract()?;
257
258        config.validate()?;
259        Ok(config)
260    }
261
262    /// Load configuration with a custom YAML file path.
263    ///
264    /// # Errors
265    /// Returns an error if configuration cannot be loaded or is invalid.
266    pub fn load_from(yaml_path: &str) -> Result<Self, ConfigError> {
267        let config: Self = Figment::new()
268            .merge(Serialized::defaults(Self::default()))
269            .merge(Yaml::file(yaml_path))
270            .merge(Env::prefixed("TC_").split("__"))
271            .extract()?;
272
273        config.validate()?;
274        Ok(config)
275    }
276
277    /// Validate configuration values.
278    ///
279    /// # Errors
280    /// Returns an error if any configuration value is invalid.
281    pub fn validate(&self) -> Result<(), ConfigError> {
282        // Database URL is required and must be a postgres URL
283        if self.database.url.is_empty() {
284            return Err(ConfigError::Validation(
285                "database.url is required. Set TC_DATABASE__URL environment variable.".into(),
286            ));
287        }
288
289        if !self.database.url.starts_with("postgres://")
290            && !self.database.url.starts_with("postgresql://")
291        {
292            return Err(ConfigError::Validation(format!(
293                "database.url must start with postgres:// or postgresql://, got: {}",
294                &self.database.url[..self.database.url.len().min(20)]
295            )));
296        }
297
298        // Port must be non-zero
299        if self.server.port == 0 {
300            return Err(ConfigError::Validation("server.port cannot be 0".into()));
301        }
302
303        // Max connections must be at least 1
304        if self.database.max_connections == 0 {
305            return Err(ConfigError::Validation(
306                "database.max_connections cannot be 0".into(),
307            ));
308        }
309
310        // CORS origins must be valid URLs or "*"
311        for origin in &self.cors.allowed_origins {
312            if origin != "*" && !origin.starts_with("http://") && !origin.starts_with("https://") {
313                return Err(ConfigError::Validation(format!(
314                    "cors.allowed_origins contains invalid origin '{origin}'. Must be '*' or start with http:// or https://"
315                )));
316            }
317        }
318
319        // X-Frame-Options must be DENY or SAMEORIGIN
320        let frame_opts = self.security_headers.frame_options.to_uppercase();
321        if frame_opts != "DENY" && frame_opts != "SAMEORIGIN" {
322            return Err(ConfigError::Validation(format!(
323                "security_headers.frame_options must be 'DENY' or 'SAMEORIGIN', got: '{}'",
324                self.security_headers.frame_options
325            )));
326        }
327
328        Ok(())
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_defaults() {
338        let config = Config::default();
339        assert_eq!(config.server.port, 8080);
340        assert_eq!(config.server.host, "0.0.0.0");
341        assert_eq!(config.logging.level, "info");
342        assert_eq!(config.database.max_connections, 10);
343    }
344
345    #[test]
346    fn test_validation_rejects_empty_database_url() {
347        let config = Config::default();
348        let result = config.validate();
349        assert!(result.is_err());
350        assert!(result
351            .unwrap_err()
352            .to_string()
353            .contains("database.url is required"));
354    }
355
356    #[test]
357    fn test_validation_rejects_non_postgres_url() {
358        let mut config = Config::default();
359        config.database.url = "mysql://localhost/db".into();
360        let result = config.validate();
361        assert!(result.is_err());
362        assert!(result
363            .unwrap_err()
364            .to_string()
365            .contains("must start with postgres://"));
366    }
367
368    #[test]
369    fn test_validation_accepts_valid_config() {
370        let mut config = Config::default();
371        config.database.url = "postgres://localhost/test".into();
372        assert!(config.validate().is_ok());
373    }
374
375    #[test]
376    fn test_validation_accepts_postgresql_scheme() {
377        let mut config = Config::default();
378        config.database.url = "postgresql://localhost/test".into();
379        assert!(config.validate().is_ok());
380    }
381
382    #[test]
383    fn test_cors_defaults_to_empty() {
384        let config = CorsConfig::default();
385        assert!(config.allowed_origins.is_empty());
386    }
387
388    #[test]
389    fn test_cors_validation_accepts_valid_origins() {
390        let mut config = Config::default();
391        config.database.url = "postgres://localhost/test".into();
392        config.cors.allowed_origins = vec![
393            "http://localhost:3000".into(),
394            "https://app.example.com".into(),
395        ];
396        assert!(config.validate().is_ok());
397    }
398
399    #[test]
400    fn test_cors_validation_accepts_wildcard() {
401        let mut config = Config::default();
402        config.database.url = "postgres://localhost/test".into();
403        config.cors.allowed_origins = vec!["*".into()];
404        assert!(config.validate().is_ok());
405    }
406
407    #[test]
408    fn test_cors_validation_rejects_invalid_origin() {
409        let mut config = Config::default();
410        config.database.url = "postgres://localhost/test".into();
411        config.cors.allowed_origins = vec!["not-a-url".into()];
412        let result = config.validate();
413        assert!(result.is_err());
414        assert!(result.unwrap_err().to_string().contains("invalid origin"));
415    }
416
417    #[test]
418    fn test_cors_deserialize_comma_separated_string() {
419        // Simulate what figment does with env var
420        let json = r#"{"allowed_origins": "http://localhost:5173,https://app.example.com"}"#;
421        let config: CorsConfig = serde_json::from_str(json).expect("should parse");
422        assert_eq!(config.allowed_origins.len(), 2);
423        assert_eq!(config.allowed_origins[0], "http://localhost:5173");
424        assert_eq!(config.allowed_origins[1], "https://app.example.com");
425    }
426
427    #[test]
428    fn test_cors_deserialize_array() {
429        let json = r#"{"allowed_origins": ["http://localhost:5173", "https://app.example.com"]}"#;
430        let config: CorsConfig = serde_json::from_str(json).expect("should parse");
431        assert_eq!(config.allowed_origins.len(), 2);
432        assert_eq!(config.allowed_origins[0], "http://localhost:5173");
433        assert_eq!(config.allowed_origins[1], "https://app.example.com");
434    }
435
436    #[test]
437    fn test_cors_deserialize_empty_string() {
438        let json = r#"{"allowed_origins": ""}"#;
439        let config: CorsConfig = serde_json::from_str(json).expect("should parse");
440        assert!(config.allowed_origins.is_empty());
441    }
442
443    #[test]
444    fn test_graphql_playground_disabled_by_default() {
445        let config = GraphQLConfig::default();
446        assert!(!config.playground_enabled);
447    }
448
449    #[test]
450    fn test_graphql_playground_can_be_enabled() {
451        let json = r#"{"playground_enabled": true}"#;
452        let config: GraphQLConfig = serde_json::from_str(json).expect("should parse");
453        assert!(config.playground_enabled);
454    }
455
456    #[test]
457    fn test_swagger_disabled_by_default() {
458        let config = SwaggerConfig::default();
459        assert!(!config.enabled);
460    }
461
462    #[test]
463    fn test_swagger_can_be_enabled() {
464        let json = r#"{"enabled": true}"#;
465        let config: SwaggerConfig = serde_json::from_str(json).expect("should parse");
466        assert!(config.enabled);
467    }
468
469    // Table-driven boundary tests for validation rules
470
471    #[test]
472    fn database_url_scheme_boundaries() {
473        let cases = [
474            ("postgres://localhost/db", true, "standard postgres"),
475            ("postgresql://localhost/db", true, "postgresql alias"),
476            ("postgres://", true, "minimal postgres URL"),
477            ("", false, "empty URL"),
478            ("mysql://localhost/db", false, "wrong scheme"),
479            ("http://localhost/db", false, "http scheme"),
480            ("postgrex://localhost/db", false, "typo in scheme"),
481            ("POSTGRES://localhost/db", false, "uppercase scheme"),
482        ];
483
484        for (url, should_pass, desc) in cases {
485            let mut config = Config::default();
486            config.database.url = url.into();
487            let result = config.validate();
488            assert_eq!(
489                result.is_ok(),
490                should_pass,
491                "case '{}': expected {}, got {:?}",
492                desc,
493                should_pass,
494                result
495            );
496        }
497    }
498
499    #[test]
500    fn port_boundaries() {
501        let cases = [
502            (0u16, false, "zero port"),
503            (1, true, "minimum valid port"),
504            (80, true, "common HTTP port"),
505            (8080, true, "default port"),
506            (65535, true, "maximum port"),
507        ];
508
509        for (port, should_pass, desc) in cases {
510            let mut config = Config::default();
511            config.database.url = "postgres://localhost/db".into();
512            config.server.port = port;
513            let result = config.validate();
514            assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
515        }
516    }
517
518    #[test]
519    fn max_connections_boundaries() {
520        let cases = [
521            (0u32, false, "zero connections"),
522            (1, true, "minimum valid"),
523            (10, true, "default value"),
524            (100, true, "high value"),
525        ];
526
527        for (max, should_pass, desc) in cases {
528            let mut config = Config::default();
529            config.database.url = "postgres://localhost/db".into();
530            config.database.max_connections = max;
531            let result = config.validate();
532            assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
533        }
534    }
535
536    #[test]
537    fn cors_origin_boundaries() {
538        let cases = [
539            (vec!["*"], true, "wildcard"),
540            (vec!["http://localhost"], true, "http localhost"),
541            (vec!["https://example.com"], true, "https domain"),
542            (vec!["http://localhost:3000"], true, "with port"),
543            (vec![], true, "empty list"),
544            (vec!["ftp://files.com"], false, "ftp scheme"),
545            (vec!["localhost"], false, "no scheme"),
546            (vec!["//example.com"], false, "protocol-relative"),
547        ];
548
549        for (origins, should_pass, desc) in cases {
550            let mut config = Config::default();
551            config.database.url = "postgres://localhost/db".into();
552            config.cors.allowed_origins = origins.into_iter().map(String::from).collect();
553            let result = config.validate();
554            assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
555        }
556    }
557
558    #[test]
559    fn frame_options_boundaries() {
560        let cases = [
561            ("DENY", true, "uppercase DENY"),
562            ("SAMEORIGIN", true, "uppercase SAMEORIGIN"),
563            ("deny", true, "lowercase deny"),
564            ("sameorigin", true, "lowercase sameorigin"),
565            ("Deny", true, "mixed case Deny"),
566            ("ALLOW-FROM", false, "deprecated ALLOW-FROM"),
567            ("", false, "empty string"),
568            ("INVALID", false, "invalid value"),
569        ];
570
571        for (value, should_pass, desc) in cases {
572            let mut config = Config::default();
573            config.database.url = "postgres://localhost/db".into();
574            config.security_headers.frame_options = value.into();
575            let result = config.validate();
576            assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
577        }
578    }
579}