1use figment::{
2 providers::{Env, Format, Serialized, Yaml},
3 Figment,
4};
5use serde::{Deserialize, Deserializer, Serialize};
6use serde_aux::prelude::deserialize_vec_from_string_or_vec;
7
8#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct Config {
16 pub database: DatabaseConfig,
17 pub server: ServerConfig,
18 pub logging: LoggingConfig,
19 #[serde(default)]
20 pub cors: CorsConfig,
21 #[serde(default)]
22 pub security_headers: SecurityHeadersConfig,
23 #[serde(default)]
24 pub graphql: GraphQLConfig,
25 #[serde(default)]
26 pub swagger: SwaggerConfig,
27}
28
29#[derive(Debug, Clone, Deserialize, Serialize)]
30pub struct DatabaseConfig {
31 pub url: String,
34
35 #[serde(default = "default_max_connections")]
37 pub max_connections: u32,
38
39 pub migrations_dir: Option<String>,
41}
42
43#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct ServerConfig {
45 #[serde(default = "default_port")]
47 pub port: u16,
48
49 #[serde(default = "default_host")]
51 pub host: String,
52}
53
54#[derive(Debug, Clone, Deserialize, Serialize)]
55pub struct LoggingConfig {
56 #[serde(default = "default_log_level")]
58 pub level: String,
59}
60
61#[derive(Debug, Clone, Deserialize, Serialize)]
62pub struct CorsConfig {
63 #[serde(
68 default = "default_allowed_origins",
69 deserialize_with = "deserialize_origins"
70 )]
71 pub allowed_origins: Vec<String>,
72}
73
74fn deserialize_origins<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
76where
77 D: Deserializer<'de>,
78{
79 let origins: Vec<String> = deserialize_vec_from_string_or_vec(deserializer)?;
80 Ok(origins.into_iter().filter(|s| !s.is_empty()).collect())
81}
82
83#[allow(clippy::missing_const_for_fn)]
85fn default_max_connections() -> u32 {
86 10
87}
88
89#[allow(clippy::missing_const_for_fn)]
90fn default_port() -> u16 {
91 8080
92}
93
94fn default_host() -> String {
95 "0.0.0.0".to_string()
96}
97
98fn default_log_level() -> String {
99 "info".to_string()
100}
101
102#[allow(clippy::missing_const_for_fn)]
103fn default_allowed_origins() -> Vec<String> {
104 vec![]
107}
108
109impl Default for CorsConfig {
110 fn default() -> Self {
111 Self {
112 allowed_origins: default_allowed_origins(),
113 }
114 }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct SecurityHeadersConfig {
119 #[serde(default = "default_true")]
121 pub enabled: bool,
122
123 #[serde(default)]
125 pub hsts_enabled: bool,
126
127 #[serde(default = "default_hsts_max_age")]
129 pub hsts_max_age: u64,
130
131 #[serde(default = "default_true")]
133 pub hsts_include_subdomains: bool,
134
135 #[serde(default = "default_frame_options")]
137 pub frame_options: String,
138
139 #[serde(default = "default_csp")]
141 pub content_security_policy: String,
142
143 #[serde(default = "default_referrer_policy")]
145 pub referrer_policy: String,
146}
147
148#[allow(clippy::missing_const_for_fn)]
149fn default_true() -> bool {
150 true
151}
152
153#[allow(clippy::missing_const_for_fn)]
154fn default_hsts_max_age() -> u64 {
155 31_536_000 }
157
158fn default_frame_options() -> String {
159 "DENY".to_string()
160}
161
162fn default_csp() -> String {
163 "default-src 'self'".to_string()
164}
165
166fn default_referrer_policy() -> String {
167 "strict-origin-when-cross-origin".to_string()
168}
169
170impl Default for SecurityHeadersConfig {
171 fn default() -> Self {
172 Self {
173 enabled: default_true(),
174 hsts_enabled: false,
175 hsts_max_age: default_hsts_max_age(),
176 hsts_include_subdomains: default_true(),
177 frame_options: default_frame_options(),
178 content_security_policy: default_csp(),
179 referrer_policy: default_referrer_policy(),
180 }
181 }
182}
183
184#[derive(Debug, Clone, Default, Deserialize, Serialize)]
185pub struct GraphQLConfig {
186 #[serde(default)]
190 pub playground_enabled: bool,
191}
192
193#[derive(Debug, Clone, Default, Deserialize, Serialize)]
194pub struct SwaggerConfig {
195 #[serde(default)]
199 pub enabled: bool,
200}
201
202impl Default for Config {
203 fn default() -> Self {
204 Self {
205 database: DatabaseConfig {
206 url: String::new(), max_connections: default_max_connections(),
208 migrations_dir: None,
209 },
210 server: ServerConfig {
211 port: default_port(),
212 host: default_host(),
213 },
214 logging: LoggingConfig {
215 level: default_log_level(),
216 },
217 cors: CorsConfig::default(),
218 security_headers: SecurityHeadersConfig::default(),
219 graphql: GraphQLConfig::default(),
220 swagger: SwaggerConfig::default(),
221 }
222 }
223}
224
225#[derive(Debug, thiserror::Error)]
227pub enum ConfigError {
228 #[error("Configuration error: {0}")]
229 Figment(#[from] Box<figment::Error>),
230
231 #[error("Validation error: {0}")]
232 Validation(String),
233}
234
235impl From<figment::Error> for ConfigError {
236 fn from(err: figment::Error) -> Self {
237 Self::Figment(Box::new(err))
238 }
239}
240
241impl Config {
242 pub fn load() -> Result<Self, ConfigError> {
252 let config: Self = Figment::new()
253 .merge(Serialized::defaults(Self::default()))
254 .merge(Yaml::file("config.yaml"))
255 .merge(Env::prefixed("TC_").split("__"))
256 .extract()?;
257
258 config.validate()?;
259 Ok(config)
260 }
261
262 pub fn load_from(yaml_path: &str) -> Result<Self, ConfigError> {
267 let config: Self = Figment::new()
268 .merge(Serialized::defaults(Self::default()))
269 .merge(Yaml::file(yaml_path))
270 .merge(Env::prefixed("TC_").split("__"))
271 .extract()?;
272
273 config.validate()?;
274 Ok(config)
275 }
276
277 pub fn validate(&self) -> Result<(), ConfigError> {
282 if self.database.url.is_empty() {
284 return Err(ConfigError::Validation(
285 "database.url is required. Set TC_DATABASE__URL environment variable.".into(),
286 ));
287 }
288
289 if !self.database.url.starts_with("postgres://")
290 && !self.database.url.starts_with("postgresql://")
291 {
292 return Err(ConfigError::Validation(format!(
293 "database.url must start with postgres:// or postgresql://, got: {}",
294 &self.database.url[..self.database.url.len().min(20)]
295 )));
296 }
297
298 if self.server.port == 0 {
300 return Err(ConfigError::Validation("server.port cannot be 0".into()));
301 }
302
303 if self.database.max_connections == 0 {
305 return Err(ConfigError::Validation(
306 "database.max_connections cannot be 0".into(),
307 ));
308 }
309
310 for origin in &self.cors.allowed_origins {
312 if origin != "*" && !origin.starts_with("http://") && !origin.starts_with("https://") {
313 return Err(ConfigError::Validation(format!(
314 "cors.allowed_origins contains invalid origin '{origin}'. Must be '*' or start with http:// or https://"
315 )));
316 }
317 }
318
319 let frame_opts = self.security_headers.frame_options.to_uppercase();
321 if frame_opts != "DENY" && frame_opts != "SAMEORIGIN" {
322 return Err(ConfigError::Validation(format!(
323 "security_headers.frame_options must be 'DENY' or 'SAMEORIGIN', got: '{}'",
324 self.security_headers.frame_options
325 )));
326 }
327
328 Ok(())
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_defaults() {
338 let config = Config::default();
339 assert_eq!(config.server.port, 8080);
340 assert_eq!(config.server.host, "0.0.0.0");
341 assert_eq!(config.logging.level, "info");
342 assert_eq!(config.database.max_connections, 10);
343 }
344
345 #[test]
346 fn test_validation_rejects_empty_database_url() {
347 let config = Config::default();
348 let result = config.validate();
349 assert!(result.is_err());
350 assert!(result
351 .unwrap_err()
352 .to_string()
353 .contains("database.url is required"));
354 }
355
356 #[test]
357 fn test_validation_rejects_non_postgres_url() {
358 let mut config = Config::default();
359 config.database.url = "mysql://localhost/db".into();
360 let result = config.validate();
361 assert!(result.is_err());
362 assert!(result
363 .unwrap_err()
364 .to_string()
365 .contains("must start with postgres://"));
366 }
367
368 #[test]
369 fn test_validation_accepts_valid_config() {
370 let mut config = Config::default();
371 config.database.url = "postgres://localhost/test".into();
372 assert!(config.validate().is_ok());
373 }
374
375 #[test]
376 fn test_validation_accepts_postgresql_scheme() {
377 let mut config = Config::default();
378 config.database.url = "postgresql://localhost/test".into();
379 assert!(config.validate().is_ok());
380 }
381
382 #[test]
383 fn test_cors_defaults_to_empty() {
384 let config = CorsConfig::default();
385 assert!(config.allowed_origins.is_empty());
386 }
387
388 #[test]
389 fn test_cors_validation_accepts_valid_origins() {
390 let mut config = Config::default();
391 config.database.url = "postgres://localhost/test".into();
392 config.cors.allowed_origins = vec![
393 "http://localhost:3000".into(),
394 "https://app.example.com".into(),
395 ];
396 assert!(config.validate().is_ok());
397 }
398
399 #[test]
400 fn test_cors_validation_accepts_wildcard() {
401 let mut config = Config::default();
402 config.database.url = "postgres://localhost/test".into();
403 config.cors.allowed_origins = vec!["*".into()];
404 assert!(config.validate().is_ok());
405 }
406
407 #[test]
408 fn test_cors_validation_rejects_invalid_origin() {
409 let mut config = Config::default();
410 config.database.url = "postgres://localhost/test".into();
411 config.cors.allowed_origins = vec!["not-a-url".into()];
412 let result = config.validate();
413 assert!(result.is_err());
414 assert!(result.unwrap_err().to_string().contains("invalid origin"));
415 }
416
417 #[test]
418 fn test_cors_deserialize_comma_separated_string() {
419 let json = r#"{"allowed_origins": "http://localhost:5173,https://app.example.com"}"#;
421 let config: CorsConfig = serde_json::from_str(json).expect("should parse");
422 assert_eq!(config.allowed_origins.len(), 2);
423 assert_eq!(config.allowed_origins[0], "http://localhost:5173");
424 assert_eq!(config.allowed_origins[1], "https://app.example.com");
425 }
426
427 #[test]
428 fn test_cors_deserialize_array() {
429 let json = r#"{"allowed_origins": ["http://localhost:5173", "https://app.example.com"]}"#;
430 let config: CorsConfig = serde_json::from_str(json).expect("should parse");
431 assert_eq!(config.allowed_origins.len(), 2);
432 assert_eq!(config.allowed_origins[0], "http://localhost:5173");
433 assert_eq!(config.allowed_origins[1], "https://app.example.com");
434 }
435
436 #[test]
437 fn test_cors_deserialize_empty_string() {
438 let json = r#"{"allowed_origins": ""}"#;
439 let config: CorsConfig = serde_json::from_str(json).expect("should parse");
440 assert!(config.allowed_origins.is_empty());
441 }
442
443 #[test]
444 fn test_graphql_playground_disabled_by_default() {
445 let config = GraphQLConfig::default();
446 assert!(!config.playground_enabled);
447 }
448
449 #[test]
450 fn test_graphql_playground_can_be_enabled() {
451 let json = r#"{"playground_enabled": true}"#;
452 let config: GraphQLConfig = serde_json::from_str(json).expect("should parse");
453 assert!(config.playground_enabled);
454 }
455
456 #[test]
457 fn test_swagger_disabled_by_default() {
458 let config = SwaggerConfig::default();
459 assert!(!config.enabled);
460 }
461
462 #[test]
463 fn test_swagger_can_be_enabled() {
464 let json = r#"{"enabled": true}"#;
465 let config: SwaggerConfig = serde_json::from_str(json).expect("should parse");
466 assert!(config.enabled);
467 }
468
469 #[test]
472 fn database_url_scheme_boundaries() {
473 let cases = [
474 ("postgres://localhost/db", true, "standard postgres"),
475 ("postgresql://localhost/db", true, "postgresql alias"),
476 ("postgres://", true, "minimal postgres URL"),
477 ("", false, "empty URL"),
478 ("mysql://localhost/db", false, "wrong scheme"),
479 ("http://localhost/db", false, "http scheme"),
480 ("postgrex://localhost/db", false, "typo in scheme"),
481 ("POSTGRES://localhost/db", false, "uppercase scheme"),
482 ];
483
484 for (url, should_pass, desc) in cases {
485 let mut config = Config::default();
486 config.database.url = url.into();
487 let result = config.validate();
488 assert_eq!(
489 result.is_ok(),
490 should_pass,
491 "case '{}': expected {}, got {:?}",
492 desc,
493 should_pass,
494 result
495 );
496 }
497 }
498
499 #[test]
500 fn port_boundaries() {
501 let cases = [
502 (0u16, false, "zero port"),
503 (1, true, "minimum valid port"),
504 (80, true, "common HTTP port"),
505 (8080, true, "default port"),
506 (65535, true, "maximum port"),
507 ];
508
509 for (port, should_pass, desc) in cases {
510 let mut config = Config::default();
511 config.database.url = "postgres://localhost/db".into();
512 config.server.port = port;
513 let result = config.validate();
514 assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
515 }
516 }
517
518 #[test]
519 fn max_connections_boundaries() {
520 let cases = [
521 (0u32, false, "zero connections"),
522 (1, true, "minimum valid"),
523 (10, true, "default value"),
524 (100, true, "high value"),
525 ];
526
527 for (max, should_pass, desc) in cases {
528 let mut config = Config::default();
529 config.database.url = "postgres://localhost/db".into();
530 config.database.max_connections = max;
531 let result = config.validate();
532 assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
533 }
534 }
535
536 #[test]
537 fn cors_origin_boundaries() {
538 let cases = [
539 (vec!["*"], true, "wildcard"),
540 (vec!["http://localhost"], true, "http localhost"),
541 (vec!["https://example.com"], true, "https domain"),
542 (vec!["http://localhost:3000"], true, "with port"),
543 (vec![], true, "empty list"),
544 (vec!["ftp://files.com"], false, "ftp scheme"),
545 (vec!["localhost"], false, "no scheme"),
546 (vec!["//example.com"], false, "protocol-relative"),
547 ];
548
549 for (origins, should_pass, desc) in cases {
550 let mut config = Config::default();
551 config.database.url = "postgres://localhost/db".into();
552 config.cors.allowed_origins = origins.into_iter().map(String::from).collect();
553 let result = config.validate();
554 assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
555 }
556 }
557
558 #[test]
559 fn frame_options_boundaries() {
560 let cases = [
561 ("DENY", true, "uppercase DENY"),
562 ("SAMEORIGIN", true, "uppercase SAMEORIGIN"),
563 ("deny", true, "lowercase deny"),
564 ("sameorigin", true, "lowercase sameorigin"),
565 ("Deny", true, "mixed case Deny"),
566 ("ALLOW-FROM", false, "deprecated ALLOW-FROM"),
567 ("", false, "empty string"),
568 ("INVALID", false, "invalid value"),
569 ];
570
571 for (value, should_pass, desc) in cases {
572 let mut config = Config::default();
573 config.database.url = "postgres://localhost/db".into();
574 config.security_headers.frame_options = value.into();
575 let result = config.validate();
576 assert_eq!(result.is_ok(), should_pass, "case '{}': {:?}", desc, result);
577 }
578 }
579}