diff options
Diffstat (limited to 'src/server/store.rs')
-rw-r--r-- | src/server/store.rs | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/src/server/store.rs b/src/server/store.rs new file mode 100644 index 0000000..98c1bcc --- /dev/null +++ b/src/server/store.rs @@ -0,0 +1,330 @@ +use std::{ + collections::HashMap, + path::PathBuf, + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use argon2::{ + password_hash::{rand_core::OsRng, Encoding, PasswordHashString, SaltString}, + Argon2, PasswordHash, PasswordHasher, +}; +use nanoid::nanoid; +use serde::{de::Unexpected, Deserialize, Serialize}; +use serde_json::{Map, Value}; +use tap::{Pipe, Tap}; +use tokio::sync::{RwLock, RwLockReadGuard}; + +#[derive(Clone)] +pub struct Store(Arc<RwLock<StoreInner>>); + +impl Store { + pub async fn load(path: PathBuf) -> eyre::Result<Self> { + let mut inner = StoreInner::new(path.clone()); + if path.try_exists()? { + let mut map: Map<String, Value> = + { std::fs::File::open(&path)?.pipe(serde_json::from_reader)? }; + let accounts: Vec<Account> = + serde_json::from_value(map.remove("accounts").unwrap_or_default())?; + let apps: Vec<App> = serde_json::from_value(map.remove("apps").unwrap_or_default())?; + for mut account in accounts { + account.tokens = account + .tokens + .into_iter() + .filter(|token| token.expires > SystemTime::now()) + .collect(); + for token in &account.tokens { + inner.token_map.insert( + token.value.clone(), + (account.name.clone(), token.expires.clone()), + ); + } + inner.accounts.insert(account.name.clone(), account); + } + for app in apps { + inner.apps.insert(app.name.clone(), app); + } + inner.invites = serde_json::from_value(map.remove("invites").unwrap_or_default())?; + inner.invites = inner + .invites + .into_iter() + .filter(|invite| invite.expires > SystemTime::now()) + .collect(); + } else { + Self::save(&mut inner).await?; + } + Ok(Store(Arc::new(RwLock::new(inner)))) + } + + async fn save(inner: &mut StoreInner) -> std::io::Result<()> { + let mut map: Map<String, Value> = Map::new(); + map.insert( + "accounts".into(), + serde_json::to_value(inner.accounts.values().collect::<Vec<_>>()).unwrap(), + ); + map.insert( + "apps".into(), + serde_json::to_value(inner.apps.values().collect::<Vec<_>>()).unwrap(), + ); + map.insert( + "invites".into(), + serde_json::to_value(&inner.invites).unwrap(), + ); + let data = serde_json::to_vec_pretty(&map).unwrap(); + let mut temp = inner.path.clone(); + temp.set_file_name( + temp.file_name() + .unwrap_or_default() + .to_os_string() + .tap_mut(|name| name.push(".tmp")), + ); + tokio::fs::write(&temp, data).await?; + tokio::fs::rename(temp, &inner.path).await?; + Ok(()) + } + + pub async fn get_account(&self, name: &str) -> Option<RwLockReadGuard<'_, Account>> { + let guard = self.0.read().await; + RwLockReadGuard::try_map(guard, |guard| guard.accounts.get(name)).ok() + } + + pub async fn create_account(&self, name: &str, password: &str) -> std::io::Result<bool> { + let hash = Argon2::default() + .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng)) + .unwrap() + .pipe(|hash| hash.serialize()) + .pipe(OwnedPasswordHash::from); + let mut guard = self.0.write().await; + if guard.accounts.get(name).is_some() { + return Ok(false); + } + guard.accounts.insert( + name.to_string(), + Account { + name: name.to_string(), + password: hash, + tokens: Default::default(), + scopes: Default::default(), + }, + ); + Self::save(&mut guard).await?; + Ok(true) + } + + pub async fn delete_account(&self, name: &str) -> std::io::Result<()> { + let mut guard = self.0.write().await; + if let Some(account) = guard.accounts.remove(name) { + for token in account.tokens { + guard.token_map.remove(&token.value); + } + Self::save(&mut guard).await + } else { + Ok(()) + } + } + + pub async fn update_account( + &self, + name: &str, + with: impl FnOnce(&mut Account), + ) -> std::io::Result<bool> { + let mut guard = self.0.write().await; + if let Some(account) = guard.accounts.get_mut(name) { + with(account); + Self::save(&mut guard).await.map(|_| true) + } else { + Ok(false) + } + } + + pub async fn list_accounts(&self) -> Vec<String> { + self.0 + .read() + .await + .accounts + .values() + .map(|account| account.name.clone()) + .collect() + } + + pub async fn create_token(&self, name: &str) -> std::io::Result<Option<String>> { + let mut guard = self.0.write().await; + let token = nanoid!(32); + if let Some(account) = guard.accounts.get_mut(name) { + let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 30); + account.tokens.push(ExpiringValue { + value: token.clone(), + expires: expires.clone(), + }); + guard + .token_map + .insert(token.clone(), (name.to_string(), expires)); + Self::save(&mut guard).await?; + Ok(Some(token)) + } else { + Ok(None) + } + } + + pub async fn check_token(&self, token: &str) -> Option<String> { + let guard = self.0.read().await; + let Some((name, expires)) = guard.token_map.get(token) else { + return None; + }; + if *expires < SystemTime::now() { + return None; + } + Some(name.clone()) + } + + pub async fn create_invite(&self) -> std::io::Result<String> { + let mut guard = self.0.write().await; + let invite = nanoid!(32); + let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 14); + guard.invites.push(ExpiringValue { + value: invite.clone(), + expires, + }); + Self::save(&mut guard).await?; + Ok(invite) + } + + pub async fn check_invite(&self, invite: &str) -> bool { + let guard = self.0.read().await; + let now = SystemTime::now(); + guard + .invites + .iter() + .any(|check| check.expires > now && check.value == invite) + } + + pub async fn use_invite(&self, invite: &str) -> std::io::Result<bool> { + let mut guard = self.0.write().await; + let now = SystemTime::now(); + if let Some((index, _)) = guard + .invites + .iter() + .enumerate() + .find(|(_, check)| check.expires > now && check.value == invite) + { + guard.invites.swap_remove(index); + Self::save(&mut guard).await?; + Ok(true) + } else { + Ok(false) + } + } +} + +pub struct StoreInner { + path: PathBuf, + accounts: HashMap<String, Account>, + apps: HashMap<String, App>, + invites: Vec<ExpiringValue>, + token_map: HashMap<String, (String, SystemTime)>, +} + +impl StoreInner { + fn new(at: PathBuf) -> Self { + Self { + path: at, + accounts: Default::default(), + apps: Default::default(), + invites: Default::default(), + token_map: Default::default(), + } + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct Account { + pub name: String, + pub password: OwnedPasswordHash, + pub tokens: Vec<ExpiringValue>, + pub scopes: Vec<String>, +} + +#[ouroboros::self_referencing] +pub struct OwnedPasswordHash { + owned: PasswordHashString, + #[borrows(owned)] + #[not_covariant] + parsed: PasswordHash<'this>, +} + +impl OwnedPasswordHash { + pub fn from(inner: PasswordHashString) -> Self { + OwnedPasswordHash::new(inner, |inner| inner.password_hash()) + } + + pub fn parsed(&self) -> PasswordHash { + self.with_parsed(|x| x.clone()) + } +} + +impl Serialize for OwnedPasswordHash { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + self.with_owned(|owned| owned.as_str().serialize(serializer)) + } +} + +impl<'de> Deserialize<'de> for OwnedPasswordHash { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + Ok(String::deserialize(deserializer)? + .pipe(|hash| { + PasswordHashString::parse(&hash, Encoding::B64).map_err(|_| { + <D::Error as serde::de::Error>::invalid_value( + Unexpected::Str(&hash), + &"valid password hash", + ) + }) + })? + .pipe(OwnedPasswordHash::from)) + } +} + +pub struct ExpiringValue { + pub value: String, + pub expires: SystemTime, +} + +impl Serialize for ExpiringValue { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + <(&str, u64)>::serialize( + &( + &self.value, + (self.expires.duration_since(SystemTime::UNIX_EPOCH)) + .unwrap() + .as_secs(), + ), + serializer, + ) + } +} + +impl<'de> Deserialize<'de> for ExpiringValue { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + <(String, u64)>::deserialize(deserializer).map(|(token, unix)| ExpiringValue { + value: token, + expires: UNIX_EPOCH + Duration::from_secs(unix), + }) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct App { + name: String, + secret: String, +} |