diff --git a/src/service.rs b/src/service.rs index 3f3f231..850f702 100644 --- a/src/service.rs +++ b/src/service.rs @@ -8,7 +8,7 @@ use std::{ }; use http::{Request, Response}; -use time::Duration; +use time::OffsetDateTime; #[cfg(any(feature = "signed", feature = "private"))] use tower_cookies::Key; use tower_cookies::{cookie::SameSite, Cookie, CookieManager, Cookies}; @@ -102,16 +102,20 @@ struct SessionConfig<'a> { } impl<'a> SessionConfig<'a> { - fn build_cookie(self, session_id: session::Id, expiry_age: Duration) -> Cookie<'a> { + fn build_cookie(self, session_id: session::Id, expiry: Option) -> Cookie<'a> { let mut cookie_builder = Cookie::build((self.name, session_id.to_string())) .http_only(self.http_only) .same_site(self.same_site) .secure(self.secure) .path(self.path); - if !matches!(self.expiry, Some(Expiry::OnSessionEnd) | None) { - cookie_builder = cookie_builder.max_age(expiry_age); - } + cookie_builder = match expiry { + Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration), + Some(Expiry::AtDateTime(datetime)) => { + cookie_builder.max_age(datetime - OffsetDateTime::now_utc()) + } + Some(Expiry::OnSessionEnd) | None => cookie_builder, + }; if let Some(domain) = self.domain { cookie_builder = cookie_builder.domain(domain); @@ -256,8 +260,8 @@ where return Ok(res); }; - let expiry_age = session.expiry_age(); - let session_cookie = session_config.build_cookie(session_id, expiry_age); + let expiry = session.expiry(); + let session_cookie = session_config.build_cookie(session_id, expiry); tracing::debug!("adding session cookie"); cookie_controller.add(&cookies, session_cookie); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 7595194..134b9d6 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -58,6 +58,12 @@ fn routes() -> Router { session.set_expiry(Some(expiry)); }), ) + .route( + "/remove_expiry", + get(|session: Session| async move { + session.set_expiry(Some(Expiry::OnSessionEnd)); + }), + ) } pub fn build_app( @@ -382,5 +388,77 @@ macro_rules! route_tests { actual_duration ); } + + #[tokio::test] + async fn change_expiry_type() { + let app = $create_app(None, Some("localhost".to_string())).await; + + let req = Request::builder() + .uri("/insert") + .body(Body::empty()) + .unwrap(); + let res = app.clone().oneshot(req).await.unwrap(); + let session_cookie = get_session_cookie(res.headers()).unwrap(); + + let expected_duration = None; + let actual_duration = session_cookie.max_age(); + + assert_eq!(actual_duration, expected_duration, "Duration is not None"); + + let req = Request::builder() + .uri("/set_expiry") + .header(header::COOKIE, session_cookie.encoded().to_string()) + .body(Body::empty()) + .unwrap(); + let res = app.oneshot(req).await.unwrap(); + + let session_cookie = get_session_cookie(res.headers()).unwrap(); + + let expected_duration = Duration::days(1); + assert!(session_cookie.max_age().is_some(), "Duration is None"); + let actual_duration = session_cookie.max_age().unwrap(); + let tolerance = Duration::seconds(1); + + assert!( + actual_duration >= expected_duration - tolerance + && actual_duration <= expected_duration + tolerance, + "Duration is not within the acceptable range: {:?}", + actual_duration + ); + + let app2 = $create_app(Some(Duration::hours(1)), Some("localhost".to_string())).await; + + let req = Request::builder() + .uri("/insert") + .body(Body::empty()) + .unwrap(); + let res = app2.clone().oneshot(req).await.unwrap(); + let session_cookie = get_session_cookie(res.headers()).unwrap(); + + let expected_duration = Duration::hours(1); + let actual_duration = session_cookie.max_age().unwrap(); + let tolerance = Duration::seconds(1); + + assert!( + actual_duration >= expected_duration - tolerance + && actual_duration <= expected_duration + tolerance, + "Duration is not within the acceptable range: {:?}", + actual_duration + ); + + let req = Request::builder() + .uri("/remove_expiry") + .header(header::COOKIE, session_cookie.encoded().to_string()) + .body(Body::empty()) + .unwrap(); + let res = app2.oneshot(req).await.unwrap(); + + let session_cookie = get_session_cookie(res.headers()).unwrap(); + + let expected_duration = None; + let actual_duration = session_cookie.max_age(); + + assert_eq!(actual_duration, expected_duration, "Duration is not None"); + } }; }