Line data Source code
1 : use figment::{
2 : providers::{Env, Format, Serialized, Yaml},
3 : Figment,
4 : };
5 : use serde::{Deserialize, Deserializer, Serialize};
6 : use serde_aux::prelude::deserialize_vec_from_string_or_vec;
7 :
8 : /// Application configuration loaded from multiple sources.
9 : ///
10 : /// Configuration is loaded in priority order (lowest to highest):
11 : /// 1. Struct defaults
12 : /// 2. config.yaml file (if exists)
13 : /// 3. Environment variables with TC_ prefix (always wins)
14 : #[derive(Debug, Clone, Deserialize, Serialize)]
15 : pub struct Config {
16 : pub database: DatabaseConfig,
17 : pub server: ServerConfig,
18 : pub logging: LoggingConfig,
19 : #[serde(default)]
20 : pub cors: CorsConfig,
21 : #[serde(default)]
22 : pub security_headers: SecurityHeadersConfig,
23 : #[serde(default)]
24 : pub graphql: GraphQLConfig,
25 : #[serde(default)]
26 : pub swagger: SwaggerConfig,
27 : }
28 :
29 : #[derive(Debug, Clone, Deserialize, Serialize)]
30 : pub struct DatabaseConfig {
31 : /// `PostgreSQL` connection URL (required).
32 : /// Example: `postgres://user:pass@host:5432/dbname`
33 : pub url: String,
34 :
35 : /// Maximum number of connections in the pool.
36 : #[serde(default = "default_max_connections")]
37 : pub max_connections: u32,
38 :
39 : /// Optional custom migrations directory path.
40 : pub migrations_dir: Option<String>,
41 : }
42 :
43 : #[derive(Debug, Clone, Deserialize, Serialize)]
44 : pub struct ServerConfig {
45 : /// HTTP server port.
46 : #[serde(default = "default_port")]
47 : pub port: u16,
48 :
49 : /// HTTP server bind address.
50 : #[serde(default = "default_host")]
51 : pub host: String,
52 : }
53 :
54 : #[derive(Debug, Clone, Deserialize, Serialize)]
55 : pub struct LoggingConfig {
56 : /// Log level filter (debug, info, warn, error).
57 : #[serde(default = "default_log_level")]
58 : pub level: String,
59 : }
60 :
61 : #[derive(Debug, Clone, Deserialize, Serialize)]
62 : pub struct CorsConfig {
63 : /// Allowed origins for CORS requests.
64 : /// Use `"*"` to allow any origin (not recommended for production).
65 : /// Accepts either an array or comma-separated string.
66 : /// Example: `["http://localhost:5173"]` or `"http://localhost:5173,https://app.example.com"`
67 : #[serde(
68 : default = "default_allowed_origins",
69 : deserialize_with = "deserialize_origins"
70 : )]
71 : pub allowed_origins: Vec<String>,
72 : }
73 :
74 : /// Deserialize origins from comma-separated string or array, filtering empty values.
75 3 : fn deserialize_origins<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
76 3 : where
77 3 : D: Deserializer<'de>,
78 : {
79 3 : let origins: Vec<String> = deserialize_vec_from_string_or_vec(deserializer)?;
80 5 : Ok(origins.into_iter().filter(|s| !s.is_empty()).collect())
81 3 : }
82 :
83 : // These functions cannot be const because serde uses function pointers for defaults
84 : #[allow(clippy::missing_const_for_fn)]
85 41 : fn default_max_connections() -> u32 {
86 41 : 10
87 41 : }
88 :
89 : #[allow(clippy::missing_const_for_fn)]
90 41 : fn default_port() -> u16 {
91 41 : 8080
92 41 : }
93 :
94 41 : fn default_host() -> String {
95 41 : "0.0.0.0".to_string()
96 41 : }
97 :
98 41 : fn default_log_level() -> String {
99 41 : "info".to_string()
100 41 : }
101 :
102 : #[allow(clippy::missing_const_for_fn)]
103 42 : fn default_allowed_origins() -> Vec<String> {
104 : // Default to empty (no cross-origin requests allowed) - safe for production
105 : // Configure explicitly via TC_CORS__ALLOWED_ORIGINS or config.yaml
106 42 : vec![]
107 42 : }
108 :
109 : impl Default for CorsConfig {
110 42 : fn default() -> Self {
111 42 : Self {
112 42 : allowed_origins: default_allowed_origins(),
113 42 : }
114 42 : }
115 : }
116 :
117 : #[derive(Debug, Clone, Deserialize, Serialize)]
118 : pub struct SecurityHeadersConfig {
119 : /// Enable security headers (default: true).
120 : #[serde(default = "default_true")]
121 : pub enabled: bool,
122 :
123 : /// Enable HSTS header (default: false, enable in production with HTTPS).
124 : #[serde(default)]
125 : pub hsts_enabled: bool,
126 :
127 : /// HSTS max-age in seconds (default: 31536000 = 1 year).
128 : #[serde(default = "default_hsts_max_age")]
129 : pub hsts_max_age: u64,
130 :
131 : /// Include subdomains in HSTS (default: true).
132 : #[serde(default = "default_true")]
133 : pub hsts_include_subdomains: bool,
134 :
135 : /// X-Frame-Options value: "DENY" or "SAMEORIGIN" (default: "DENY").
136 : #[serde(default = "default_frame_options")]
137 : pub frame_options: String,
138 :
139 : /// Content-Security-Policy header value (default: "default-src 'self'").
140 : #[serde(default = "default_csp")]
141 : pub content_security_policy: String,
142 :
143 : /// Referrer-Policy header value (default: "strict-origin-when-cross-origin").
144 : #[serde(default = "default_referrer_policy")]
145 : pub referrer_policy: String,
146 : }
147 :
148 : #[allow(clippy::missing_const_for_fn)]
149 128 : fn default_true() -> bool {
150 128 : true
151 128 : }
152 :
153 : #[allow(clippy::missing_const_for_fn)]
154 64 : fn default_hsts_max_age() -> u64 {
155 64 : 31_536_000 // 1 year
156 64 : }
157 :
158 64 : fn default_frame_options() -> String {
159 64 : "DENY".to_string()
160 64 : }
161 :
162 64 : fn default_csp() -> String {
163 64 : "default-src 'self'".to_string()
164 64 : }
165 :
166 64 : fn default_referrer_policy() -> String {
167 64 : "strict-origin-when-cross-origin".to_string()
168 64 : }
169 :
170 : impl Default for SecurityHeadersConfig {
171 64 : fn default() -> Self {
172 64 : Self {
173 64 : enabled: default_true(),
174 64 : hsts_enabled: false,
175 64 : hsts_max_age: default_hsts_max_age(),
176 64 : hsts_include_subdomains: default_true(),
177 64 : frame_options: default_frame_options(),
178 64 : content_security_policy: default_csp(),
179 64 : referrer_policy: default_referrer_policy(),
180 64 : }
181 64 : }
182 : }
183 :
184 : #[derive(Debug, Clone, Default, Deserialize, Serialize)]
185 : pub struct GraphQLConfig {
186 : /// Enable GraphQL Playground UI at /graphql (GET).
187 : /// Default: false (disabled for security - exposes schema to potential attackers).
188 : /// Enable in development via `TC_GRAPHQL__PLAYGROUND_ENABLED=true`
189 : #[serde(default)]
190 : pub playground_enabled: bool,
191 : }
192 :
193 : #[derive(Debug, Clone, Default, Deserialize, Serialize)]
194 : pub struct SwaggerConfig {
195 : /// Enable Swagger UI at /swagger-ui.
196 : /// Default: false (disabled for security - exposes API documentation).
197 : /// Enable in development via `TC_SWAGGER__ENABLED=true`
198 : #[serde(default)]
199 : pub enabled: bool,
200 : }
201 :
202 : impl Default for Config {
203 41 : fn default() -> Self {
204 41 : Self {
205 41 : database: DatabaseConfig {
206 41 : url: String::new(), // Will fail validation if not provided
207 41 : max_connections: default_max_connections(),
208 41 : migrations_dir: None,
209 41 : },
210 41 : server: ServerConfig {
211 41 : port: default_port(),
212 41 : host: default_host(),
213 41 : },
214 41 : logging: LoggingConfig {
215 41 : level: default_log_level(),
216 41 : },
217 41 : cors: CorsConfig::default(),
218 41 : security_headers: SecurityHeadersConfig::default(),
219 41 : graphql: GraphQLConfig::default(),
220 41 : swagger: SwaggerConfig::default(),
221 41 : }
222 41 : }
223 : }
224 :
225 : /// Configuration loading and validation errors.
226 : #[derive(Debug, thiserror::Error)]
227 : pub enum ConfigError {
228 : #[error("Configuration error: {0}")]
229 : Figment(#[from] Box<figment::Error>),
230 :
231 : #[error("Validation error: {0}")]
232 : Validation(String),
233 : }
234 :
235 : impl From<figment::Error> for ConfigError {
236 0 : fn from(err: figment::Error) -> Self {
237 0 : Self::Figment(Box::new(err))
238 0 : }
239 : }
240 :
241 : impl Config {
242 : /// Load configuration from all sources.
243 : ///
244 : /// Sources are merged in priority order:
245 : /// 1. Struct defaults (lowest)
246 : /// 2. config.yaml file (if exists)
247 : /// 3. Environment variables with TC_ prefix (highest)
248 : ///
249 : /// # Errors
250 : /// Returns an error if configuration cannot be loaded or is invalid.
251 0 : pub fn load() -> Result<Self, ConfigError> {
252 0 : let config: Self = Figment::new()
253 0 : .merge(Serialized::defaults(Self::default()))
254 0 : .merge(Yaml::file("config.yaml"))
255 0 : .merge(Env::prefixed("TC_").split("__"))
256 0 : .extract()?;
257 :
258 0 : config.validate()?;
259 0 : Ok(config)
260 0 : }
261 :
262 : /// Load configuration with a custom YAML file path.
263 : ///
264 : /// # Errors
265 : /// Returns an error if configuration cannot be loaded or is invalid.
266 0 : pub fn load_from(yaml_path: &str) -> Result<Self, ConfigError> {
267 0 : let config: Self = Figment::new()
268 0 : .merge(Serialized::defaults(Self::default()))
269 0 : .merge(Yaml::file(yaml_path))
270 0 : .merge(Env::prefixed("TC_").split("__"))
271 0 : .extract()?;
272 :
273 0 : config.validate()?;
274 0 : Ok(config)
275 0 : }
276 :
277 : /// Validate configuration values.
278 : ///
279 : /// # Errors
280 : /// Returns an error if any configuration value is invalid.
281 40 : pub fn validate(&self) -> Result<(), ConfigError> {
282 : // Database URL is required and must be a postgres URL
283 40 : if self.database.url.is_empty() {
284 2 : return Err(ConfigError::Validation(
285 2 : "database.url is required. Set TC_DATABASE__URL environment variable.".into(),
286 2 : ));
287 38 : }
288 :
289 38 : if !self.database.url.starts_with("postgres://")
290 7 : && !self.database.url.starts_with("postgresql://")
291 : {
292 5 : return Err(ConfigError::Validation(format!(
293 5 : "database.url must start with postgres:// or postgresql://, got: {}",
294 5 : &self.database.url[..self.database.url.len().min(20)]
295 5 : )));
296 33 : }
297 :
298 : // Port must be non-zero
299 33 : if self.server.port == 0 {
300 1 : return Err(ConfigError::Validation("server.port cannot be 0".into()));
301 32 : }
302 :
303 : // Max connections must be at least 1
304 32 : if self.database.max_connections == 0 {
305 1 : return Err(ConfigError::Validation(
306 1 : "database.max_connections cannot be 0".into(),
307 1 : ));
308 31 : }
309 :
310 : // CORS origins must be valid URLs or "*"
311 38 : for origin in &self.cors.allowed_origins {
312 11 : if origin != "*" && !origin.starts_with("http://") && !origin.starts_with("https://") {
313 4 : return Err(ConfigError::Validation(format!(
314 4 : "cors.allowed_origins contains invalid origin '{origin}'. Must be '*' or start with http:// or https://"
315 4 : )));
316 7 : }
317 : }
318 :
319 : // X-Frame-Options must be DENY or SAMEORIGIN
320 27 : let frame_opts = self.security_headers.frame_options.to_uppercase();
321 27 : if frame_opts != "DENY" && frame_opts != "SAMEORIGIN" {
322 3 : return Err(ConfigError::Validation(format!(
323 3 : "security_headers.frame_options must be 'DENY' or 'SAMEORIGIN', got: '{}'",
324 3 : self.security_headers.frame_options
325 3 : )));
326 24 : }
327 :
328 24 : Ok(())
329 40 : }
330 : }
331 :
332 : #[cfg(test)]
333 : mod tests {
334 : use super::*;
335 :
336 : #[test]
337 1 : fn test_defaults() {
338 1 : let config = Config::default();
339 1 : assert_eq!(config.server.port, 8080);
340 1 : assert_eq!(config.server.host, "0.0.0.0");
341 1 : assert_eq!(config.logging.level, "info");
342 1 : assert_eq!(config.database.max_connections, 10);
343 1 : }
344 :
345 : #[test]
346 1 : fn test_validation_rejects_empty_database_url() {
347 1 : let config = Config::default();
348 1 : let result = config.validate();
349 1 : assert!(result.is_err());
350 1 : assert!(result
351 1 : .unwrap_err()
352 1 : .to_string()
353 1 : .contains("database.url is required"));
354 1 : }
355 :
356 : #[test]
357 1 : fn test_validation_rejects_non_postgres_url() {
358 1 : let mut config = Config::default();
359 1 : config.database.url = "mysql://localhost/db".into();
360 1 : let result = config.validate();
361 1 : assert!(result.is_err());
362 1 : assert!(result
363 1 : .unwrap_err()
364 1 : .to_string()
365 1 : .contains("must start with postgres://"));
366 1 : }
367 :
368 : #[test]
369 1 : fn test_validation_accepts_valid_config() {
370 1 : let mut config = Config::default();
371 1 : config.database.url = "postgres://localhost/test".into();
372 1 : assert!(config.validate().is_ok());
373 1 : }
374 :
375 : #[test]
376 1 : fn test_validation_accepts_postgresql_scheme() {
377 1 : let mut config = Config::default();
378 1 : config.database.url = "postgresql://localhost/test".into();
379 1 : assert!(config.validate().is_ok());
380 1 : }
381 :
382 : #[test]
383 1 : fn test_cors_defaults_to_empty() {
384 1 : let config = CorsConfig::default();
385 1 : assert!(config.allowed_origins.is_empty());
386 1 : }
387 :
388 : #[test]
389 1 : fn test_cors_validation_accepts_valid_origins() {
390 1 : let mut config = Config::default();
391 1 : config.database.url = "postgres://localhost/test".into();
392 1 : config.cors.allowed_origins = vec![
393 1 : "http://localhost:3000".into(),
394 1 : "https://app.example.com".into(),
395 : ];
396 1 : assert!(config.validate().is_ok());
397 1 : }
398 :
399 : #[test]
400 1 : fn test_cors_validation_accepts_wildcard() {
401 1 : let mut config = Config::default();
402 1 : config.database.url = "postgres://localhost/test".into();
403 1 : config.cors.allowed_origins = vec!["*".into()];
404 1 : assert!(config.validate().is_ok());
405 1 : }
406 :
407 : #[test]
408 1 : fn test_cors_validation_rejects_invalid_origin() {
409 1 : let mut config = Config::default();
410 1 : config.database.url = "postgres://localhost/test".into();
411 1 : config.cors.allowed_origins = vec!["not-a-url".into()];
412 1 : let result = config.validate();
413 1 : assert!(result.is_err());
414 1 : assert!(result.unwrap_err().to_string().contains("invalid origin"));
415 1 : }
416 :
417 : #[test]
418 1 : fn test_cors_deserialize_comma_separated_string() {
419 : // Simulate what figment does with env var
420 1 : let json = r#"{"allowed_origins": "http://localhost:5173,https://app.example.com"}"#;
421 1 : let config: CorsConfig = serde_json::from_str(json).expect("should parse");
422 1 : assert_eq!(config.allowed_origins.len(), 2);
423 1 : assert_eq!(config.allowed_origins[0], "http://localhost:5173");
424 1 : assert_eq!(config.allowed_origins[1], "https://app.example.com");
425 1 : }
426 :
427 : #[test]
428 1 : fn test_cors_deserialize_array() {
429 1 : let json = r#"{"allowed_origins": ["http://localhost:5173", "https://app.example.com"]}"#;
430 1 : let config: CorsConfig = serde_json::from_str(json).expect("should parse");
431 1 : assert_eq!(config.allowed_origins.len(), 2);
432 1 : assert_eq!(config.allowed_origins[0], "http://localhost:5173");
433 1 : assert_eq!(config.allowed_origins[1], "https://app.example.com");
434 1 : }
435 :
436 : #[test]
437 1 : fn test_cors_deserialize_empty_string() {
438 1 : let json = r#"{"allowed_origins": ""}"#;
439 1 : let config: CorsConfig = serde_json::from_str(json).expect("should parse");
440 1 : assert!(config.allowed_origins.is_empty());
441 1 : }
442 :
443 : #[test]
444 1 : fn test_graphql_playground_disabled_by_default() {
445 1 : let config = GraphQLConfig::default();
446 1 : assert!(!config.playground_enabled);
447 1 : }
448 :
449 : #[test]
450 1 : fn test_graphql_playground_can_be_enabled() {
451 1 : let json = r#"{"playground_enabled": true}"#;
452 1 : let config: GraphQLConfig = serde_json::from_str(json).expect("should parse");
453 1 : assert!(config.playground_enabled);
454 1 : }
455 :
456 : #[test]
457 1 : fn test_swagger_disabled_by_default() {
458 1 : let config = SwaggerConfig::default();
459 1 : assert!(!config.enabled);
460 1 : }
461 :
462 : #[test]
463 1 : fn test_swagger_can_be_enabled() {
464 1 : let json = r#"{"enabled": true}"#;
465 1 : let config: SwaggerConfig = serde_json::from_str(json).expect("should parse");
466 1 : assert!(config.enabled);
467 1 : }
468 :
469 : // Table-driven boundary tests for validation rules
470 :
471 : #[test]
472 1 : fn database_url_scheme_boundaries() {
473 1 : let cases = [
474 1 : ("postgres://localhost/db", true, "standard postgres"),
475 1 : ("postgresql://localhost/db", true, "postgresql alias"),
476 1 : ("postgres://", true, "minimal postgres URL"),
477 1 : ("", false, "empty URL"),
478 1 : ("mysql://localhost/db", false, "wrong scheme"),
479 1 : ("http://localhost/db", false, "http scheme"),
480 1 : ("postgrex://localhost/db", false, "typo in scheme"),
481 1 : ("POSTGRES://localhost/db", false, "uppercase scheme"),
482 1 : ];
483 :
484 9 : for (url, should_pass, desc) in cases {
485 8 : let mut config = Config::default();
486 8 : config.database.url = url.into();
487 8 : let result = config.validate();
488 8 : assert_eq!(
489 8 : result.is_ok(),
490 : should_pass,
491 0 : "case '{}': expected {}, got {:?}",
492 : desc,
493 : should_pass,
494 : result
495 : );
496 : }
497 1 : }
498 :
499 : #[test]
500 1 : fn port_boundaries() {
501 1 : let cases = [
502 1 : (0u16, false, "zero port"),
503 1 : (1, true, "minimum valid port"),
504 1 : (80, true, "common HTTP port"),
505 1 : (8080, true, "default port"),
506 1 : (65535, true, "maximum port"),
507 1 : ];
508 :
509 6 : for (port, should_pass, desc) in cases {
510 5 : let mut config = Config::default();
511 5 : config.database.url = "postgres://localhost/db".into();
512 5 : config.server.port = port;
513 5 : let result = config.validate();
514 5 : assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
515 : }
516 1 : }
517 :
518 : #[test]
519 1 : fn max_connections_boundaries() {
520 1 : let cases = [
521 1 : (0u32, false, "zero connections"),
522 1 : (1, true, "minimum valid"),
523 1 : (10, true, "default value"),
524 1 : (100, true, "high value"),
525 1 : ];
526 :
527 5 : for (max, should_pass, desc) in cases {
528 4 : let mut config = Config::default();
529 4 : config.database.url = "postgres://localhost/db".into();
530 4 : config.database.max_connections = max;
531 4 : let result = config.validate();
532 4 : assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
533 : }
534 1 : }
535 :
536 : #[test]
537 1 : fn cors_origin_boundaries() {
538 1 : let cases = [
539 1 : (vec!["*"], true, "wildcard"),
540 1 : (vec!["http://localhost"], true, "http localhost"),
541 1 : (vec!["https://example.com"], true, "https domain"),
542 1 : (vec!["http://localhost:3000"], true, "with port"),
543 1 : (vec![], true, "empty list"),
544 1 : (vec!["ftp://files.com"], false, "ftp scheme"),
545 1 : (vec!["localhost"], false, "no scheme"),
546 1 : (vec!["//example.com"], false, "protocol-relative"),
547 1 : ];
548 :
549 9 : for (origins, should_pass, desc) in cases {
550 8 : let mut config = Config::default();
551 8 : config.database.url = "postgres://localhost/db".into();
552 8 : config.cors.allowed_origins = origins.into_iter().map(String::from).collect();
553 8 : let result = config.validate();
554 8 : assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
555 : }
556 1 : }
557 :
558 : #[test]
559 1 : fn frame_options_boundaries() {
560 1 : let cases = [
561 1 : ("DENY", true, "uppercase DENY"),
562 1 : ("SAMEORIGIN", true, "uppercase SAMEORIGIN"),
563 1 : ("deny", true, "lowercase deny"),
564 1 : ("sameorigin", true, "lowercase sameorigin"),
565 1 : ("Deny", true, "mixed case Deny"),
566 1 : ("ALLOW-FROM", false, "deprecated ALLOW-FROM"),
567 1 : ("", false, "empty string"),
568 1 : ("INVALID", false, "invalid value"),
569 1 : ];
570 :
571 9 : for (value, should_pass, desc) in cases {
572 8 : let mut config = Config::default();
573 8 : config.database.url = "postgres://localhost/db".into();
574 8 : config.security_headers.frame_options = value.into();
575 8 : let result = config.validate();
576 8 : assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
577 : }
578 1 : }
579 : }
|