diff --git a/cot/src/auth.rs b/cot/src/auth.rs index 0ee95c4f..550bf1d2 100644 --- a/cot/src/auth.rs +++ b/cot/src/auth.rs @@ -685,6 +685,51 @@ impl ToDbValue for PasswordHash { } } +#[cfg(feature = "db")] +impl FromDbValue for Option { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: crate::db::impl_sqlite::SqliteValueRef<'_>) -> cot::db::Result { + value + .get::>() + .map(|op_str| op_str.map(PasswordHash::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "postgres")] + fn from_postgres( + value: crate::db::impl_postgres::PostgresValueRef<'_>, + ) -> cot::db::Result { + value + .get::>() + .map(|op_str| op_str.map(PasswordHash::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: crate::db::impl_mysql::MySqlValueRef<'_>) -> crate::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|op_str| op_str.map(PasswordHash::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } +} + +#[cfg(feature = "db")] +impl ToDbValue for Option { + fn to_db_value(&self) -> DbValue { + match self { + Some(hash) => hash.to_db_value(), + None => >::None.to_db_value(), + } + } +} + /// Authentication helper structure. /// /// This is an object that provides methods to sign users in and out, by using diff --git a/cot/src/common_types.rs b/cot/src/common_types.rs index a120afe0..1d3616ce 100644 --- a/cot/src/common_types.rs +++ b/cot/src/common_types.rs @@ -393,6 +393,52 @@ impl FromDbValue for Url { } } +#[cfg(feature = "db")] +impl FromDbValue for Option { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Url::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Url::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Url::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } +} + +#[cfg(feature = "db")] +impl ToDbValue for Option { + fn to_db_value(&self) -> DbValue { + self.clone().map(Url::into_string).into() + } +} + #[cfg(feature = "db")] impl DatabaseField for Url { const TYPE: ColumnType = ColumnType::Text; @@ -683,6 +729,52 @@ impl FromDbValue for Email { } } +#[cfg(feature = "db")] +impl ToDbValue for Option { + fn to_db_value(&self) -> DbValue { + self.clone().map(|email| email.email()).into() + } +} + +#[cfg(feature = "db")] +impl FromDbValue for Option { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Email::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Email::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef<'_>) -> cot::db::Result + where + Self: Sized, + { + value + .get::>() + .map(|opt_str| opt_str.map(Email::new))? + .transpose() + .map_err(cot::db::DatabaseError::value_decode) + } +} + /// Defines the database field type for `Email`. /// /// Emails are stored as strings with a maximum length of 254 characters, diff --git a/cot/src/db/fields.rs b/cot/src/db/fields.rs index 50e8acfd..8c93a554 100644 --- a/cot/src/db/fields.rs +++ b/cot/src/db/fields.rs @@ -300,6 +300,35 @@ impl ToDbValue for Option> { } } +impl FromDbValue for Option> { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef<'_>) -> Result { + value + .get::>() + .map(|opt_str| opt_str.map(LimitedString::new))? + .transpose() + .map_err(DatabaseError::value_decode) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef<'_>) -> Result { + value + .get::>() + .map(|opt_str| opt_str.map(LimitedString::new))? + .transpose() + .map_err(DatabaseError::value_decode) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef<'_>) -> Result { + value + .get::>() + .map(|opt_str| opt_str.map(LimitedString::new))? + .transpose() + .map_err(DatabaseError::value_decode) + } +} + impl DatabaseField for ForeignKey { const NULLABLE: bool = T::PrimaryKey::NULLABLE; const TYPE: ColumnType = T::PrimaryKey::TYPE; diff --git a/cot/tests/db.rs b/cot/tests/db.rs index f32d0961..107f48c4 100644 --- a/cot/tests/db.rs +++ b/cot/tests/db.rs @@ -1,6 +1,8 @@ #![cfg(feature = "fake")] #![cfg_attr(miri, ignore)] +use cot::auth::PasswordHash; +use cot::common_types::{Email, Password, Url}; use cot::db::migrations::{Field, Operation}; use cot::db::query::ExprEq; use cot::db::{ @@ -39,6 +41,31 @@ impl Dummy for chrono::WeekdaySet { } } +struct EmailFaker; + +impl Dummy for Email { + fn dummy_with_rng(_: &EmailFaker, rng: &mut R) -> Self { + let username: String = (0..10) + .map(|_| (0x61u8 + (rng.next_u32() % 26) as u8) as char) + .collect(); + let domain: String = (0..10) + .map(|_| (0x61u8 + (rng.next_u32() % 26) as u8) as char) + .collect(); + Email::new(format!("{username}@{domain}.com")).expect("Generated email should be valid") + } +} + +struct UrlFaker; + +impl Dummy for Url { + fn dummy_with_rng(_config: &UrlFaker, rng: &mut R) -> Self { + let domain: String = (0..10) + .map(|_| (0x61u8 + (rng.next_u32() % 26) as u8) as char) + .collect(); + Url::new(format!("https://{domain}.com")).expect("Generated URL should be valid") + } +} + #[derive(Debug, PartialEq)] #[model] struct TestModel { @@ -221,8 +248,17 @@ struct AllFieldsModel { field_blob: Vec, field_option: Option, field_limited_string: LimitedString<10>, + field_option_limited_string: Option>, #[dummy(faker = "WeekdaySetFaker")] field_weekday_set: chrono::WeekdaySet, + #[dummy(faker = "EmailFaker")] + field_email: Email, + #[dummy(faker = "EmailFaker")] + field_option_email: Option, + #[dummy(faker = "UrlFaker")] + field_url: Url, + #[dummy(faker = "UrlFaker")] + field_option_url: Option, } async fn migrate_all_fields_model(db: &Database) { @@ -254,7 +290,13 @@ const CREATE_ALL_FIELDS_MODEL: Operation = Operation::create_model() all_fields_migration_field!(blob, Vec), all_fields_migration_field!(option, Option), all_fields_migration_field!(limited_string, LimitedString<10>), + all_fields_migration_field!(option_limited_string, Option>), all_fields_migration_field!(weekday_set, chrono::WeekdaySet), + all_fields_migration_field!(email, Email), + all_fields_migration_field!(option_email, Option), + all_fields_migration_field!(url, Url), + all_fields_migration_field!(option_url, Option), + all_fields_migration_field!(option_password_hash, Option), ]) .build(); @@ -315,6 +357,52 @@ macro_rules! run_migrations { }; } +#[cot_macros::dbtest] +async fn password_hash_field(db: &TestDatabase) { + #[derive(Debug, Clone)] + #[model] + struct OptionPasswordHashModel { + #[model(primary_key)] + id: Auto, + password: Option, + } + + const CREATE_OPTIONAL_PASSWORD_HASH_MODEL: Operation = Operation::create_model() + .table_name(Identifier::new("cot__option_password_hash_model")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("password"), + as DatabaseField>::TYPE, + ) + .set_null( as DatabaseField>::NULLABLE), + ]) + .build(); + + run_migrations!(db, CREATE_OPTIONAL_PASSWORD_HASH_MODEL); + + let generated_password: String = Faker.fake(); + let mut with_password = OptionPasswordHashModel { + id: Auto::auto(), + password: Some(PasswordHash::from_password(&Password::new( + &generated_password, + ))), + }; + with_password.save(&**db).await.unwrap(); + + let mut without_password = OptionPasswordHashModel { + id: Auto::auto(), + password: None, + }; + without_password.save(&**db).await.unwrap(); + + let models = OptionPasswordHashModel::objects().all(&**db).await.unwrap(); + + assert_eq!(models.len(), 2); +} + #[cot_macros::dbtest] async fn foreign_keys(db: &mut TestDatabase) { #[derive(Debug, Clone, PartialEq)]