use std::{ collections::HashMap, path::PathBuf, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; use argon2::{ password_hash::{rand_core::OsRng, Encoding, PasswordHashString, SaltString}, Argon2, PasswordHash, PasswordHasher, }; use nanoid::nanoid; use serde::{de::Unexpected, Deserialize, Serialize}; use serde_json::{Map, Value}; use tap::{Pipe, Tap}; use tokio::sync::{RwLock, RwLockReadGuard}; #[derive(Clone)] pub struct Store(Arc>); impl Store { pub async fn load(path: PathBuf) -> eyre::Result { let mut inner = StoreInner::new(path.clone()); if path.try_exists()? { let mut map: Map = { std::fs::File::open(&path)?.pipe(serde_json::from_reader)? }; let accounts: Vec = serde_json::from_value(map.remove("accounts").unwrap_or_default())?; let apps: Vec = serde_json::from_value(map.remove("apps").unwrap_or_default())?; for mut account in accounts { account.tokens = account .tokens .into_iter() .filter(|token| token.expires > SystemTime::now()) .collect(); for token in &account.tokens { inner.token_map.insert( token.value.clone(), (account.name.clone(), token.expires.clone()), ); } inner.accounts.insert(account.name.clone(), account); } for app in apps { inner.apps.insert(app.name.clone(), app); } inner.invites = serde_json::from_value(map.remove("invites").unwrap_or_default())?; inner.invites = inner .invites .into_iter() .filter(|invite| invite.expires > SystemTime::now()) .collect(); } else { Self::save(&mut inner).await?; } Ok(Store(Arc::new(RwLock::new(inner)))) } async fn save(inner: &mut StoreInner) -> std::io::Result<()> { let mut map: Map = Map::new(); map.insert( "accounts".into(), serde_json::to_value(inner.accounts.values().collect::>()).unwrap(), ); map.insert( "apps".into(), serde_json::to_value(inner.apps.values().collect::>()).unwrap(), ); map.insert( "invites".into(), serde_json::to_value(&inner.invites).unwrap(), ); let data = serde_json::to_vec_pretty(&map).unwrap(); let mut temp = inner.path.clone(); temp.set_file_name( temp.file_name() .unwrap_or_default() .to_os_string() .tap_mut(|name| name.push(".tmp")), ); tokio::fs::write(&temp, data).await?; tokio::fs::rename(temp, &inner.path).await?; Ok(()) } pub async fn get_account(&self, name: &str) -> Option> { let guard = self.0.read().await; RwLockReadGuard::try_map(guard, |guard| guard.accounts.get(name)).ok() } pub async fn create_account(&self, name: &str, password: &str) -> std::io::Result { let hash = Argon2::default() .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng)) .unwrap() .pipe(|hash| hash.serialize()) .pipe(OwnedPasswordHash::from); let mut guard = self.0.write().await; if guard.accounts.get(name).is_some() { return Ok(false); } guard.accounts.insert( name.to_string(), Account { name: name.to_string(), password: hash, tokens: Default::default(), scopes: Default::default(), }, ); Self::save(&mut guard).await?; Ok(true) } pub async fn delete_account(&self, name: &str) -> std::io::Result<()> { let mut guard = self.0.write().await; if let Some(account) = guard.accounts.remove(name) { for token in account.tokens { guard.token_map.remove(&token.value); } Self::save(&mut guard).await } else { Ok(()) } } pub async fn update_account( &self, name: &str, with: impl FnOnce(&mut Account), ) -> std::io::Result { let mut guard = self.0.write().await; if let Some(account) = guard.accounts.get_mut(name) { with(account); Self::save(&mut guard).await.map(|_| true) } else { Ok(false) } } pub async fn list_accounts(&self) -> Vec { self.0 .read() .await .accounts .values() .map(|account| account.name.clone()) .collect() } pub async fn create_token(&self, name: &str) -> std::io::Result> { let mut guard = self.0.write().await; let token = nanoid!(32); if let Some(account) = guard.accounts.get_mut(name) { let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 30); account.tokens.push(ExpiringValue { value: token.clone(), expires: expires.clone(), }); guard .token_map .insert(token.clone(), (name.to_string(), expires)); Self::save(&mut guard).await?; Ok(Some(token)) } else { Ok(None) } } pub async fn check_token(&self, token: &str) -> Option<(String, SystemTime)> { let guard = self.0.read().await; let Some((name, expires)) = guard.token_map.get(token) else { return None; }; if *expires < SystemTime::now() { return None; } Some((name.clone(), expires.clone())) } pub async fn create_invite(&self) -> std::io::Result { let mut guard = self.0.write().await; let invite = nanoid!(32); let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 14); guard.invites.push(ExpiringValue { value: invite.clone(), expires, }); Self::save(&mut guard).await?; Ok(invite) } pub async fn check_invite(&self, invite: &str) -> bool { let guard = self.0.read().await; let now = SystemTime::now(); guard .invites .iter() .any(|check| check.expires > now && check.value == invite) } pub async fn use_invite(&self, invite: &str) -> std::io::Result { let mut guard = self.0.write().await; let now = SystemTime::now(); if let Some((index, _)) = guard .invites .iter() .enumerate() .find(|(_, check)| check.expires > now && check.value == invite) { guard.invites.swap_remove(index); Self::save(&mut guard).await?; Ok(true) } else { Ok(false) } } } pub struct StoreInner { path: PathBuf, accounts: HashMap, apps: HashMap, invites: Vec, token_map: HashMap, } impl StoreInner { fn new(at: PathBuf) -> Self { Self { path: at, accounts: Default::default(), apps: Default::default(), invites: Default::default(), token_map: Default::default(), } } } #[derive(serde::Serialize, serde::Deserialize)] pub struct Account { pub name: String, pub password: OwnedPasswordHash, pub tokens: Vec, pub scopes: Vec, } #[ouroboros::self_referencing] pub struct OwnedPasswordHash { owned: PasswordHashString, #[borrows(owned)] #[not_covariant] parsed: PasswordHash<'this>, } impl OwnedPasswordHash { pub fn from(inner: PasswordHashString) -> Self { OwnedPasswordHash::new(inner, |inner| inner.password_hash()) } pub fn parsed(&self) -> PasswordHash { self.with_parsed(|x| x.clone()) } } impl Serialize for OwnedPasswordHash { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { self.with_owned(|owned| owned.as_str().serialize(serializer)) } } impl<'de> Deserialize<'de> for OwnedPasswordHash { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Ok(String::deserialize(deserializer)? .pipe(|hash| { PasswordHashString::parse(&hash, Encoding::B64).map_err(|_| { ::invalid_value( Unexpected::Str(&hash), &"valid password hash", ) }) })? .pipe(OwnedPasswordHash::from)) } } pub struct ExpiringValue { pub value: String, pub expires: SystemTime, } impl Serialize for ExpiringValue { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { <(&str, u64)>::serialize( &( &self.value, (self.expires.duration_since(SystemTime::UNIX_EPOCH)) .unwrap() .as_secs(), ), serializer, ) } } impl<'de> Deserialize<'de> for ExpiringValue { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { <(String, u64)>::deserialize(deserializer).map(|(token, unix)| ExpiringValue { value: token, expires: UNIX_EPOCH + Duration::from_secs(unix), }) } } #[derive(serde::Serialize, serde::Deserialize)] pub struct App { name: String, secret: String, }